[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -80,24 +80,14 @@ class ChromaDB(VectorStoreBase):
values.append(value)
ids, distances, metadatas = values
max_length = max(
len(v) for v in values if isinstance(v, list) and v is not None
)
max_length = max(len(v) for v in values if isinstance(v, list) and v is not None)
result = []
for i in range(max_length):
entry = OutputData(
id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None,
score=(
distances[i]
if isinstance(distances, list) and distances and i < len(distances)
else None
),
payload=(
metadatas[i]
if isinstance(metadatas, list) and metadatas and i < len(metadatas)
else None
),
score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None),
payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None),
)
result.append(entry)
@@ -143,9 +133,7 @@ class ChromaDB(VectorStoreBase):
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(
self, query: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
"""
Search for similar vectors.
@@ -157,9 +145,7 @@ class ChromaDB(VectorStoreBase):
Returns:
List[OutputData]: Search results.
"""
results = self.collection.query(
query_embeddings=query, where=filters, n_results=limit
)
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
final_results = self._parse_output(results)
return final_results
@@ -225,9 +211,7 @@ class ChromaDB(VectorStoreBase):
"""
return self.client.get_collection(name=self.collection_name)
def list(
self, filters: Optional[Dict] = None, limit: int = 100
) -> List[OutputData]:
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
"""
List all vectors in a collection.

View File

@@ -8,15 +8,13 @@ class VectorStoreConfig(BaseModel):
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
default="qdrant",
)
config: Optional[Dict] = Field(
description="Configuration for the specific vector store", default=None
)
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
_provider_configs: Dict[str, str] = {
"qdrant": "QdrantConfig",
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
"milvus" : "MilvusDBConfig"
"milvus": "MilvusDBConfig",
}
@model_validator(mode="after")

View File

@@ -1,15 +1,17 @@
import logging
from typing import Dict, Optional
from pydantic import BaseModel
from typing import Optional, Dict
from mem0.vector_stores.base import VectorStoreBase
from mem0.configs.vector_stores.milvus import MetricType
from mem0.vector_stores.base import VectorStoreBase
try:
import pymilvus
import pymilvus # noqa: F401
except ImportError:
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
logger = logging.getLogger(__name__)
@@ -20,9 +22,15 @@ class OutputData(BaseModel):
payload: Optional[Dict] # metadata
class MilvusDB(VectorStoreBase):
def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
def __init__(
self,
url: str,
token: str,
collection_name: str,
embedding_model_dims: int,
metric_type: MetricType,
) -> None:
"""Initialize the MilvusDB database.
Args:
@@ -32,22 +40,21 @@ class MilvusDB(VectorStoreBase):
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType): Metric type for similarity search (defaults to L2).
"""
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
self.client = MilvusClient(uri=url,token=token)
self.client = MilvusClient(uri=url, token=token)
self.create_col(
collection_name=self.collection_name,
vector_size=self.embedding_model_dims,
metric_type=self.metric_type
metric_type=self.metric_type,
)
def create_col(
self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
self,
collection_name: str,
vector_size: str,
metric_type: MetricType = MetricType.COSINE,
) -> None:
"""Create a new collection with index_type AUTOINDEX.
@@ -65,7 +72,7 @@ class MilvusDB(VectorStoreBase):
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="metadata", dtype=DataType.JSON),
]
schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params(
@@ -73,12 +80,10 @@ class MilvusDB(VectorStoreBase):
metric_type=metric_type,
index_type="AUTOINDEX",
index_name="vector_index",
params={ "nlist": 128 }
params={"nlist": 128},
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
"""Insert vectors into a collection.
@@ -91,9 +96,8 @@ class MilvusDB(VectorStoreBase):
data = {"id": idx, "vectors": embedding, "metadata": metadata}
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
def _create_filter(self, filters: dict):
"""Prepare filters for efficient query.
"""Prepare filters for efficient query.
Args:
filters (dict): filters [user_id, agent_id, run_id]
@@ -109,8 +113,7 @@ class MilvusDB(VectorStoreBase):
operands.append(f'(metadata["{key}"] == {value})')
return " and ".join(operands)
def _parse_output(self, data: list):
"""
Parse the output data.
@@ -125,16 +128,15 @@ class MilvusDB(VectorStoreBase):
for value in data:
uid, score, metadata = (
value.get("id"),
value.get("distance"),
value.get("entity",{}).get("metadata")
value.get("id"),
value.get("distance"),
value.get("entity", {}).get("metadata"),
)
memory_obj = OutputData(id=uid, score=score, payload=metadata)
memory.append(memory_obj)
return memory
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
"""
@@ -150,14 +152,15 @@ class MilvusDB(VectorStoreBase):
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
collection_name=self.collection_name,
data=[query], limit=limit, filter=query_filter,
output_fields=["*"]
collection_name=self.collection_name,
data=[query],
limit=limit,
filter=query_filter,
output_fields=["*"],
)
result = self._parse_output(data=hits[0])
return result
def delete(self, vector_id):
"""
Delete a vector by ID.
@@ -166,7 +169,6 @@ class MilvusDB(VectorStoreBase):
vector_id (str): ID of the vector to delete.
"""
self.client.delete(collection_name=self.collection_name, ids=vector_id)
def update(self, vector_id=None, vector=None, payload=None):
"""
@@ -177,7 +179,7 @@ class MilvusDB(VectorStoreBase):
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
schema = {"id": vector_id, "vectors": vector, "metadata": payload}
self.client.upsert(collection_name=self.collection_name, data=schema)
def get(self, vector_id):
@@ -191,7 +193,11 @@ class MilvusDB(VectorStoreBase):
OutputData: Retrieved vector.
"""
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
output = OutputData(
id=result[0].get("id", None),
score=None,
payload=result[0].get("metadata", None),
)
return output
def list_cols(self):
@@ -228,12 +234,9 @@ class MilvusDB(VectorStoreBase):
List[OutputData]: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.query(
collection_name=self.collection_name,
filter=query_filter,
limit=limit)
result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit)
memories = []
for data in result:
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
return [memories]
return [memories]

