[Feature] Update db.query to return source of context (#831)

This commit is contained in:
Deven Patel
2023-10-25 22:20:32 -07:00
committed by GitHub
parent a27eeb3255
commit d77e8da3f3
13 changed files with 195 additions and 73 deletions

1
.gitignore vendored
View File

@@ -174,4 +174,5 @@ test-db
notebooks/*.yaml notebooks/*.yaml
.ipynb_checkpoints/ .ipynb_checkpoints/
!configs/*.yaml !configs/*.yaml

View File

@@ -500,13 +500,17 @@ class EmbedChain(JSONSerializable):
db_query = ClipProcessor.get_text_features(query=input_query) db_query = ClipProcessor.get_text_features(query=input_query)
contents = self.db.query( contexts = self.db.query(
input_query=db_query, input_query=db_query,
n_results=query_config.number_documents, n_results=query_config.number_documents,
where=where, where=where,
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
) )
return contents
if len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts = list(map(lambda x: x[0], contexts))
return contexts
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str: def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
""" """

View File

@@ -41,15 +41,15 @@ class LlmFactory:
class EmbedderFactory: class EmbedderFactory:
provider_to_class = { provider_to_class = {
"azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder", "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder", "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
"openai": "embedchain.embedder.openai.OpenAIEmbedder", "openai": "embedchain.embedder.openai.OpenAIEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
} }
provider_to_config_class = { provider_to_config_class = {
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
} }
@classmethod @classmethod
@@ -72,16 +72,18 @@ class VectorDBFactory:
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB", "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
"qdrant": "embedchain.vectordb.qdrant.QdrantDB", "qdrant": "embedchain.vectordb.qdrant.QdrantDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
} }
provider_to_config_class = { provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig", "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
} }
@classmethod @classmethod

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from chromadb import Collection, QueryResult from chromadb import Collection, QueryResult
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -191,7 +191,9 @@ class ChromaDB(BaseVectorDB):
) )
] ]
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
@@ -204,8 +206,8 @@ class ChromaDB(BaseVectorDB):
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool :type skip_embedding: bool
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query. :return: The content of the document that matched your query, url of the source, doc_id
:rtype: List[str] :rtype: List[Tuple[str,str,str]]
""" """
try: try:
if skip_embedding: if skip_embedding:
@@ -231,8 +233,14 @@ class ChromaDB(BaseVectorDB):
" embeddings, is used to retrieve an embedding from the database." " embeddings, is used to retrieve an embedding from the database."
) from None ) from None
results_formatted = self._format_result(result) results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted] contexts = []
return contents for result in results_formatted:
context = result[0].page_content
metadata = result[0].metadata
source = metadata["url"]
doc_id = metadata["doc_id"]
contexts.append((context, source, doc_id))
return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
try: try:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@@ -135,7 +135,9 @@ class ElasticsearchDB(BaseVectorDB):
bulk(self.client, docs) bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -147,8 +149,9 @@ class ElasticsearchDB(BaseVectorDB):
:type where: Dict[str, any] :type where: Dict[str, any]
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool :type skip_embedding: bool
:return: Database contents that are the result of the query :return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[str]
:rtype: List[Tuple[str,str,str]]
""" """
if skip_embedding: if skip_embedding:
query_vector = input_query query_vector = input_query
@@ -156,6 +159,7 @@ class ElasticsearchDB(BaseVectorDB):
input_query_vector = self.embedder.embedding_fn(input_query) input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0] query_vector = input_query_vector[0]
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
query = { query = {
"script_score": { "script_score": {
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, "query": {"bool": {"must": [{"exists": {"field": "text"}}]}},
@@ -167,11 +171,17 @@ class ElasticsearchDB(BaseVectorDB):
} }
if "app_id" in where: if "app_id" in where:
app_id = where["app_id"] app_id = where["app_id"]
query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] query["script_score"]["query"] = {"match": {"metadata.app_id": app_id}}
_source = ["text"] _source = ["text", "metadata.url", "metadata.doc_id"]
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results) response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
docs = response["hits"]["hits"] docs = response["hits"]["hits"]
contents = [doc["_source"]["text"] for doc in docs] contents = []
for doc in docs:
context = doc["_source"]["text"]
metadata = doc["_source"]["metadata"]
source = metadata["url"]
doc_id = metadata["doc_id"]
contents.append(tuple((context, source, doc_id)))
return contents return contents
def set_collection_name(self, name: str): def set_collection_name(self, name: str):

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set, Tuple
try: try:
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
@@ -145,7 +145,9 @@ class OpenSearchDB(BaseVectorDB):
bulk(self.client, docs) bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -157,8 +159,8 @@ class OpenSearchDB(BaseVectorDB):
:type where: Dict[str, any] :type where: Dict[str, any]
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool :type skip_embedding: bool
:return: Database contents that are the result of the query :return: The content of the document that matched your query, url of the source, doc_id
:rtype: List[str] :rtype: List[Tuple[str,str,str]]
""" """
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
@@ -185,7 +187,13 @@ class OpenSearchDB(BaseVectorDB):
pre_filter=pre_filter, pre_filter=pre_filter,
k=n_results, k=n_results,
) )
contents = [doc.page_content for doc in docs]
contents = []
for doc in docs:
context = doc.page_content
source = doc.metadata["url"]
doc_id = doc.metadata["doc_id"]
contents.append(tuple((context, source, doc_id)))
return contents return contents
def set_collection_name(self, name: str): def set_collection_name(self, name: str):

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
try: try:
import pinecone import pinecone
@@ -118,7 +118,9 @@ class PineconeDB(BaseVectorDB):
for i in range(0, len(docs), self.BATCH_SIZE): for i in range(0, len(docs), self.BATCH_SIZE):
self.client.upsert(docs[i : i + self.BATCH_SIZE]) self.client.upsert(docs[i : i + self.BATCH_SIZE])
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -129,16 +131,22 @@ class PineconeDB(BaseVectorDB):
:type where: Dict[str, any] :type where: Dict[str, any]
:param skip_embedding: Optional. if True, input_query is already embedded :param skip_embedding: Optional. if True, input_query is already embedded
:type skip_embedding: bool :type skip_embedding: bool
:return: Database contents that are the result of the query :return: The content of the document that matched your query, url of the source, doc_id
:rtype: List[str] :rtype: List[Tuple[str,str,str]]
""" """
if not skip_embedding: if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
else: else:
query_vector = input_query query_vector = input_query
contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"])) contents = []
return embeddings for doc in data["matches"]:
metadata = doc["metadata"]
context = metadata["text"]
source = metadata["url"]
doc_id = metadata["doc_id"]
contents.append(tuple((context, source, doc_id)))
return contents
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,7 +1,7 @@
import copy import copy
import os import os
import uuid import uuid
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
try: try:
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
@@ -160,7 +160,9 @@ class QdrantDB(BaseVectorDB):
), ),
) )
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -172,8 +174,8 @@ class QdrantDB(BaseVectorDB):
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not generated or not
:type skip_embedding: bool :type skip_embedding: bool
:return: Database contents that are the result of the query :return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[str] :rtype: List[Tuple[str,str,str]]
""" """
if not skip_embedding: if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
@@ -199,9 +201,14 @@ class QdrantDB(BaseVectorDB):
query_vector=query_vector, query_vector=query_vector,
limit=n_results, limit=n_results,
) )
response = [] response = []
for result in results: for result in results:
response.append(result.payload.get("text", "")) context = result.payload["text"]
metadata = result.payload["metadata"]
source = metadata["url"]
doc_id = metadata["doc_id"]
response.append(tuple((context, source, doc_id)))
return response return response
def count(self) -> int: def count(self) -> int:
@@ -211,3 +218,15 @@ class QdrantDB(BaseVectorDB):
def reset(self): def reset(self):
self.client.delete_collection(collection_name=self.collection_name) self.client.delete_collection(collection_name=self.collection_name)
self._initialize() self._initialize()
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
self.collection_name = self._get_or_create_collection()

View File

@@ -1,6 +1,6 @@
import copy import copy
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
try: try:
import weaviate import weaviate
@@ -194,7 +194,9 @@ class WeaviateDB(BaseVectorDB):
) )
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -206,14 +208,15 @@ class WeaviateDB(BaseVectorDB):
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not generated or not
:type skip_embedding: bool :type skip_embedding: bool
:return: Database contents that are the result of the query :return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[str] :rtype: List[Tuple[str,str,str]]
""" """
if not skip_embedding: if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
else: else:
query_vector = input_query query_vector = input_query
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())
data_fields = ["text"]
if len(keys.intersection(self.metadata_keys)) != 0: if len(keys.intersection(self.metadata_keys)) != 0:
weaviate_where_operands = [] weaviate_where_operands = []
for key in keys: for key in keys:
@@ -231,7 +234,7 @@ class WeaviateDB(BaseVectorDB):
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
results = ( results = (
self.client.query.get(self.index_name, ["text"]) self.client.query.get(self.index_name, data_fields)
.with_where(weaviate_where_clause) .with_where(weaviate_where_clause)
.with_near_vector({"vector": query_vector}) .with_near_vector({"vector": query_vector})
.with_limit(n_results) .with_limit(n_results)
@@ -239,16 +242,13 @@ class WeaviateDB(BaseVectorDB):
) )
else: else:
results = ( results = (
self.client.query.get(self.index_name, ["text"]) self.client.query.get(self.index_name, data_fields)
.with_near_vector({"vector": query_vector}) .with_near_vector({"vector": query_vector})
.with_limit(n_results) .with_limit(n_results)
.do() .do()
) )
matched_tokens = [] contexts = results["data"]["Get"].get(self.index_name)
for result in results["data"]["Get"].get(self.index_name): return contexts
matched_tokens.append(result["text"])
return matched_tokens
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,4 +1,5 @@
from typing import Dict, List, Optional import logging
from typing import Dict, List, Optional, Tuple
from embedchain.config import ZillizDBConfig from embedchain.config import ZillizDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
@@ -61,6 +62,7 @@ class ZillizVectorDB(BaseVectorDB):
:type name: str :type name: str
""" """
if utility.has_collection(name): if utility.has_collection(name):
logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
self.collection = Collection(name) self.collection = Collection(name)
else: else:
fields = [ fields = [
@@ -124,7 +126,9 @@ class ZillizVectorDB(BaseVectorDB):
self.collection.flush() self.collection.flush()
self.client.flush(self.config.collection_name) self.client.flush(self.config.collection_name)
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
""" """
Query contents from vector data base based on vector similarity Query contents from vector data base based on vector similarity
@@ -135,8 +139,8 @@ class ZillizVectorDB(BaseVectorDB):
:param where: to filter data :param where: to filter data
:type where: str :type where: str
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query. :return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[str] :rtype: List[Tuple[str,str,str]]
""" """
if self.collection.is_empty: if self.collection.is_empty:
@@ -145,13 +149,14 @@ class ZillizVectorDB(BaseVectorDB):
if not isinstance(where, str): if not isinstance(where, str):
where = None where = None
output_fields = ["text", "url", "doc_id"]
if skip_embedding: if skip_embedding:
query_vector = input_query query_vector = input_query
query_result = self.client.search( query_result = self.client.search(
collection_name=self.config.collection_name, collection_name=self.config.collection_name,
data=query_vector, data=query_vector,
limit=n_results, limit=n_results,
output_fields=["text"], output_fields=output_fields,
) )
else: else:
@@ -162,13 +167,16 @@ class ZillizVectorDB(BaseVectorDB):
collection_name=self.config.collection_name, collection_name=self.config.collection_name,
data=[query_vector], data=[query_vector],
limit=n_results, limit=n_results,
output_fields=["text"], output_fields=output_fields,
) )
doc_list = [] doc_list = []
for query in query_result: for query in query_result:
doc_list.append(query[0]["entity"]["text"]) data = query[0]["entity"]
context = data["text"]
source = data["url"]
doc_id = data["doc_id"]
doc_list.append(tuple((context, source, doc_id)))
return doc_list return doc_list
def count(self) -> int: def count(self) -> int:

View File

@@ -146,7 +146,7 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
app_with_settings.db.add( app_with_settings.db.add(
embeddings=[[0, 0, 0]], embeddings=[[0, 0, 0]],
documents=["document"], documents=["document"],
metadatas=[{"value": "somevalue"}], metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
ids=["id"], ids=["id"],
skip_embedding=True, skip_embedding=True,
) )
@@ -158,13 +158,13 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
"documents": ["document"], "documents": ["document"],
"embeddings": None, "embeddings": None,
"ids": ["id"], "ids": ["id"],
"metadatas": [{"value": "somevalue"}], "metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
} }
assert data == expected_value assert data == expected_value
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ["document"] expected_value = [("document", "url_1", "doc_id_1")]
assert data == expected_value assert data == expected_value
app_with_settings.db.reset() app_with_settings.db.reset()
@@ -299,3 +299,35 @@ def test_chroma_db_collection_reset():
app2.db.reset() app2.db.reset()
app3.db.reset() app3.db.reset()
app4.db.reset() app4.db.reset()
def test_chroma_db_collection_query(app_with_settings):
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 1
app_with_settings.db.add(
embeddings=[[0, 1, 0]],
documents=["document2"],
metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
ids=["id2"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 2
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
assert data == expected_value
app_with_settings.db.reset()

View File

@@ -31,7 +31,7 @@ class TestEsDB(unittest.TestCase):
# Create some dummy data. # Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]] embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."] documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}] metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
ids = ["doc_1", "doc_2"] ids = ["doc_1", "doc_2"]
# Add the data to the database. # Add the data to the database.
@@ -40,8 +40,17 @@ class TestEsDB(unittest.TestCase):
search_response = { search_response = {
"hits": { "hits": {
"hits": [ "hits": [
{"_source": {"text": "This is a document."}, "_score": 0.9}, {
{"_source": {"text": "This is another document."}, "_score": 0.8}, "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
"_score": 0.9,
},
{
"_source": {
"text": "This is another document.",
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
},
"_score": 0.8,
},
] ]
} }
} }
@@ -54,7 +63,9 @@ class TestEsDB(unittest.TestCase):
results = self.db.query(query, n_results=2, where={}, skip_embedding=False) results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
# Assert that the results are correct. # Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."]) self.assertEqual(
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
)
@patch("embedchain.vectordb.elasticsearch.Elasticsearch") @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client): def test_query_with_skip_embedding(self, mock_client):
@@ -68,7 +79,7 @@ class TestEsDB(unittest.TestCase):
# Create some dummy data. # Create some dummy data.
embeddings = [[1, 2, 3], [4, 5, 6]] embeddings = [[1, 2, 3], [4, 5, 6]]
documents = ["This is a document.", "This is another document."] documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}] metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
ids = ["doc_1", "doc_2"] ids = ["doc_1", "doc_2"]
# Add the data to the database. # Add the data to the database.
@@ -77,8 +88,17 @@ class TestEsDB(unittest.TestCase):
search_response = { search_response = {
"hits": { "hits": {
"hits": [ "hits": [
{"_source": {"text": "This is a document."}, "_score": 0.9}, {
{"_source": {"text": "This is another document."}, "_score": 0.8}, "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
"_score": 0.9,
},
{
"_source": {
"text": "This is another document.",
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
},
"_score": 0.8,
},
] ]
} }
} }
@@ -91,7 +111,9 @@ class TestEsDB(unittest.TestCase):
results = self.db.query(query, n_results=2, where={}, skip_embedding=True) results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
# Assert that the results are correct. # Assert that the results are correct.
self.assertEqual(results, ["This is a document.", "This is another document."]) self.assertEqual(
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
)
def test_init_without_url(self): def test_init_without_url(self):
# Make sure it's not loaded from env # Make sure it's not loaded from env

View File

@@ -123,7 +123,7 @@ class TestZillizDBCollection:
# Mock the MilvusClient search method # Mock the MilvusClient search method
with patch.object(zilliz_db.client, "search") as mock_search: with patch.object(zilliz_db.client, "search") as mock_search:
# Mock the search result # Mock the search result
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]] mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
# Call the query method with skip_embedding=True # Call the query method with skip_embedding=True
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True) query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
@@ -133,11 +133,11 @@ class TestZillizDBCollection:
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_text"], data=["query_text"],
limit=1, limit=1,
output_fields=["text"], output_fields=["text", "url", "doc_id"],
) )
# Assert that the query result matches the expected result # Assert that the query result matches the expected result
assert query_result == ["result_doc"] assert query_result == [("result_doc", "url_1", "doc_id_1")]
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True) @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True) @patch("embedchain.vectordb.zilliz.connections", autospec=True)
@@ -162,7 +162,7 @@ class TestZillizDBCollection:
mock_embedder.embedding_fn.return_value = ["query_vector"] mock_embedder.embedding_fn.return_value = ["query_vector"]
# Mock the search result # Mock the search result
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]] mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
# Call the query method with skip_embedding=False # Call the query method with skip_embedding=False
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False) query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
@@ -172,8 +172,8 @@ class TestZillizDBCollection:
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_vector"], data=["query_vector"],
limit=1, limit=1,
output_fields=["text"], output_fields=["text", "url", "doc_id"],
) )
# Assert that the query result matches the expected result # Assert that the query result matches the expected result
assert query_result == ["result_doc"] assert query_result == [("result_doc", "url_1", "doc_id_1")]