Fix user_id functionality (#2548)

This commit is contained in:
Dev Khant
2025-04-16 13:32:33 +05:30
committed by GitHub
parent 541030d69c
commit 3613e2f14a
9 changed files with 86 additions and 49 deletions

View File

@@ -5,7 +5,9 @@ from pydantic import BaseModel
try:
from langchain_community.vectorstores import VectorStore
except ImportError:
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.")
raise ImportError(
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
)
from mem0.vector_stores.base import VectorStoreBase
@@ -15,11 +17,12 @@ class OutputData(BaseModel):
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class Langchain(VectorStoreBase):
def __init__(self, client: VectorStore, collection_name: str = "mem0"):
self.client = client
self.collection_name = collection_name
def _parse_output(self, data: Dict) -> List[OutputData]:
"""
Parse the output data.
@@ -31,17 +34,17 @@ class Langchain(VectorStoreBase):
List[OutputData]: Parsed output data.
"""
# Check if input is a list of Document objects
if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')):
if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")):
result = []
for doc in data:
entry = OutputData(
id=getattr(doc, "id", None),
score=None, # Document objects typically don't include scores
payload=getattr(doc, "metadata", {})
payload=getattr(doc, "metadata", {}),
)
result.append(entry)
return result
# Original format handling
keys = ["ids", "distances", "metadatas"]
values = []
@@ -70,26 +73,20 @@ class Langchain(VectorStoreBase):
self.collection_name = name
return self.client
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
):
"""
Insert vectors into the LangChain vectorstore.
"""
# Check if client has add_embeddings method
if hasattr(self.client, "add_embeddings"):
# Some LangChain vectorstores have a direct add_embeddings method
self.client.add_embeddings(
embeddings=vectors,
metadatas=payloads,
ids=ids
)
self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids)
else:
# Fallback to add_texts method
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
self.client.add_texts(
texts=texts,
metadatas=payloads,
ids=ids
)
self.client.add_texts(texts=texts, metadatas=payloads, ids=ids)
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
"""
@@ -97,16 +94,9 @@ class Langchain(VectorStoreBase):
"""
# For each vector, perform a similarity search
if filters:
results = self.client.similarity_search_by_vector(
embedding=vectors,
k=limit,
filter=filters
)
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters)
else:
results = self.client.similarity_search_by_vector(
embedding=vectors,
k=limit
)
results = self.client.similarity_search_by_vector(embedding=vectors, k=limit)
final_results = self._parse_output(results)
return final_results
@@ -133,26 +123,26 @@ class Langchain(VectorStoreBase):
doc = docs[0]
return self._parse_output([doc])[0]
return None
def list_cols(self):
"""
List all collections.
"""
# LangChain doesn't have collections
return [self.collection_name]
def delete_col(self):
"""
Delete a collection.
"""
self.client.delete(ids=None)
def col_info(self):
"""
Get information about a collection.
"""
return {"name": self.collection_name}
def list(self, filters=None, limit=None):
"""
List all vectors in a collection.