From 19d7beef430df2d22adf2e538a5e3ba2626cebb2 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Fri, 11 Apr 2025 13:37:18 +0530 Subject: [PATCH] Add support for Langchain VectorStores (#2518) --- docs/components/vectordbs/dbs/langchain.mdx | 85 +++++++++ docs/components/vectordbs/overview.mdx | 1 + docs/docs.json | 3 +- mem0/configs/vector_stores/langchain.py | 30 ++++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/langchain.py | 161 ++++++++++++++++++ tests/vector_stores/test_faiss.py | 8 +- .../test_langchain_vector_store.py | 101 +++++++++++ 9 files changed, 386 insertions(+), 5 deletions(-) create mode 100644 docs/components/vectordbs/dbs/langchain.mdx create mode 100644 mem0/configs/vector_stores/langchain.py create mode 100644 mem0/vector_stores/langchain.py create mode 100644 tests/vector_stores/test_langchain_vector_store.py diff --git a/docs/components/vectordbs/dbs/langchain.mdx b/docs/components/vectordbs/dbs/langchain.mdx new file mode 100644 index 00000000..8c2a2f1a --- /dev/null +++ b/docs/components/vectordbs/dbs/langchain.mdx @@ -0,0 +1,85 @@ +--- +title: LangChain +--- + +Mem0 supports LangChain as a provider for vector store integration. LangChain provides a unified interface to various vector databases, making it easy to integrate different vector store providers through a consistent API. + + + When using LangChain as your vector store provider, you must set the collection name to "mem0". This is a required configuration for proper integration with Mem0. + + +## Usage + + +```python Python +import os +from mem0 import Memory +from langchain_community.vectorstores import Chroma +from langchain_openai import OpenAIEmbeddings + +# Initialize a LangChain vector store +embeddings = OpenAIEmbeddings() +vector_store = Chroma( + persist_directory="./chroma_db", + embedding_function=embeddings, + collection_name="mem0" # Required collection name +) + +# Pass the initialized vector store to the config +config = { + "vector_store": { + "provider": "langchain", + "config": { + "client": vector_store + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + + +## Supported LangChain Vector Stores + +LangChain supports a wide range of vector store providers, including: + +- Chroma +- FAISS +- Pinecone +- Weaviate +- Milvus +- Qdrant +- And many more + +You can use any of these vector store instances directly in your configuration. For a complete and up-to-date list of available providers, refer to the [LangChain Vector Stores documentation](https://python.langchain.com/docs/integrations/vectorstores). + +## Limitations + +When using LangChain as a vector store provider, there are some limitations to be aware of: + +1. **Bulk Operations**: The `get_all` and `delete_all` operations are not supported when using LangChain as the vector store provider. This is because LangChain's vector store interface doesn't provide standardized methods for these bulk operations across all providers. + +2. **Provider-Specific Features**: Some advanced features may not be available depending on the specific vector store implementation you're using through LangChain. + +## Provider-Specific Configuration + +When using LangChain as a vector store provider, you'll need to: + +1. Set the appropriate environment variables for your chosen vector store provider +2. Import and initialize the specific vector store class you want to use +3. Pass the initialized vector store instance to the config + + + Make sure to install the necessary LangChain packages and any provider-specific dependencies. + + +## Config + +All available parameters for the `langchain` vector store config are present in [Master List of All Params in Config](../config). diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 3e60d34d..5d406eb6 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -29,6 +29,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/docs/docs.json b/docs/docs.json index 6738a663..b20f0683 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -140,7 +140,8 @@ "components/vectordbs/dbs/supabase", "components/vectordbs/dbs/vertex_ai", "components/vectordbs/dbs/weaviate", - "components/vectordbs/dbs/faiss" + "components/vectordbs/dbs/faiss", + "components/vectordbs/dbs/langchain" ] } ] diff --git a/mem0/configs/vector_stores/langchain.py b/mem0/configs/vector_stores/langchain.py new file mode 100644 index 00000000..78b3533e --- /dev/null +++ b/mem0/configs/vector_stores/langchain.py @@ -0,0 +1,30 @@ +from typing import Any, ClassVar, Dict + +from pydantic import BaseModel, Field, model_validator + + +class LangchainConfig(BaseModel): + try: + from langchain_community.vectorstores import VectorStore + except ImportError: + raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.") + VectorStore: ClassVar[type] = VectorStore + + client: VectorStore = Field(description="Existing VectorStore instance") + collection_name: str = Field("mem0", description="Name of the collection to use") + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 4126cc15..8d48c3ae 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -84,6 +84,7 @@ class VectorStoreFactory: "supabase": "mem0.vector_stores.supabase.Supabase", "weaviate": "mem0.vector_stores.weaviate.Weaviate", "faiss": "mem0.vector_stores.faiss.FAISS", + "langchain": "mem0.vector_stores.langchain.Langchain", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 0d1f9f7c..43a2289f 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -25,6 +25,7 @@ class VectorStoreConfig(BaseModel): "supabase": "SupabaseConfig", "weaviate": "WeaviateConfig", "faiss": "FAISSConfig", + "langchain": "LangchainConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/langchain.py b/mem0/vector_stores/langchain.py new file mode 100644 index 00000000..888ed5ae --- /dev/null +++ b/mem0/vector_stores/langchain.py @@ -0,0 +1,161 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel + +try: + from langchain_community.vectorstores import VectorStore +except ImportError: + raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.") + +from mem0.vector_stores.base import VectorStoreBase + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # distance + payload: Optional[Dict] # metadata + +class Langchain(VectorStoreBase): + def __init__(self, client: VectorStore, collection_name: str = "mem0"): + self.client = client + self.collection_name = collection_name + + def _parse_output(self, data: Dict) -> List[OutputData]: + """ + Parse the output data. + + Args: + data (Dict): Output data or list of Document objects. + + Returns: + List[OutputData]: Parsed output data. + """ + # Check if input is a list of Document objects + if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')): + result = [] + for doc in data: + entry = OutputData( + id=getattr(doc, "id", None), + score=None, # Document objects typically don't include scores + payload=getattr(doc, "metadata", {}) + ) + result.append(entry) + return result + + # Original format handling + keys = ["ids", "distances", "metadatas"] + values = [] + + for key in keys: + value = data.get(key, []) + if isinstance(value, list) and value and isinstance(value[0], list): + value = value[0] + values.append(value) + + ids, distances, metadatas = values + max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) + + result = [] + for i in range(max_length): + entry = OutputData( + id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, + score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None), + payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None), + ) + result.append(entry) + + return result + + def create_col(self, name, vector_size=None, distance=None): + self.collection_name = name + return self.client + + def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None): + """ + Insert vectors into the LangChain vectorstore. + """ + # Check if client has add_embeddings method + if hasattr(self.client, "add_embeddings"): + # Some LangChain vectorstores have a direct add_embeddings method + self.client.add_embeddings( + embeddings=vectors, + metadatas=payloads, + ids=ids + ) + else: + # Fallback to add_texts method + texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors) + self.client.add_texts( + texts=texts, + metadatas=payloads, + ids=ids + ) + + def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None): + """ + Search for similar vectors in LangChain. + """ + # For each vector, perform a similarity search + if filters: + results = self.client.similarity_search_by_vector( + embedding=vectors, + k=limit, + filter=filters + ) + else: + results = self.client.similarity_search_by_vector( + embedding=vectors, + k=limit + ) + + final_results = self._parse_output(results) + return final_results + + def delete(self, vector_id): + """ + Delete a vector by ID. + """ + self.client.delete(ids=[vector_id]) + + def update(self, vector_id, vector=None, payload=None): + """ + Update a vector and its payload. + """ + self.delete(vector_id) + self.insert(vector, payload, [vector_id]) + + def get(self, vector_id): + """ + Retrieve a vector by ID. + """ + docs = self.client.get_by_ids([vector_id]) + if docs and len(docs) > 0: + doc = docs[0] + return self._parse_output([doc])[0] + return None + + def list_cols(self): + """ + List all collections. + """ + # LangChain doesn't have collections + return [self.collection_name] + + def delete_col(self): + """ + Delete a collection. + """ + self.client.delete(ids=None) + + def col_info(self): + """ + Get information about a collection. + """ + return {"name": self.collection_name} + + def list(self, filters=None, limit=None): + """ + List all vectors in a collection. + """ + # This would require implementation-specific access to the underlying store + raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores") diff --git a/tests/vector_stores/test_faiss.py b/tests/vector_stores/test_faiss.py index c4fcd800..cd21c448 100644 --- a/tests/vector_stores/test_faiss.py +++ b/tests/vector_stores/test_faiss.py @@ -39,14 +39,14 @@ def test_create_col(faiss_instance, mock_faiss_index): # Test creating a collection with euclidean distance with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2: with patch('faiss.write_index'): - faiss_instance.create_col(name="new_collection", vector_size=256) - mock_index_flat_l2.assert_called_once_with(256) + faiss_instance.create_col(name="new_collection") + mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims) # Test creating a collection with inner product distance with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip: with patch('faiss.write_index'): - faiss_instance.create_col(name="new_collection", vector_size=256, distance="inner_product") - mock_index_flat_ip.assert_called_once_with(256) + faiss_instance.create_col(name="new_collection", distance="inner_product") + mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims) def test_insert(faiss_instance, mock_faiss_index): diff --git a/tests/vector_stores/test_langchain_vector_store.py b/tests/vector_stores/test_langchain_vector_store.py new file mode 100644 index 00000000..9c9e6ca2 --- /dev/null +++ b/tests/vector_stores/test_langchain_vector_store.py @@ -0,0 +1,101 @@ +from unittest.mock import Mock, patch + +import pytest +from langchain_community.vectorstores import VectorStore + +from mem0.vector_stores.langchain import Langchain + + +@pytest.fixture +def mock_langchain_client(): + with patch("langchain_community.vectorstores.VectorStore") as mock_client: + yield mock_client + +@pytest.fixture +def langchain_instance(mock_langchain_client): + mock_client = Mock(spec=VectorStore) + return Langchain(client=mock_client, collection_name="test_collection") + +def test_insert_vectors(langchain_instance): + # Test data + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"data": "text1", "name": "vector1"}, {"data": "text2", "name": "vector2"}] + ids = ["id1", "id2"] + + # Test with add_embeddings method + langchain_instance.client.add_embeddings = Mock() + langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + langchain_instance.client.add_embeddings.assert_called_once_with( + embeddings=vectors, + metadatas=payloads, + ids=ids + ) + + # Test with add_texts method + delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely + langchain_instance.client.add_texts = Mock() + langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + langchain_instance.client.add_texts.assert_called_once_with( + texts=["text1", "text2"], + metadatas=payloads, + ids=ids + ) + + # Test with empty payloads + langchain_instance.client.add_texts.reset_mock() + langchain_instance.insert(vectors=vectors, payloads=None, ids=ids) + langchain_instance.client.add_texts.assert_called_once_with( + texts=["", ""], + metadatas=None, + ids=ids + ) + +def test_search_vectors(langchain_instance): + # Mock search results + mock_docs = [ + Mock(metadata={"name": "vector1"}, id="id1"), + Mock(metadata={"name": "vector2"}, id="id2") + ] + langchain_instance.client.similarity_search_by_vector.return_value = mock_docs + + # Test search without filters + vectors = [[0.1, 0.2, 0.3]] + results = langchain_instance.search(query="", vectors=vectors, limit=2) + + langchain_instance.client.similarity_search_by_vector.assert_called_once_with( + embedding=vectors, + k=2 + ) + + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].payload == {"name": "vector1"} + assert results[1].id == "id2" + assert results[1].payload == {"name": "vector2"} + + # Test search with filters + filters = {"name": "vector1"} + langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters) + langchain_instance.client.similarity_search_by_vector.assert_called_with( + embedding=vectors, + k=2, + filter=filters + ) + +def test_get_vector(langchain_instance): + # Mock get result + mock_doc = Mock(metadata={"name": "vector1"}, id="id1") + langchain_instance.client.get_by_ids.return_value = [mock_doc] + + # Test get existing vector + result = langchain_instance.get("id1") + langchain_instance.client.get_by_ids.assert_called_once_with(["id1"]) + + assert result is not None + assert result.id == "id1" + assert result.payload == {"name": "vector1"} + + # Test get non-existent vector + langchain_instance.client.get_by_ids.return_value = [] + result = langchain_instance.get("non_existent_id") + assert result is None