diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 1f0a1b6d..bb72fb3d 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -11,6 +11,8 @@ try: except ImportError: raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None +from tqdm import tqdm + from embedchain.config.vectordb.qdrant import QdrantDBConfig from embedchain.vectordb.base import BaseVectorDB @@ -48,7 +50,6 @@ class QdrantDB(BaseVectorDB): raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") self.collection_name = self._get_or_create_collection() - self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"} all_collections = self.client.get_collections() collection_names = [collection.name for collection in all_collections.collections] if self.collection_name not in collection_names: @@ -82,21 +83,23 @@ class QdrantDB(BaseVectorDB): :return: All the existing IDs :rtype: Set[str] """ - if ids is None or len(ids) == 0: - return {"ids": []} keys = set(where.keys() if where is not None else set()) - qdrant_must_filters = [ - models.FieldCondition( - key="identifier", - match=models.MatchAny( - any=ids, - ), + qdrant_must_filters = [] + + if ids: + qdrant_must_filters.append( + models.FieldCondition( + key="identifier", + match=models.MatchAny( + any=ids, + ), + ) ) - ] - if len(keys.intersection(self.metadata_keys)) != 0: - for key in keys.intersection(self.metadata_keys): + + if len(keys) > 0: + for key in keys: qdrant_must_filters.append( models.FieldCondition( key="metadata.{}".format(key), @@ -108,6 +111,7 @@ class QdrantDB(BaseVectorDB): offset = 0 existing_ids = [] + metadatas = [] while offset is not None: response = self.client.scroll( collection_name=self.collection_name, @@ -118,7 +122,8 @@ class QdrantDB(BaseVectorDB): offset = response[1] for doc in response[0]: existing_ids.append(doc.payload["identifier"]) - return {"ids": existing_ids} + metadatas.append(doc.payload["metadata"]) + return {"ids": existing_ids, "metadatas": metadatas} def add( self, @@ -143,7 +148,8 @@ class QdrantDB(BaseVectorDB): metadata["text"] = document qdrant_ids.append(str(uuid.uuid4())) payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) - for i in range(0, len(qdrant_ids), self.BATCH_SIZE): + + for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"): self.client.upsert( collection_name=self.collection_name, points=Batch( @@ -180,16 +186,17 @@ class QdrantDB(BaseVectorDB): keys = set(where.keys() if where is not None else set()) qdrant_must_filters = [] - if len(keys.intersection(self.metadata_keys)) != 0: - for key in keys.intersection(self.metadata_keys): + if len(keys) > 0: + for key in keys: qdrant_must_filters.append( models.FieldCondition( - key="payload.metadata.{}".format(key), + key="metadata.{}".format(key), match=models.MatchValue( value=where.get(key), ), ) ) + results = self.client.search( collection_name=self.collection_name, query_filter=models.Filter(must=qdrant_must_filters), @@ -228,3 +235,21 @@ class QdrantDB(BaseVectorDB): raise TypeError("Collection name must be a string") self.config.collection_name = name self.collection_name = self._get_or_create_collection() + + @staticmethod + def _generate_query(where: dict): + must_fields = [] + for key, value in where.items(): + must_fields.append( + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue( + value=value, + ), + ) + ) + return models.Filter(must=must_fields) + + def delete(self, where: dict): + db_filter = self._generate_query(where) + self.client.delete(collection_name=self.collection_name, points_selector=db_filter) diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index d2b70a04..446db496 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -45,6 +45,9 @@ class WeaviateDB(BaseVectorDB): auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")), **self.config.extra_params, ) + # Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb. + # This is needed to filter data while querying. + self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"} # Call parent init here because embedder is needed super().__init__(config=self.config) @@ -58,7 +61,6 @@ class WeaviateDB(BaseVectorDB): raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") self.index_name = self._get_index_name() - self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"} if not self.client.schema.exists(self.index_name): # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier # The none vectorizer is crucial as we have our own custom embedding function @@ -127,29 +129,64 @@ class WeaviateDB(BaseVectorDB): :return: ids :rtype: Set[str] """ + weaviate_where_operands = [] - if ids is None or len(ids) == 0: - return {"ids": []} + if ids: + for doc_id in ids: + weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id}) + + keys = set(where.keys() if where is not None else set()) + if len(keys) > 0: + for key in keys: + weaviate_where_operands.append( + { + "path": ["metadata", self.index_name + "_metadata", key], + "operator": "Equal", + "valueText": where.get(key), + } + ) + + if len(weaviate_where_operands) == 1: + weaviate_where_clause = weaviate_where_operands[0] + else: + weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} existing_ids = [] + metadatas = [] cursor = None + offset = 0 has_iterated_once = False + query_metadata_keys = self.metadata_keys.union(keys) while cursor is not None or not has_iterated_once: has_iterated_once = True - results = self._query_with_cursor( - self.client.query.get(self.index_name, ["identifier"]) + results = self._query_with_offset( + self.client.query.get( + self.index_name, + [ + "identifier", + weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)), + ], + ) + .with_where(weaviate_where_clause) .with_additional(["id"]) - .with_limit(self.BATCH_SIZE), - cursor, + .with_limit(limit or self.BATCH_SIZE), + offset, ) + fetched_results = results["data"]["Get"].get(self.index_name, []) - if len(fetched_results) == 0: + if not fetched_results: break + for result in fetched_results: existing_ids.append(result["identifier"]) + metadatas.append(result["metadata"][0]) cursor = result["_additional"]["id"] + offset += 1 - return {"ids": existing_ids} + if limit is not None and len(existing_ids) >= limit: + break + + return {"ids": existing_ids, "metadatas": metadatas} def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]): """add data in vector database @@ -201,21 +238,20 @@ class WeaviateDB(BaseVectorDB): query_vector = self.embedder.embedding_fn([input_query])[0] keys = set(where.keys() if where is not None else set()) data_fields = ["text"] - + query_metadata_keys = self.metadata_keys.union(keys) if citations: - data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys))) + data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys))) - if len(keys.intersection(self.metadata_keys)) != 0: + if len(keys) > 0: weaviate_where_operands = [] for key in keys: - if key in self.metadata_keys: - weaviate_where_operands.append( - { - "path": ["metadata", self.index_name + "_metadata", key], - "operator": "Equal", - "valueText": where.get(key), - } - ) + weaviate_where_operands.append( + { + "path": ["metadata", self.index_name + "_metadata", key], + "operator": "Equal", + "valueText": where.get(key), + } + ) if len(weaviate_where_operands) == 1: weaviate_where_clause = weaviate_where_operands[0] else: @@ -289,11 +325,37 @@ class WeaviateDB(BaseVectorDB): :return: Weaviate index :rtype: str """ - return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize() + return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_") @staticmethod - def _query_with_cursor(query, cursor): - if cursor is not None: - query.with_after(cursor) + def _query_with_offset(query, offset): + if offset: + query.with_offset(offset) results = query.do() return results + + def _generate_query(self, where: dict): + weaviate_where_operands = [] + for key, value in where.items(): + weaviate_where_operands.append( + { + "path": ["metadata", self.index_name + "_metadata", key], + "operator": "Equal", + "valueText": value, + } + ) + + if len(weaviate_where_operands) == 1: + weaviate_where_clause = weaviate_where_operands[0] + else: + weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} + + return weaviate_where_clause + + def delete(self, where: dict): + """Delete from database. + :param where: to filter data + :type where: dict[str, any] + """ + query = self._generate_query(where) + self.client.batch.delete_objects(self.index_name, where=query) diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py index 563f50be..c12c6848 100644 --- a/tests/vectordb/test_qdrant.py +++ b/tests/vectordb/test_qdrant.py @@ -56,9 +56,9 @@ class TestQdrantDB(unittest.TestCase): App(config=app_config, db=db, embedding_model=embedder) resp = db.get(ids=[], where={}) - self.assertEqual(resp, {"ids": []}) + self.assertEqual(resp, {"ids": [], "metadatas": []}) resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"}) - self.assertEqual(resp2, {"ids": []}) + self.assertEqual(resp2, {"ids": [], "metadatas": []}) @patch("embedchain.vectordb.qdrant.QdrantClient") @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS) @@ -119,7 +119,7 @@ class TestQdrantDB(unittest.TestCase): query_filter=models.Filter( must=[ models.FieldCondition( - key="payload.metadata.doc_id", + key="metadata.doc_id", match=models.MatchValue( value="123", ),