Add support for Langchain VectorStores (#2518)
This commit is contained in:
85
docs/components/vectordbs/dbs/langchain.mdx
Normal file
85
docs/components/vectordbs/dbs/langchain.mdx
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
title: LangChain
|
||||
---
|
||||
|
||||
Mem0 supports LangChain as a provider for vector store integration. LangChain provides a unified interface to various vector databases, making it easy to integrate different vector store providers through a consistent API.
|
||||
|
||||
<Note>
|
||||
When using LangChain as your vector store provider, you must set the collection name to "mem0". This is a required configuration for proper integration with Mem0.
|
||||
</Note>
|
||||
|
||||
## Usage
|
||||
|
||||
<CodeGroup>
|
||||
```python Python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Initialize a LangChain vector store
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vector_store = Chroma(
|
||||
persist_directory="./chroma_db",
|
||||
embedding_function=embeddings,
|
||||
collection_name="mem0" # Required collection name
|
||||
)
|
||||
|
||||
# Pass the initialized vector store to the config
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "langchain",
|
||||
"config": {
|
||||
"client": vector_store
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
## Supported LangChain Vector Stores
|
||||
|
||||
LangChain supports a wide range of vector store providers, including:
|
||||
|
||||
- Chroma
|
||||
- FAISS
|
||||
- Pinecone
|
||||
- Weaviate
|
||||
- Milvus
|
||||
- Qdrant
|
||||
- And many more
|
||||
|
||||
You can use any of these vector store instances directly in your configuration. For a complete and up-to-date list of available providers, refer to the [LangChain Vector Stores documentation](https://python.langchain.com/docs/integrations/vectorstores).
|
||||
|
||||
## Limitations
|
||||
|
||||
When using LangChain as a vector store provider, there are some limitations to be aware of:
|
||||
|
||||
1. **Bulk Operations**: The `get_all` and `delete_all` operations are not supported when using LangChain as the vector store provider. This is because LangChain's vector store interface doesn't provide standardized methods for these bulk operations across all providers.
|
||||
|
||||
2. **Provider-Specific Features**: Some advanced features may not be available depending on the specific vector store implementation you're using through LangChain.
|
||||
|
||||
## Provider-Specific Configuration
|
||||
|
||||
When using LangChain as a vector store provider, you'll need to:
|
||||
|
||||
1. Set the appropriate environment variables for your chosen vector store provider
|
||||
2. Import and initialize the specific vector store class you want to use
|
||||
3. Pass the initialized vector store instance to the config
|
||||
|
||||
<Note>
|
||||
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
|
||||
</Note>
|
||||
|
||||
## Config
|
||||
|
||||
All available parameters for the `langchain` vector store config are present in [Master List of All Params in Config](../config).
|
||||
@@ -29,6 +29,7 @@ See the list of supported vector databases below.
|
||||
<Card title="Vertex AI" href="/components/vectordbs/dbs/vertex_ai"></Card>
|
||||
<Card title="Weaviate" href="/components/vectordbs/dbs/weaviate"></Card>
|
||||
<Card title="FAISS" href="/components/vectordbs/dbs/faiss"></Card>
|
||||
<Card title="LangChain" href="/components/vectordbs/dbs/langchain"></Card>
|
||||
</CardGroup>
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -140,7 +140,8 @@
|
||||
"components/vectordbs/dbs/supabase",
|
||||
"components/vectordbs/dbs/vertex_ai",
|
||||
"components/vectordbs/dbs/weaviate",
|
||||
"components/vectordbs/dbs/faiss"
|
||||
"components/vectordbs/dbs/faiss",
|
||||
"components/vectordbs/dbs/langchain"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
30
mem0/configs/vector_stores/langchain.py
Normal file
30
mem0/configs/vector_stores/langchain.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any, ClassVar, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class LangchainConfig(BaseModel):
|
||||
try:
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
except ImportError:
|
||||
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.")
|
||||
VectorStore: ClassVar[type] = VectorStore
|
||||
|
||||
client: VectorStore = Field(description="Existing VectorStore instance")
|
||||
collection_name: str = Field("mem0", description="Name of the collection to use")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -84,6 +84,7 @@ class VectorStoreFactory:
|
||||
"supabase": "mem0.vector_stores.supabase.Supabase",
|
||||
"weaviate": "mem0.vector_stores.weaviate.Weaviate",
|
||||
"faiss": "mem0.vector_stores.faiss.FAISS",
|
||||
"langchain": "mem0.vector_stores.langchain.Langchain",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -25,6 +25,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"supabase": "SupabaseConfig",
|
||||
"weaviate": "WeaviateConfig",
|
||||
"faiss": "FAISSConfig",
|
||||
"langchain": "LangchainConfig",
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
161
mem0/vector_stores/langchain.py
Normal file
161
mem0/vector_stores/langchain.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
except ImportError:
|
||||
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
class Langchain(VectorStoreBase):
|
||||
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data or list of Document objects.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
# Check if input is a list of Document objects
|
||||
if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')):
|
||||
result = []
|
||||
for doc in data:
|
||||
entry = OutputData(
|
||||
id=getattr(doc, "id", None),
|
||||
score=None, # Document objects typically don't include scores
|
||||
payload=getattr(doc, "metadata", {})
|
||||
)
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
# Original format handling
|
||||
keys = ["ids", "distances", "metadatas"]
|
||||
values = []
|
||||
|
||||
for key in keys:
|
||||
value = data.get(key, [])
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
values.append(value)
|
||||
|
||||
ids, distances, metadatas = values
|
||||
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
||||
|
||||
result = []
|
||||
for i in range(max_length):
|
||||
entry = OutputData(
|
||||
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
||||
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
|
||||
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
def create_col(self, name, vector_size=None, distance=None):
|
||||
self.collection_name = name
|
||||
return self.client
|
||||
|
||||
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
|
||||
"""
|
||||
Insert vectors into the LangChain vectorstore.
|
||||
"""
|
||||
# Check if client has add_embeddings method
|
||||
if hasattr(self.client, "add_embeddings"):
|
||||
# Some LangChain vectorstores have a direct add_embeddings method
|
||||
self.client.add_embeddings(
|
||||
embeddings=vectors,
|
||||
metadatas=payloads,
|
||||
ids=ids
|
||||
)
|
||||
else:
|
||||
# Fallback to add_texts method
|
||||
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
|
||||
self.client.add_texts(
|
||||
texts=texts,
|
||||
metadatas=payloads,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
|
||||
"""
|
||||
Search for similar vectors in LangChain.
|
||||
"""
|
||||
# For each vector, perform a similarity search
|
||||
if filters:
|
||||
results = self.client.similarity_search_by_vector(
|
||||
embedding=vectors,
|
||||
k=limit,
|
||||
filter=filters
|
||||
)
|
||||
else:
|
||||
results = self.client.similarity_search_by_vector(
|
||||
embedding=vectors,
|
||||
k=limit
|
||||
)
|
||||
|
||||
final_results = self._parse_output(results)
|
||||
return final_results
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
"""
|
||||
self.client.delete(ids=[vector_id])
|
||||
|
||||
def update(self, vector_id, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
"""
|
||||
self.delete(vector_id)
|
||||
self.insert(vector, payload, [vector_id])
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
"""
|
||||
docs = self.client.get_by_ids([vector_id])
|
||||
if docs and len(docs) > 0:
|
||||
doc = docs[0]
|
||||
return self._parse_output([doc])[0]
|
||||
return None
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections.
|
||||
"""
|
||||
# LangChain doesn't have collections
|
||||
return [self.collection_name]
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
self.client.delete(ids=None)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about a collection.
|
||||
"""
|
||||
return {"name": self.collection_name}
|
||||
|
||||
def list(self, filters=None, limit=None):
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
"""
|
||||
# This would require implementation-specific access to the underlying store
|
||||
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores")
|
||||
@@ -39,14 +39,14 @@ def test_create_col(faiss_instance, mock_faiss_index):
|
||||
# Test creating a collection with euclidean distance
|
||||
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
|
||||
with patch('faiss.write_index'):
|
||||
faiss_instance.create_col(name="new_collection", vector_size=256)
|
||||
mock_index_flat_l2.assert_called_once_with(256)
|
||||
faiss_instance.create_col(name="new_collection")
|
||||
mock_index_flat_l2.assert_called_once_with(faiss_instance.embedding_model_dims)
|
||||
|
||||
# Test creating a collection with inner product distance
|
||||
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
|
||||
with patch('faiss.write_index'):
|
||||
faiss_instance.create_col(name="new_collection", vector_size=256, distance="inner_product")
|
||||
mock_index_flat_ip.assert_called_once_with(256)
|
||||
faiss_instance.create_col(name="new_collection", distance="inner_product")
|
||||
mock_index_flat_ip.assert_called_once_with(faiss_instance.embedding_model_dims)
|
||||
|
||||
|
||||
def test_insert(faiss_instance, mock_faiss_index):
|
||||
|
||||
101
tests/vector_stores/test_langchain_vector_store.py
Normal file
101
tests/vector_stores/test_langchain_vector_store.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
|
||||
from mem0.vector_stores.langchain import Langchain
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_langchain_client():
|
||||
with patch("langchain_community.vectorstores.VectorStore") as mock_client:
|
||||
yield mock_client
|
||||
|
||||
@pytest.fixture
|
||||
def langchain_instance(mock_langchain_client):
|
||||
mock_client = Mock(spec=VectorStore)
|
||||
return Langchain(client=mock_client, collection_name="test_collection")
|
||||
|
||||
def test_insert_vectors(langchain_instance):
|
||||
# Test data
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
payloads = [{"data": "text1", "name": "vector1"}, {"data": "text2", "name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
# Test with add_embeddings method
|
||||
langchain_instance.client.add_embeddings = Mock()
|
||||
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
langchain_instance.client.add_embeddings.assert_called_once_with(
|
||||
embeddings=vectors,
|
||||
metadatas=payloads,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
# Test with add_texts method
|
||||
delattr(langchain_instance.client, "add_embeddings") # Remove attribute completely
|
||||
langchain_instance.client.add_texts = Mock()
|
||||
langchain_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
langchain_instance.client.add_texts.assert_called_once_with(
|
||||
texts=["text1", "text2"],
|
||||
metadatas=payloads,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
# Test with empty payloads
|
||||
langchain_instance.client.add_texts.reset_mock()
|
||||
langchain_instance.insert(vectors=vectors, payloads=None, ids=ids)
|
||||
langchain_instance.client.add_texts.assert_called_once_with(
|
||||
texts=["", ""],
|
||||
metadatas=None,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
def test_search_vectors(langchain_instance):
|
||||
# Mock search results
|
||||
mock_docs = [
|
||||
Mock(metadata={"name": "vector1"}, id="id1"),
|
||||
Mock(metadata={"name": "vector2"}, id="id2")
|
||||
]
|
||||
langchain_instance.client.similarity_search_by_vector.return_value = mock_docs
|
||||
|
||||
# Test search without filters
|
||||
vectors = [[0.1, 0.2, 0.3]]
|
||||
results = langchain_instance.search(query="", vectors=vectors, limit=2)
|
||||
|
||||
langchain_instance.client.similarity_search_by_vector.assert_called_once_with(
|
||||
embedding=vectors,
|
||||
k=2
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].payload == {"name": "vector1"}
|
||||
assert results[1].id == "id2"
|
||||
assert results[1].payload == {"name": "vector2"}
|
||||
|
||||
# Test search with filters
|
||||
filters = {"name": "vector1"}
|
||||
langchain_instance.search(query="", vectors=vectors, limit=2, filters=filters)
|
||||
langchain_instance.client.similarity_search_by_vector.assert_called_with(
|
||||
embedding=vectors,
|
||||
k=2,
|
||||
filter=filters
|
||||
)
|
||||
|
||||
def test_get_vector(langchain_instance):
|
||||
# Mock get result
|
||||
mock_doc = Mock(metadata={"name": "vector1"}, id="id1")
|
||||
langchain_instance.client.get_by_ids.return_value = [mock_doc]
|
||||
|
||||
# Test get existing vector
|
||||
result = langchain_instance.get("id1")
|
||||
langchain_instance.client.get_by_ids.assert_called_once_with(["id1"])
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
# Test get non-existent vector
|
||||
langchain_instance.client.get_by_ids.return_value = []
|
||||
result = langchain_instance.get("non_existent_id")
|
||||
assert result is None
|
||||
Reference in New Issue
Block a user