Files
t6_mem0/mem0/vector_stores/langchain.py
2025-04-11 13:37:18 +05:30

162 lines
5.3 KiB
Python

from typing import Dict, List, Optional
from pydantic import BaseModel
try:
from langchain_community.vectorstores import VectorStore
except ImportError:
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.")
from mem0.vector_stores.base import VectorStoreBase
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class Langchain(VectorStoreBase):
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
self.client = client
self.collection_name = collection_name
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
Args:
data (Dict): Output data or list of Document objects.
Returns:
List[OutputData]: Parsed output data.
"""
# Check if input is a list of Document objects
if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')):
result = []
for doc in data:
entry = OutputData(
id=getattr(doc, "id", None),
score=None, # Document objects typically don't include scores
payload=getattr(doc, "metadata", {})
)
result.append(entry)
return result
# Original format handling
keys = ["ids", "distances", "metadatas"]
values = []
for key in keys:
value = data.get(key, [])
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
values.append(value)
ids, distances, metadatas = values
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
return result
def create_col(self, name, vector_size=None, distance=None):
self.collection_name = name
return self.client
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
"""
Insert vectors into the LangChain vectorstore.
"""
# Check if client has add_embeddings method
if hasattr(self.client, "add_embeddings"):
# Some LangChain vectorstores have a direct add_embeddings method
self.client.add_embeddings(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
else:
# Fallback to add_texts method
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
self.client.add_texts(
texts=texts,
metadatas=payloads,
ids=ids
)
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
"""
Search for similar vectors in LangChain.
"""
# For each vector, perform a similarity search
if filters:
results = self.client.similarity_search_by_vector(
embedding=vectors,
k=limit,
filter=filters
)
else:
results = self.client.similarity_search_by_vector(
embedding=vectors,
k=limit
)
final_results = self._parse_output(results)
return final_results
def delete(self, vector_id):
"""
Delete a vector by ID.
"""
self.client.delete(ids=[vector_id])
def update(self, vector_id, vector=None, payload=None):
"""
Update a vector and its payload.
"""
self.delete(vector_id)
self.insert(vector, payload, [vector_id])
def get(self, vector_id):
"""
Retrieve a vector by ID.
"""
docs = self.client.get_by_ids([vector_id])
if docs and len(docs) > 0:
doc = docs[0]
return self._parse_output([doc])[0]
return None
def list_cols(self):
"""
List all collections.
"""
# LangChain doesn't have collections
return [self.collection_name]
def delete_col(self):
"""
Delete a collection.
"""
self.client.delete(ids=None)
def col_info(self):
"""
Get information about a collection.
"""
return {"name": self.collection_name}
def list(self, filters=None, limit=None):
"""
List all vectors in a collection.
"""
# This would require implementation-specific access to the underlying store
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores")