[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -1,15 +1,17 @@
import logging
from typing import Dict, Optional
from pydantic import BaseModel
from typing import Optional, Dict
from mem0.vector_stores.base import VectorStoreBase
from mem0.configs.vector_stores.milvus import MetricType
from mem0.vector_stores.base import VectorStoreBase
try:
import pymilvus
import pymilvus # noqa: F401
except ImportError:
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
logger = logging.getLogger(__name__)
@@ -20,9 +22,15 @@ class OutputData(BaseModel):
payload: Optional[Dict] # metadata
class MilvusDB(VectorStoreBase):
def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
def __init__(
self,
url: str,
token: str,
collection_name: str,
embedding_model_dims: int,
metric_type: MetricType,
) -> None:
"""Initialize the MilvusDB database.
Args:
@@ -32,22 +40,21 @@ class MilvusDB(VectorStoreBase):
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType): Metric type for similarity search (defaults to L2).
"""
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
self.client = MilvusClient(uri=url,token=token)
self.client = MilvusClient(uri=url, token=token)
self.create_col(
collection_name=self.collection_name,
vector_size=self.embedding_model_dims,
metric_type=self.metric_type
metric_type=self.metric_type,
)
def create_col(
self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
self,
collection_name: str,
vector_size: str,
metric_type: MetricType = MetricType.COSINE,
) -> None:
"""Create a new collection with index_type AUTOINDEX.
@@ -65,7 +72,7 @@ class MilvusDB(VectorStoreBase):
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="metadata", dtype=DataType.JSON),
]
schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params(
@@ -73,12 +80,10 @@ class MilvusDB(VectorStoreBase):
metric_type=metric_type,
index_type="AUTOINDEX",
index_name="vector_index",
params={ "nlist": 128 }
params={"nlist": 128},
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
"""Insert vectors into a collection.
@@ -91,9 +96,8 @@ class MilvusDB(VectorStoreBase):
data = {"id": idx, "vectors": embedding, "metadata": metadata}
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
def _create_filter(self, filters: dict):
"""Prepare filters for efficient query.
"""Prepare filters for efficient query.
Args:
filters (dict): filters [user_id, agent_id, run_id]
@@ -109,8 +113,7 @@ class MilvusDB(VectorStoreBase):
operands.append(f'(metadata["{key}"] == {value})')
return " and ".join(operands)
def _parse_output(self, data: list):
"""
Parse the output data.
@@ -125,16 +128,15 @@ class MilvusDB(VectorStoreBase):
for value in data:
uid, score, metadata = (
value.get("id"),
value.get("distance"),
value.get("entity",{}).get("metadata")
value.get("id"),
value.get("distance"),
value.get("entity", {}).get("metadata"),
)
memory_obj = OutputData(id=uid, score=score, payload=metadata)
memory.append(memory_obj)
return memory
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
"""
@@ -150,14 +152,15 @@ class MilvusDB(VectorStoreBase):
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
collection_name=self.collection_name,
data=[query], limit=limit, filter=query_filter,
output_fields=["*"]
collection_name=self.collection_name,
data=[query],
limit=limit,
filter=query_filter,
output_fields=["*"],
)
result = self._parse_output(data=hits[0])
return result
def delete(self, vector_id):
"""
Delete a vector by ID.
@@ -166,7 +169,6 @@ class MilvusDB(VectorStoreBase):
vector_id (str): ID of the vector to delete.
"""
self.client.delete(collection_name=self.collection_name, ids=vector_id)
def update(self, vector_id=None, vector=None, payload=None):
"""
@@ -177,7 +179,7 @@ class MilvusDB(VectorStoreBase):
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
schema = {"id": vector_id, "vectors": vector, "metadata": payload}
self.client.upsert(collection_name=self.collection_name, data=schema)
def get(self, vector_id):
@@ -191,7 +193,11 @@ class MilvusDB(VectorStoreBase):
OutputData: Retrieved vector.
"""
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
output = OutputData(
id=result[0].get("id", None),
score=None,
payload=result[0].get("metadata", None),
)
return output
def list_cols(self):
@@ -228,12 +234,9 @@ class MilvusDB(VectorStoreBase):
List[OutputData]: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.query(
collection_name=self.collection_name,
filter=query_filter,
limit=limit)
result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
memories = []
for data in result:
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
return [memories]
return [memories]