[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -3,16 +3,9 @@ import os
import shutil
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
FieldCondition,
Filter,
MatchValue,
PointIdsList,
PointStruct,
Range,
VectorParams,
)
from qdrant_client.models import (Distance, FieldCondition, Filter, MatchValue,
PointIdsList, PointStruct, Range,
VectorParams)
from mem0.vector_stores.base import VectorStoreBase
@@ -68,9 +61,7 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name
self.create_col(embedding_model_dims, on_disk)
def create_col(
self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE
):
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
"""
Create a new collection.
@@ -83,16 +74,12 @@ class Qdrant(VectorStoreBase):
response = self.list_cols()
for collection in response.collections:
if collection.name == self.collection_name:
logging.debug(
f"Collection {self.collection_name} already exists. Skipping creation."
)
logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.")
return
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=vector_size, distance=distance, on_disk=on_disk
),
vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk),
)
def insert(self, vectors: list, payloads: list = None, ids: list = None):
@@ -128,15 +115,9 @@ class Qdrant(VectorStoreBase):
conditions = []
for key, value in filters.items():
if isinstance(value, dict) and "gte" in value and "lte" in value:
conditions.append(
FieldCondition(
key=key, range=Range(gte=value["gte"], lte=value["lte"])
)
)
conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"])))
else:
conditions.append(
FieldCondition(key=key, match=MatchValue(value=value))
)
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
return Filter(must=conditions) if conditions else None
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
@@ -196,9 +177,7 @@ class Qdrant(VectorStoreBase):
Returns:
dict: Retrieved vector.
"""
result = self.client.retrieve(
collection_name=self.collection_name, ids=[vector_id], with_payload=True
)
result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True)
return result[0] if result else None
def list_cols(self) -> list: