[Bugfix] fix qdrant and weaviate db integration (#1181)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -11,6 +11,8 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config.vectordb.qdrant import QdrantDBConfig
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
@@ -48,7 +50,6 @@ class QdrantDB(BaseVectorDB):
|
||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||
|
||||
self.collection_name = self._get_or_create_collection()
|
||||
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
|
||||
all_collections = self.client.get_collections()
|
||||
collection_names = [collection.name for collection in all_collections.collections]
|
||||
if self.collection_name not in collection_names:
|
||||
@@ -82,21 +83,23 @@ class QdrantDB(BaseVectorDB):
|
||||
:return: All the existing IDs
|
||||
:rtype: Set[str]
|
||||
"""
|
||||
if ids is None or len(ids) == 0:
|
||||
return {"ids": []}
|
||||
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
qdrant_must_filters = [
|
||||
models.FieldCondition(
|
||||
key="identifier",
|
||||
match=models.MatchAny(
|
||||
any=ids,
|
||||
),
|
||||
qdrant_must_filters = []
|
||||
|
||||
if ids:
|
||||
qdrant_must_filters.append(
|
||||
models.FieldCondition(
|
||||
key="identifier",
|
||||
match=models.MatchAny(
|
||||
any=ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||
for key in keys.intersection(self.metadata_keys):
|
||||
|
||||
if len(keys) > 0:
|
||||
for key in keys:
|
||||
qdrant_must_filters.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.{}".format(key),
|
||||
@@ -108,6 +111,7 @@ class QdrantDB(BaseVectorDB):
|
||||
|
||||
offset = 0
|
||||
existing_ids = []
|
||||
metadatas = []
|
||||
while offset is not None:
|
||||
response = self.client.scroll(
|
||||
collection_name=self.collection_name,
|
||||
@@ -118,7 +122,8 @@ class QdrantDB(BaseVectorDB):
|
||||
offset = response[1]
|
||||
for doc in response[0]:
|
||||
existing_ids.append(doc.payload["identifier"])
|
||||
return {"ids": existing_ids}
|
||||
metadatas.append(doc.payload["metadata"])
|
||||
return {"ids": existing_ids, "metadatas": metadatas}
|
||||
|
||||
def add(
|
||||
self,
|
||||
@@ -143,7 +148,8 @@ class QdrantDB(BaseVectorDB):
|
||||
metadata["text"] = document
|
||||
qdrant_ids.append(str(uuid.uuid4()))
|
||||
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
|
||||
for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
|
||||
|
||||
for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"):
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=Batch(
|
||||
@@ -180,16 +186,17 @@ class QdrantDB(BaseVectorDB):
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
qdrant_must_filters = []
|
||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||
for key in keys.intersection(self.metadata_keys):
|
||||
if len(keys) > 0:
|
||||
for key in keys:
|
||||
qdrant_must_filters.append(
|
||||
models.FieldCondition(
|
||||
key="payload.metadata.{}".format(key),
|
||||
key="metadata.{}".format(key),
|
||||
match=models.MatchValue(
|
||||
value=where.get(key),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_filter=models.Filter(must=qdrant_must_filters),
|
||||
@@ -228,3 +235,21 @@ class QdrantDB(BaseVectorDB):
|
||||
raise TypeError("Collection name must be a string")
|
||||
self.config.collection_name = name
|
||||
self.collection_name = self._get_or_create_collection()
|
||||
|
||||
@staticmethod
|
||||
def _generate_query(where: dict):
|
||||
must_fields = []
|
||||
for key, value in where.items():
|
||||
must_fields.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(
|
||||
value=value,
|
||||
),
|
||||
)
|
||||
)
|
||||
return models.Filter(must=must_fields)
|
||||
|
||||
def delete(self, where: dict):
|
||||
db_filter = self._generate_query(where)
|
||||
self.client.delete(collection_name=self.collection_name, points_selector=db_filter)
|
||||
|
||||
Reference in New Issue
Block a user