[Bugfix] fix qdrant and weaviate db integration (#1181)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-23 14:24:29 +05:30
committed by GitHub
parent 22e14b5e65
commit 2d9fbd4e49
3 changed files with 131 additions and 44 deletions

View File

@@ -11,6 +11,8 @@ try:
except ImportError: except ImportError:
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
from tqdm import tqdm
from embedchain.config.vectordb.qdrant import QdrantDBConfig from embedchain.config.vectordb.qdrant import QdrantDBConfig
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -48,7 +50,6 @@ class QdrantDB(BaseVectorDB):
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self.collection_name = self._get_or_create_collection() self.collection_name = self._get_or_create_collection()
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
all_collections = self.client.get_collections() all_collections = self.client.get_collections()
collection_names = [collection.name for collection in all_collections.collections] collection_names = [collection.name for collection in all_collections.collections]
if self.collection_name not in collection_names: if self.collection_name not in collection_names:
@@ -82,21 +83,23 @@ class QdrantDB(BaseVectorDB):
:return: All the existing IDs :return: All the existing IDs
:rtype: Set[str] :rtype: Set[str]
""" """
if ids is None or len(ids) == 0:
return {"ids": []}
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())
qdrant_must_filters = [ qdrant_must_filters = []
models.FieldCondition(
key="identifier", if ids:
match=models.MatchAny( qdrant_must_filters.append(
any=ids, models.FieldCondition(
), key="identifier",
match=models.MatchAny(
any=ids,
),
)
) )
]
if len(keys.intersection(self.metadata_keys)) != 0: if len(keys) > 0:
for key in keys.intersection(self.metadata_keys): for key in keys:
qdrant_must_filters.append( qdrant_must_filters.append(
models.FieldCondition( models.FieldCondition(
key="metadata.{}".format(key), key="metadata.{}".format(key),
@@ -108,6 +111,7 @@ class QdrantDB(BaseVectorDB):
offset = 0 offset = 0
existing_ids = [] existing_ids = []
metadatas = []
while offset is not None: while offset is not None:
response = self.client.scroll( response = self.client.scroll(
collection_name=self.collection_name, collection_name=self.collection_name,
@@ -118,7 +122,8 @@ class QdrantDB(BaseVectorDB):
offset = response[1] offset = response[1]
for doc in response[0]: for doc in response[0]:
existing_ids.append(doc.payload["identifier"]) existing_ids.append(doc.payload["identifier"])
return {"ids": existing_ids} metadatas.append(doc.payload["metadata"])
return {"ids": existing_ids, "metadatas": metadatas}
def add( def add(
self, self,
@@ -143,7 +148,8 @@ class QdrantDB(BaseVectorDB):
metadata["text"] = document metadata["text"] = document
qdrant_ids.append(str(uuid.uuid4())) qdrant_ids.append(str(uuid.uuid4()))
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"):
self.client.upsert( self.client.upsert(
collection_name=self.collection_name, collection_name=self.collection_name,
points=Batch( points=Batch(
@@ -180,16 +186,17 @@ class QdrantDB(BaseVectorDB):
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())
qdrant_must_filters = [] qdrant_must_filters = []
if len(keys.intersection(self.metadata_keys)) != 0: if len(keys) > 0:
for key in keys.intersection(self.metadata_keys): for key in keys:
qdrant_must_filters.append( qdrant_must_filters.append(
models.FieldCondition( models.FieldCondition(
key="payload.metadata.{}".format(key), key="metadata.{}".format(key),
match=models.MatchValue( match=models.MatchValue(
value=where.get(key), value=where.get(key),
), ),
) )
) )
results = self.client.search( results = self.client.search(
collection_name=self.collection_name, collection_name=self.collection_name,
query_filter=models.Filter(must=qdrant_must_filters), query_filter=models.Filter(must=qdrant_must_filters),
@@ -228,3 +235,21 @@ class QdrantDB(BaseVectorDB):
raise TypeError("Collection name must be a string") raise TypeError("Collection name must be a string")
self.config.collection_name = name self.config.collection_name = name
self.collection_name = self._get_or_create_collection() self.collection_name = self._get_or_create_collection()
@staticmethod
def _generate_query(where: dict):
must_fields = []
for key, value in where.items():
must_fields.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(
value=value,
),
)
)
return models.Filter(must=must_fields)
def delete(self, where: dict):
db_filter = self._generate_query(where)
self.client.delete(collection_name=self.collection_name, points_selector=db_filter)

View File

@@ -45,6 +45,9 @@ class WeaviateDB(BaseVectorDB):
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")), auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
**self.config.extra_params, **self.config.extra_params,
) )
# Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb.
# This is needed to filter data while querying.
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
# Call parent init here because embedder is needed # Call parent init here because embedder is needed
super().__init__(config=self.config) super().__init__(config=self.config)
@@ -58,7 +61,6 @@ class WeaviateDB(BaseVectorDB):
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self.index_name = self._get_index_name() self.index_name = self._get_index_name()
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
if not self.client.schema.exists(self.index_name): if not self.client.schema.exists(self.index_name):
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
# The none vectorizer is crucial as we have our own custom embedding function # The none vectorizer is crucial as we have our own custom embedding function
@@ -127,29 +129,64 @@ class WeaviateDB(BaseVectorDB):
:return: ids :return: ids
:rtype: Set[str] :rtype: Set[str]
""" """
weaviate_where_operands = []
if ids is None or len(ids) == 0: if ids:
return {"ids": []} for doc_id in ids:
weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id})
keys = set(where.keys() if where is not None else set())
if len(keys) > 0:
for key in keys:
weaviate_where_operands.append(
{
"path": ["metadata", self.index_name + "_metadata", key],
"operator": "Equal",
"valueText": where.get(key),
}
)
if len(weaviate_where_operands) == 1:
weaviate_where_clause = weaviate_where_operands[0]
else:
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
existing_ids = [] existing_ids = []
metadatas = []
cursor = None cursor = None
offset = 0
has_iterated_once = False has_iterated_once = False
query_metadata_keys = self.metadata_keys.union(keys)
while cursor is not None or not has_iterated_once: while cursor is not None or not has_iterated_once:
has_iterated_once = True has_iterated_once = True
results = self._query_with_cursor( results = self._query_with_offset(
self.client.query.get(self.index_name, ["identifier"]) self.client.query.get(
self.index_name,
[
"identifier",
weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)),
],
)
.with_where(weaviate_where_clause)
.with_additional(["id"]) .with_additional(["id"])
.with_limit(self.BATCH_SIZE), .with_limit(limit or self.BATCH_SIZE),
cursor, offset,
) )
fetched_results = results["data"]["Get"].get(self.index_name, []) fetched_results = results["data"]["Get"].get(self.index_name, [])
if len(fetched_results) == 0: if not fetched_results:
break break
for result in fetched_results: for result in fetched_results:
existing_ids.append(result["identifier"]) existing_ids.append(result["identifier"])
metadatas.append(result["metadata"][0])
cursor = result["_additional"]["id"] cursor = result["_additional"]["id"]
offset += 1
return {"ids": existing_ids} if limit is not None and len(existing_ids) >= limit:
break
return {"ids": existing_ids, "metadatas": metadatas}
def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]): def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
"""add data in vector database """add data in vector database
@@ -201,21 +238,20 @@ class WeaviateDB(BaseVectorDB):
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())
data_fields = ["text"] data_fields = ["text"]
query_metadata_keys = self.metadata_keys.union(keys)
if citations: if citations:
data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys))) data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)))
if len(keys.intersection(self.metadata_keys)) != 0: if len(keys) > 0:
weaviate_where_operands = [] weaviate_where_operands = []
for key in keys: for key in keys:
if key in self.metadata_keys: weaviate_where_operands.append(
weaviate_where_operands.append( {
{ "path": ["metadata", self.index_name + "_metadata", key],
"path": ["metadata", self.index_name + "_metadata", key], "operator": "Equal",
"operator": "Equal", "valueText": where.get(key),
"valueText": where.get(key), }
} )
)
if len(weaviate_where_operands) == 1: if len(weaviate_where_operands) == 1:
weaviate_where_clause = weaviate_where_operands[0] weaviate_where_clause = weaviate_where_operands[0]
else: else:
@@ -289,11 +325,37 @@ class WeaviateDB(BaseVectorDB):
:return: Weaviate index :return: Weaviate index
:rtype: str :rtype: str
""" """
return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize() return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_")
@staticmethod @staticmethod
def _query_with_cursor(query, cursor): def _query_with_offset(query, offset):
if cursor is not None: if offset:
query.with_after(cursor) query.with_offset(offset)
results = query.do() results = query.do()
return results return results
def _generate_query(self, where: dict):
weaviate_where_operands = []
for key, value in where.items():
weaviate_where_operands.append(
{
"path": ["metadata", self.index_name + "_metadata", key],
"operator": "Equal",
"valueText": value,
}
)
if len(weaviate_where_operands) == 1:
weaviate_where_clause = weaviate_where_operands[0]
else:
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
return weaviate_where_clause
def delete(self, where: dict):
"""Delete from database.
:param where: to filter data
:type where: dict[str, any]
"""
query = self._generate_query(where)
self.client.batch.delete_objects(self.index_name, where=query)

View File

@@ -56,9 +56,9 @@ class TestQdrantDB(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder) App(config=app_config, db=db, embedding_model=embedder)
resp = db.get(ids=[], where={}) resp = db.get(ids=[], where={})
self.assertEqual(resp, {"ids": []}) self.assertEqual(resp, {"ids": [], "metadatas": []})
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"}) resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
self.assertEqual(resp2, {"ids": []}) self.assertEqual(resp2, {"ids": [], "metadatas": []})
@patch("embedchain.vectordb.qdrant.QdrantClient") @patch("embedchain.vectordb.qdrant.QdrantClient")
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS) @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
@@ -119,7 +119,7 @@ class TestQdrantDB(unittest.TestCase):
query_filter=models.Filter( query_filter=models.Filter(
must=[ must=[
models.FieldCondition( models.FieldCondition(
key="payload.metadata.doc_id", key="metadata.doc_id",
match=models.MatchValue( match=models.MatchValue(
value="123", value="123",
), ),