Add support for Langchain VectorStores (#2518)

This commit is contained in:
Dev Khant
2025-04-11 13:37:18 +05:30
committed by GitHub
parent 8b789adb15
commit 19d7beef43
9 changed files with 386 additions and 5 deletions

View File

@@ -0,0 +1,85 @@
---
title: LangChain
---
Mem0 supports LangChain as a provider for vector store integration. LangChain provides a unified interface to various vector databases, making it easy to integrate different vector store providers through a consistent API.
<Note>
When using LangChain as your vector store provider, you must set the collection name to "mem0". This is a required configuration for proper integration with Mem0.
</Note>
## Usage
<CodeGroup>
```python Python
import os
from mem0 import Memory
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
# Initialize a LangChain vector store
embeddings = OpenAIEmbeddings()
vector_store = Chroma(
persist_directory="./chroma_db",
embedding_function=embeddings,
collection_name="mem0" # Required collection name
)
# Pass the initialized vector store to the config
config = {
"vector_store": {
"provider": "langchain",
"config": {
"client": vector_store
}
}
}
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
</CodeGroup>
## Supported LangChain Vector Stores
LangChain supports a wide range of vector store providers, including:
- Chroma
- FAISS
- Pinecone
- Weaviate
- Milvus
- Qdrant
- And many more
You can use any of these vector store instances directly in your configuration. For a complete and up-to-date list of available providers, refer to the [LangChain Vector Stores documentation](https://python.langchain.com/docs/integrations/vectorstores).
## Limitations
When using LangChain as a vector store provider, there are some limitations to be aware of:
1. **Bulk Operations**: The `get_all` and `delete_all` operations are not supported when using LangChain as the vector store provider. This is because LangChain's vector store interface doesn't provide standardized methods for these bulk operations across all providers.
2. **Provider-Specific Features**: Some advanced features may not be available depending on the specific vector store implementation you're using through LangChain.
## Provider-Specific Configuration
When using LangChain as a vector store provider, you'll need to:
1. Set the appropriate environment variables for your chosen vector store provider
2. Import and initialize the specific vector store class you want to use
3. Pass the initialized vector store instance to the config
<Note>
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
</Note>
## Config
All available parameters for the `langchain` vector store config are present in [Master List of All Params in Config](../config).

View File

@@ -29,6 +29,7 @@ See the list of supported vector databases below.
<Card title="Vertex AI" href="/components/vectordbs/dbs/vertex_ai"></Card>
<Card title="Weaviate" href="/components/vectordbs/dbs/weaviate"></Card>
<Card title="FAISS" href="/components/vectordbs/dbs/faiss"></Card>
<Card title="LangChain" href="/components/vectordbs/dbs/langchain"></Card>
</CardGroup>
## Usage

View File

@@ -140,7 +140,8 @@
"components/vectordbs/dbs/supabase",
"components/vectordbs/dbs/vertex_ai",
"components/vectordbs/dbs/weaviate",
"components/vectordbs/dbs/faiss"
"components/vectordbs/dbs/faiss",
"components/vectordbs/dbs/langchain"
]
}
]

View File

@@ -0,0 +1,30 @@
from typing import Any, ClassVar, Dict
from pydantic import BaseModel, Field, model_validator
class LangchainConfig(BaseModel):
try:
from langchain_community.vectorstores import VectorStore
except ImportError:
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.")
VectorStore: ClassVar[type] = VectorStore
client: VectorStore = Field(description="Existing VectorStore instance")
collection_name: str = Field("mem0", description="Name of the collection to use")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = {
"arbitrary_types_allowed": True,
}

View File

@@ -84,6 +84,7 @@ class VectorStoreFactory:
"supabase": "mem0.vector_stores.supabase.Supabase",
"weaviate": "mem0.vector_stores.weaviate.Weaviate",
"faiss": "mem0.vector_stores.faiss.FAISS",
"langchain": "mem0.vector_stores.langchain.Langchain",
}
@classmethod

View File

@@ -25,6 +25,7 @@ class VectorStoreConfig(BaseModel):
"supabase": "SupabaseConfig",
"weaviate": "WeaviateConfig",
"faiss": "FAISSConfig",
"langchain": "LangchainConfig",
}
@model_validator(mode="after")

View File

@@ -0,0 +1,161 @@
from typing import Dict, List, Optional
from pydantic import BaseModel
try:
from langchain_community.vectorstores import VectorStore
except ImportError:
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.")
from mem0.vector_stores.base import VectorStoreBase
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class Langchain(VectorStoreBase):
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
self.client = client
self.collection_name = collection_name
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
Args:
data (Dict): Output data or list of Document objects.
Returns:
List[OutputData]: Parsed output data.
"""
# Check if input is a list of Document objects
if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')):
result = []
for doc in data:
entry = OutputData(
id=getattr(doc, "id", None),
score=None, # Document objects typically don't include scores
payload=getattr(doc, "metadata", {})
)
result.append(entry)
return result
# Original format handling
keys = ["ids", "distances", "metadatas"]
values = []
for key in keys:
value = data.get(key, [])
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
values.append(value)
ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
return result
def create_col(self, name, vector_size=None, distance=None):
self.collection_name = name
return self.client
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
"""
Insert vectors into the LangChain vectorstore.
"""
# Check if client has add_embeddings method
if hasattr(self.client, "add_embeddings"):
# Some LangChain vectorstores have a direct add_embeddings method
self.client.add_embeddings(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
else:
# Fallback to add_texts method
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
self.client.add_texts(
texts=texts,
metadatas=payloads,
ids=ids
)
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
"""
Search for similar vectors in LangChain.
"""
# For each vector, perform a similarity search
if filters:
results = self.client.similarity_search_by_vector(
embedding=vectors,
k=limit,
filter=filters
)
else:
results = self.client.similarity_search_by_vector(
embedding=vectors,
k=limit
)
final_results = self._parse_output(results)
return final_results
def delete(self, vector_id):
"""
Delete a vector by ID.
"""
self.client.delete(ids=[vector_id])
def update(self, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
"""
self.delete(vector_id)
self.insert(vector, payload, [vector_id])
def get(self, vector_id):
"""
Retrieve a vector by ID.
"""
docs = self.client.get_by_ids([vector_id])
if docs and len(docs) > 0:
doc = docs[0]
return self._parse_output([doc])[0]
return None
def list_cols(self):
"""
List all collections.
"""
# LangChain doesn't have collections
return [self.collection_name]
def delete_col(self):
"""
Delete a collection.
"""
self.client.delete(ids=None)
def col_info(self):
"""
Get information about a collection.
"""
return {"name": self.collection_name}
def list(self, filters=None, limit=None):
"""
List all vectors in a collection.
"""
# This would require implementation-specific access to the underlying store
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores")

View File

@@ -39,14 +39,14 @@ def test_create_col(faiss_instance, mock_faiss_index):
# Test creating a collection with euclidean distance
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
with patch('faiss.write_index'):
faiss_instance.create_col(name="new_collection", vector_size=256)
mock_index_flat_l2.assert_called_once_with(256)
faiss_instance.create_col(name="new_collection")
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
# Test creating a collection with inner product distance
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
with patch('faiss.write_index'):
faiss_instance.create_col(name="new_collection", vector_size=256, distance="inner_product")
mock_index_flat_ip.assert_called_once_with(256)
faiss_instance.create_col(name="new_collection", distance="inner_product")
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
def test_insert(faiss_instance, mock_faiss_index):

View File

@@ -0,0 +1,101 @@
from unittest.mock import Mock, patch
import pytest
from langchain_community.vectorstores import VectorStore
from mem0.vector_stores.langchain import Langchain
@pytest.fixture
def mock_langchain_client():
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
yield mock_client
@pytest.fixture
def langchain_instance(mock_langchain_client):
mock_client = Mock(spec=VectorStore)
return Langchain(client=mock_client, collection_name="test_collection")
def test_insert_vectors(langchain_instance):
# Test data
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"data": "text1", "name": "vector1"}, {"data": "text2", "name": "vector2"}]
ids = ["id1", "id2"]
# Test with add_embeddings method
langchain_instance.client.add_embeddings = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_embeddings.assert_called_once_with(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
# Test with add_texts method
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
langchain_instance.client.add_texts = Mock()
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["text1", "text2"],
metadatas=payloads,
ids=ids
)
# Test with empty payloads
langchain_instance.client.add_texts.reset_mock()
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
langchain_instance.client.add_texts.assert_called_once_with(
texts=["", ""],
metadatas=None,
ids=ids
)
def test_search_vectors(langchain_instance):
# Mock search results
mock_docs = [
Mock(metadata={"name": "vector1"}, id="id1"),
Mock(metadata={"name": "vector2"}, id="id2")
]
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
# Test search without filters
vectors = [[0.1, 0.2, 0.3]]
results = langchain_instance.search(query="", vectors=vectors, limit=2)
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
embedding=vectors,
k=2
)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].payload == {"name": "vector1"}
assert results[1].id == "id2"
assert results[1].payload == {"name": "vector2"}
# Test search with filters
filters = {"name": "vector1"}
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
langchain_instance.client.similarity_search_by_vector.assert_called_with(
embedding=vectors,
k=2,
filter=filters
)
def test_get_vector(langchain_instance):
# Mock get result
mock_doc = Mock(metadata={"name": "vector1"}, id="id1")
langchain_instance.client.get_by_ids.return_value = [mock_doc]
# Test get existing vector
result = langchain_instance.get("id1")
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
assert result is not None
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
# Test get non-existent vector
langchain_instance.client.get_by_ids.return_value = []
result = langchain_instance.get("non_existent_id")
assert result is None