[Bugfix] fix qdrant and weaviate db integration (#1181)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -11,6 +11,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
|
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from embedchain.config.vectordb.qdrant import QdrantDBConfig
|
from embedchain.config.vectordb.qdrant import QdrantDBConfig
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
@@ -48,7 +50,6 @@ class QdrantDB(BaseVectorDB):
|
|||||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||||
|
|
||||||
self.collection_name = self._get_or_create_collection()
|
self.collection_name = self._get_or_create_collection()
|
||||||
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
|
|
||||||
all_collections = self.client.get_collections()
|
all_collections = self.client.get_collections()
|
||||||
collection_names = [collection.name for collection in all_collections.collections]
|
collection_names = [collection.name for collection in all_collections.collections]
|
||||||
if self.collection_name not in collection_names:
|
if self.collection_name not in collection_names:
|
||||||
@@ -82,21 +83,23 @@ class QdrantDB(BaseVectorDB):
|
|||||||
:return: All the existing IDs
|
:return: All the existing IDs
|
||||||
:rtype: Set[str]
|
:rtype: Set[str]
|
||||||
"""
|
"""
|
||||||
if ids is None or len(ids) == 0:
|
|
||||||
return {"ids": []}
|
|
||||||
|
|
||||||
keys = set(where.keys() if where is not None else set())
|
keys = set(where.keys() if where is not None else set())
|
||||||
|
|
||||||
qdrant_must_filters = [
|
qdrant_must_filters = []
|
||||||
models.FieldCondition(
|
|
||||||
key="identifier",
|
if ids:
|
||||||
match=models.MatchAny(
|
qdrant_must_filters.append(
|
||||||
any=ids,
|
models.FieldCondition(
|
||||||
),
|
key="identifier",
|
||||||
|
match=models.MatchAny(
|
||||||
|
any=ids,
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
]
|
|
||||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
if len(keys) > 0:
|
||||||
for key in keys.intersection(self.metadata_keys):
|
for key in keys:
|
||||||
qdrant_must_filters.append(
|
qdrant_must_filters.append(
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="metadata.{}".format(key),
|
key="metadata.{}".format(key),
|
||||||
@@ -108,6 +111,7 @@ class QdrantDB(BaseVectorDB):
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
existing_ids = []
|
existing_ids = []
|
||||||
|
metadatas = []
|
||||||
while offset is not None:
|
while offset is not None:
|
||||||
response = self.client.scroll(
|
response = self.client.scroll(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
@@ -118,7 +122,8 @@ class QdrantDB(BaseVectorDB):
|
|||||||
offset = response[1]
|
offset = response[1]
|
||||||
for doc in response[0]:
|
for doc in response[0]:
|
||||||
existing_ids.append(doc.payload["identifier"])
|
existing_ids.append(doc.payload["identifier"])
|
||||||
return {"ids": existing_ids}
|
metadatas.append(doc.payload["metadata"])
|
||||||
|
return {"ids": existing_ids, "metadatas": metadatas}
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
@@ -143,7 +148,8 @@ class QdrantDB(BaseVectorDB):
|
|||||||
metadata["text"] = document
|
metadata["text"] = document
|
||||||
qdrant_ids.append(str(uuid.uuid4()))
|
qdrant_ids.append(str(uuid.uuid4()))
|
||||||
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
|
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
|
||||||
for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
|
|
||||||
|
for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"):
|
||||||
self.client.upsert(
|
self.client.upsert(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
points=Batch(
|
points=Batch(
|
||||||
@@ -180,16 +186,17 @@ class QdrantDB(BaseVectorDB):
|
|||||||
keys = set(where.keys() if where is not None else set())
|
keys = set(where.keys() if where is not None else set())
|
||||||
|
|
||||||
qdrant_must_filters = []
|
qdrant_must_filters = []
|
||||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
if len(keys) > 0:
|
||||||
for key in keys.intersection(self.metadata_keys):
|
for key in keys:
|
||||||
qdrant_must_filters.append(
|
qdrant_must_filters.append(
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="payload.metadata.{}".format(key),
|
key="metadata.{}".format(key),
|
||||||
match=models.MatchValue(
|
match=models.MatchValue(
|
||||||
value=where.get(key),
|
value=where.get(key),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.client.search(
|
results = self.client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query_filter=models.Filter(must=qdrant_must_filters),
|
query_filter=models.Filter(must=qdrant_must_filters),
|
||||||
@@ -228,3 +235,21 @@ class QdrantDB(BaseVectorDB):
|
|||||||
raise TypeError("Collection name must be a string")
|
raise TypeError("Collection name must be a string")
|
||||||
self.config.collection_name = name
|
self.config.collection_name = name
|
||||||
self.collection_name = self._get_or_create_collection()
|
self.collection_name = self._get_or_create_collection()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_query(where: dict):
|
||||||
|
must_fields = []
|
||||||
|
for key, value in where.items():
|
||||||
|
must_fields.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=f"metadata.{key}",
|
||||||
|
match=models.MatchValue(
|
||||||
|
value=value,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return models.Filter(must=must_fields)
|
||||||
|
|
||||||
|
def delete(self, where: dict):
|
||||||
|
db_filter = self._generate_query(where)
|
||||||
|
self.client.delete(collection_name=self.collection_name, points_selector=db_filter)
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
|
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
|
||||||
**self.config.extra_params,
|
**self.config.extra_params,
|
||||||
)
|
)
|
||||||
|
# Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb.
|
||||||
|
# This is needed to filter data while querying.
|
||||||
|
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
|
||||||
|
|
||||||
# Call parent init here because embedder is needed
|
# Call parent init here because embedder is needed
|
||||||
super().__init__(config=self.config)
|
super().__init__(config=self.config)
|
||||||
@@ -58,7 +61,6 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||||
|
|
||||||
self.index_name = self._get_index_name()
|
self.index_name = self._get_index_name()
|
||||||
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
|
|
||||||
if not self.client.schema.exists(self.index_name):
|
if not self.client.schema.exists(self.index_name):
|
||||||
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
|
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
|
||||||
# The none vectorizer is crucial as we have our own custom embedding function
|
# The none vectorizer is crucial as we have our own custom embedding function
|
||||||
@@ -127,29 +129,64 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
:return: ids
|
:return: ids
|
||||||
:rtype: Set[str]
|
:rtype: Set[str]
|
||||||
"""
|
"""
|
||||||
|
weaviate_where_operands = []
|
||||||
|
|
||||||
if ids is None or len(ids) == 0:
|
if ids:
|
||||||
return {"ids": []}
|
for doc_id in ids:
|
||||||
|
weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id})
|
||||||
|
|
||||||
|
keys = set(where.keys() if where is not None else set())
|
||||||
|
if len(keys) > 0:
|
||||||
|
for key in keys:
|
||||||
|
weaviate_where_operands.append(
|
||||||
|
{
|
||||||
|
"path": ["metadata", self.index_name + "_metadata", key],
|
||||||
|
"operator": "Equal",
|
||||||
|
"valueText": where.get(key),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(weaviate_where_operands) == 1:
|
||||||
|
weaviate_where_clause = weaviate_where_operands[0]
|
||||||
|
else:
|
||||||
|
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
||||||
|
|
||||||
existing_ids = []
|
existing_ids = []
|
||||||
|
metadatas = []
|
||||||
cursor = None
|
cursor = None
|
||||||
|
offset = 0
|
||||||
has_iterated_once = False
|
has_iterated_once = False
|
||||||
|
query_metadata_keys = self.metadata_keys.union(keys)
|
||||||
while cursor is not None or not has_iterated_once:
|
while cursor is not None or not has_iterated_once:
|
||||||
has_iterated_once = True
|
has_iterated_once = True
|
||||||
results = self._query_with_cursor(
|
results = self._query_with_offset(
|
||||||
self.client.query.get(self.index_name, ["identifier"])
|
self.client.query.get(
|
||||||
|
self.index_name,
|
||||||
|
[
|
||||||
|
"identifier",
|
||||||
|
weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.with_where(weaviate_where_clause)
|
||||||
.with_additional(["id"])
|
.with_additional(["id"])
|
||||||
.with_limit(self.BATCH_SIZE),
|
.with_limit(limit or self.BATCH_SIZE),
|
||||||
cursor,
|
offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
fetched_results = results["data"]["Get"].get(self.index_name, [])
|
fetched_results = results["data"]["Get"].get(self.index_name, [])
|
||||||
if len(fetched_results) == 0:
|
if not fetched_results:
|
||||||
break
|
break
|
||||||
|
|
||||||
for result in fetched_results:
|
for result in fetched_results:
|
||||||
existing_ids.append(result["identifier"])
|
existing_ids.append(result["identifier"])
|
||||||
|
metadatas.append(result["metadata"][0])
|
||||||
cursor = result["_additional"]["id"]
|
cursor = result["_additional"]["id"]
|
||||||
|
offset += 1
|
||||||
|
|
||||||
return {"ids": existing_ids}
|
if limit is not None and len(existing_ids) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
return {"ids": existing_ids, "metadatas": metadatas}
|
||||||
|
|
||||||
def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
|
def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
|
||||||
"""add data in vector database
|
"""add data in vector database
|
||||||
@@ -201,21 +238,20 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
keys = set(where.keys() if where is not None else set())
|
keys = set(where.keys() if where is not None else set())
|
||||||
data_fields = ["text"]
|
data_fields = ["text"]
|
||||||
|
query_metadata_keys = self.metadata_keys.union(keys)
|
||||||
if citations:
|
if citations:
|
||||||
data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
|
data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)))
|
||||||
|
|
||||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
if len(keys) > 0:
|
||||||
weaviate_where_operands = []
|
weaviate_where_operands = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key in self.metadata_keys:
|
weaviate_where_operands.append(
|
||||||
weaviate_where_operands.append(
|
{
|
||||||
{
|
"path": ["metadata", self.index_name + "_metadata", key],
|
||||||
"path": ["metadata", self.index_name + "_metadata", key],
|
"operator": "Equal",
|
||||||
"operator": "Equal",
|
"valueText": where.get(key),
|
||||||
"valueText": where.get(key),
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
if len(weaviate_where_operands) == 1:
|
if len(weaviate_where_operands) == 1:
|
||||||
weaviate_where_clause = weaviate_where_operands[0]
|
weaviate_where_clause = weaviate_where_operands[0]
|
||||||
else:
|
else:
|
||||||
@@ -289,11 +325,37 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
:return: Weaviate index
|
:return: Weaviate index
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
|
return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _query_with_cursor(query, cursor):
|
def _query_with_offset(query, offset):
|
||||||
if cursor is not None:
|
if offset:
|
||||||
query.with_after(cursor)
|
query.with_offset(offset)
|
||||||
results = query.do()
|
results = query.do()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _generate_query(self, where: dict):
|
||||||
|
weaviate_where_operands = []
|
||||||
|
for key, value in where.items():
|
||||||
|
weaviate_where_operands.append(
|
||||||
|
{
|
||||||
|
"path": ["metadata", self.index_name + "_metadata", key],
|
||||||
|
"operator": "Equal",
|
||||||
|
"valueText": value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(weaviate_where_operands) == 1:
|
||||||
|
weaviate_where_clause = weaviate_where_operands[0]
|
||||||
|
else:
|
||||||
|
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
||||||
|
|
||||||
|
return weaviate_where_clause
|
||||||
|
|
||||||
|
def delete(self, where: dict):
|
||||||
|
"""Delete from database.
|
||||||
|
:param where: to filter data
|
||||||
|
:type where: dict[str, any]
|
||||||
|
"""
|
||||||
|
query = self._generate_query(where)
|
||||||
|
self.client.batch.delete_objects(self.index_name, where=query)
|
||||||
|
|||||||
@@ -56,9 +56,9 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
|
|
||||||
resp = db.get(ids=[], where={})
|
resp = db.get(ids=[], where={})
|
||||||
self.assertEqual(resp, {"ids": []})
|
self.assertEqual(resp, {"ids": [], "metadatas": []})
|
||||||
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
|
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
|
||||||
self.assertEqual(resp2, {"ids": []})
|
self.assertEqual(resp2, {"ids": [], "metadatas": []})
|
||||||
|
|
||||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||||
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
|
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
|
||||||
@@ -119,7 +119,7 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
query_filter=models.Filter(
|
query_filter=models.Filter(
|
||||||
must=[
|
must=[
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="payload.metadata.doc_id",
|
key="metadata.doc_id",
|
||||||
match=models.MatchValue(
|
match=models.MatchValue(
|
||||||
value="123",
|
value="123",
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user