[Mem0] Update dependencies and make the package lighter (#1708)
Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
@@ -7,14 +7,16 @@ try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
except ImportError:
|
||||
raise ImportError("Chromadb requires extra dependencies. Install with `pip install chromadb`") from None
|
||||
raise ImportError(
|
||||
"Chromadb requires extra dependencies. Install with `pip install chromadb`"
|
||||
) from None
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
@@ -25,7 +27,7 @@ class ChromaDB(VectorStoreBase):
|
||||
client: Optional[chromadb.Client] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
path: Optional[str] = None
|
||||
path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Chromadb vector store.
|
||||
@@ -68,7 +70,7 @@ class ChromaDB(VectorStoreBase):
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
keys = ['ids', 'distances', 'metadatas']
|
||||
keys = ["ids", "distances", "metadatas"]
|
||||
values = []
|
||||
|
||||
for key in keys:
|
||||
@@ -78,14 +80,24 @@ class ChromaDB(VectorStoreBase):
|
||||
values.append(value)
|
||||
|
||||
ids, distances, metadatas = values
|
||||
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
|
||||
max_length = max(
|
||||
len(v) for v in values if isinstance(v, list) and v is not None
|
||||
)
|
||||
|
||||
result = []
|
||||
for i in range(max_length):
|
||||
entry = OutputData(
|
||||
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
|
||||
score=distances[i] if isinstance(distances, list) and distances and i < len(distances) else None,
|
||||
payload=metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None,
|
||||
score=(
|
||||
distances[i]
|
||||
if isinstance(distances, list) and distances and i < len(distances)
|
||||
else None
|
||||
),
|
||||
payload=(
|
||||
metadatas[i]
|
||||
if isinstance(metadatas, list) and metadatas and i < len(metadatas)
|
||||
else None
|
||||
),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
@@ -114,7 +126,12 @@ class ChromaDB(VectorStoreBase):
|
||||
)
|
||||
return collection
|
||||
|
||||
def insert(self, vectors: List[list], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None):
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[list],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Insert vectors into a collection.
|
||||
|
||||
@@ -125,7 +142,9 @@ class ChromaDB(VectorStoreBase):
|
||||
"""
|
||||
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
|
||||
|
||||
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
def search(
|
||||
self, query: List[list], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
@@ -137,7 +156,9 @@ class ChromaDB(VectorStoreBase):
|
||||
Returns:
|
||||
List[OutputData]: Search results.
|
||||
"""
|
||||
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
|
||||
results = self.collection.query(
|
||||
query_embeddings=query, where=filters, n_results=limit
|
||||
)
|
||||
final_results = self._parse_output(results)
|
||||
return final_results
|
||||
|
||||
@@ -150,7 +171,12 @@ class ChromaDB(VectorStoreBase):
|
||||
"""
|
||||
self.collection.delete(ids=vector_id)
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
|
||||
def update(
|
||||
self,
|
||||
vector_id: str,
|
||||
vector: Optional[List[float]] = None,
|
||||
payload: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
@@ -184,8 +210,8 @@ class ChromaDB(VectorStoreBase):
|
||||
return self.client.list_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
self.client.delete_collection(name=self.collection_name)
|
||||
|
||||
@@ -198,7 +224,9 @@ class ChromaDB(VectorStoreBase):
|
||||
"""
|
||||
return self.client.get_collection(name=self.collection_name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
def list(
|
||||
self, filters: Optional[Dict] = None, limit: int = 100
|
||||
) -> List[OutputData]:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user