View File

@@ -14,6 +14,7 @@ from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
@@ -22,7 +23,15 @@ class OutputData(BaseModel):
class PGVector(VectorStoreBase):
def __init__(
self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann
self,
dbname,
collection_name,
embedding_model_dims,
user,
password,
host,
port,
diskann,
):
"""
Initialize the PGVector database.
@@ -40,9 +49,7 @@ class PGVector(VectorStoreBase):
self.collection_name = collection_name
self.use_diskann = diskann
self.conn = psycopg2.connect(
dbname=dbname, user=user, password=password, host=host, port=port
)
self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
self.cur = self.conn.cursor()
collections = self.list_cols()
@@ -73,7 +80,8 @@ class PGVector(VectorStoreBase):
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
if self.cur.fetchone():
# Create DiskANN index if extension is installed for faster search
self.cur.execute(f"""
self.cur.execute(
f"""
CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx
ON {self.collection_name}
USING diskann (vector);
@@ -94,10 +102,7 @@ class PGVector(VectorStoreBase):
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
json_payloads = [json.dumps(payload) for payload in payloads]
data = [
(id, vector, payload)
for id, vector, payload in zip(ids, vectors, json_payloads)
]
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
execute_values(
self.cur,
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
@@ -125,9 +130,7 @@ class PGVector(VectorStoreBase):
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
filter_clause = (
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
)
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
self.cur.execute(
f"""
@@ -137,13 +140,11 @@ class PGVector(VectorStoreBase):
ORDER BY distance
LIMIT %s
""",
(query, *filter_params, limit),
(query, *filter_params, limit),
)
results = self.cur.fetchall()
return [
OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results
]
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
def delete(self, vector_id):
"""
@@ -152,9 +153,7 @@ class PGVector(VectorStoreBase):
Args:
vector_id (str): ID of the vector to delete.
"""
self.cur.execute(
f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)
)
self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
self.conn.commit()
def update(self, vector_id, vector=None, payload=None):
@@ -204,9 +203,7 @@ class PGVector(VectorStoreBase):
Returns:
List[str]: List of collection names.
"""
self.cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
return [row[0] for row in self.cur.fetchall()]
def delete_col(self):
@@ -254,9 +251,7 @@ class PGVector(VectorStoreBase):
filter_conditions.append("payload->>%s = %s")
filter_params.extend([k, str(v)])
filter_clause = (
"WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
)
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
query = f"""
SELECT id, vector, payload

View File

@@ -3,16 +3,9 @@ import os
import shutil
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
MatchValue,
PointIdsList,
PointStruct,
Range,
VectorParams,
)
from qdrant_client.models import (Distance, FieldCondition, Filter, MatchValue,
PointIdsList, PointStruct, Range,
VectorParams)
from mem0.vector_stores.base import VectorStoreBase
@@ -68,9 +61,7 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name
self.create_col(embedding_model_dims, on_disk)
def create_col(
self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE
):
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
"""
Create a new collection.
@@ -83,16 +74,12 @@ class Qdrant(VectorStoreBase):
response = self.list_cols()
for collection in response.collections:
if collection.name == self.collection_name:
logging.debug(
f"Collection {self.collection_name} already exists. Skipping creation."
)
logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=vector_size, distance=distance, on_disk=on_disk
),
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
)
def insert(self, vectors: list, payloads: list = None, ids: list = None):
@@ -128,15 +115,9 @@ class Qdrant(VectorStoreBase):
conditions = []
for key, value in filters.items():
if isinstance(value, dict) and "gte" in value and "lte" in value:
conditions.append(
FieldCondition(
key=key, range=Range(gte=value["gte"], lte=value["lte"])
)
)
conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
else:
conditions.append(
FieldCondition(key=key, match=MatchValue(value=value))
)
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
return Filter(must=conditions) if conditions else None
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
@@ -196,9 +177,7 @@ class Qdrant(VectorStoreBase):
Returns:
dict: Retrieved vector.
"""
result = self.client.retrieve(
collection_name=self.collection_name, ids=[vector_id], with_payload=True
)
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
return result[0] if result else None
def list_cols(self) -> list: