[Feature] Return score when doing search in vectorDB (#1060)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-29 15:56:12 +05:30
committed by GitHub
parent 19d80914df
commit c0aafd38c9
12 changed files with 72 additions and 28 deletions

View File

@@ -515,7 +515,7 @@ class EmbedChain(JSONSerializable):
where: Optional[Dict] = None, where: Optional[Dict] = None,
citations: bool = False, citations: bool = False,
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -566,7 +566,7 @@ class EmbedChain(JSONSerializable):
where: Optional[Dict[str, str]] = None, where: Optional[Dict[str, str]] = None,
citations: bool = False, citations: bool = False,
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an

View File

@@ -200,7 +200,7 @@ class ChromaDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
@@ -250,6 +250,7 @@ class ChromaDB(BaseVectorDB):
context = result[0].page_content context = result[0].page_content
if citations: if citations:
metadata = result[0].metadata metadata = result[0].metadata
metadata["score"] = result[1]
contexts.append((context, metadata)) contexts.append((context, metadata))
else: else:
contexts.append(context) contexts.append(context)

View File

@@ -164,7 +164,7 @@ class ElasticsearchDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -210,6 +210,7 @@ class ElasticsearchDB(BaseVectorDB):
context = doc["_source"]["text"] context = doc["_source"]["text"]
if citations: if citations:
metadata = doc["_source"]["metadata"] metadata = doc["_source"]["metadata"]
metadata["score"] = doc["_score"]
contexts.append(tuple((context, metadata))) contexts.append(tuple((context, metadata)))
else: else:
contexts.append(context) contexts.append(context)

View File

@@ -169,7 +169,7 @@ class OpenSearchDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -202,7 +202,7 @@ class OpenSearchDB(BaseVectorDB):
if "app_id" in where: if "app_id" in where:
app_id = where["app_id"] app_id = where["app_id"]
pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}} pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
docs = docsearch.similarity_search( docs = docsearch.similarity_search_with_score(
input_query, input_query,
search_type="script_scoring", search_type="script_scoring",
space_type="cosinesimil", space_type="cosinesimil",
@@ -215,10 +215,12 @@ class OpenSearchDB(BaseVectorDB):
) )
contexts = [] contexts = []
for doc in docs: for doc, score in docs:
context = doc.page_content context = doc.page_content
if citations: if citations:
contexts.append(tuple((context, doc.metadata))) metadata = doc.metadata
metadata["score"] = score
contexts.append(tuple((context, metadata)))
else: else:
contexts.append(context) contexts.append(context)
return contexts return contexts

View File

@@ -127,7 +127,7 @@ class PineconeDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[Dict[str, any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -154,6 +154,7 @@ class PineconeDB(BaseVectorDB):
metadata = doc["metadata"] metadata = doc["metadata"]
context = metadata["text"] context = metadata["text"]
if citations: if citations:
metadata["score"] = doc["score"]
contexts.append(tuple((context, metadata))) contexts.append(tuple((context, metadata)))
else: else:
contexts.append(context) contexts.append(context)

View File

@@ -170,7 +170,7 @@ class QdrantDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -219,6 +219,7 @@ class QdrantDB(BaseVectorDB):
context = result.payload["text"] context = result.payload["text"]
if citations: if citations:
metadata = result.payload["metadata"] metadata = result.payload["metadata"]
metadata["score"] = result.score
contexts.append(tuple((context, metadata))) contexts.append(tuple((context, metadata)))
else: else:
contexts.append(context) contexts.append(context)

View File

@@ -205,7 +205,7 @@ class WeaviateDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -255,6 +255,7 @@ class WeaviateDB(BaseVectorDB):
.with_where(weaviate_where_clause) .with_where(weaviate_where_clause)
.with_near_vector({"vector": query_vector}) .with_near_vector({"vector": query_vector})
.with_limit(n_results) .with_limit(n_results)
.with_additional(["distance"])
.do() .do()
) )
else: else:
@@ -262,6 +263,7 @@ class WeaviateDB(BaseVectorDB):
self.client.query.get(self.index_name, data_fields) self.client.query.get(self.index_name, data_fields)
.with_near_vector({"vector": query_vector}) .with_near_vector({"vector": query_vector})
.with_limit(n_results) .with_limit(n_results)
.with_additional(["distance"])
.do() .do()
) )
@@ -271,6 +273,8 @@ class WeaviateDB(BaseVectorDB):
context = doc["text"] context = doc["text"]
if citations: if citations:
metadata = doc["metadata"][0] metadata = doc["metadata"][0]
score = doc["_additional"]["distance"]
metadata["score"] = score
contexts.append((context, metadata)) contexts.append((context, metadata))
else: else:
contexts.append(context) contexts.append(context)

View File

@@ -135,7 +135,7 @@ class ZillizVectorDB(BaseVectorDB):
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
Query contents from vector data base based on vector similarity Query contents from vector data base based on vector similarity
@@ -159,7 +159,7 @@ class ZillizVectorDB(BaseVectorDB):
if not isinstance(where, str): if not isinstance(where, str):
where = None where = None
output_fields = ["text", "url", "doc_id"] output_fields = ["*"]
if skip_embedding: if skip_embedding:
query_vector = input_query query_vector = input_query
query_result = self.client.search( query_result = self.client.search(
@@ -181,12 +181,18 @@ class ZillizVectorDB(BaseVectorDB):
output_fields=output_fields, output_fields=output_fields,
**kwargs, **kwargs,
) )
query_result = query_result[0]
contexts = [] contexts = []
for query in query_result: for query in query_result:
data = query[0]["entity"] data = query["entity"]
score = query["distance"]
context = data["text"] context = data["text"]
if "embeddings" in data:
data.pop("embeddings")
if citations: if citations:
data["score"] = score
contexts.append(tuple((context, data))) contexts.append(tuple((context, data)))
else: else:
contexts.append(context) contexts.append(context)

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.44" version = "0.1.45"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -342,8 +342,22 @@ def test_chroma_db_collection_query(app_with_settings):
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
) )
expected_value_with_citations = [ expected_value_with_citations = [
("document", {"url": "url_1", "doc_id": "doc_id_1"}), (
("document2", {"url": "url_2", "doc_id": "doc_id_2"}), "document",
{
"url": "url_1",
"doc_id": "doc_id_1",
"score": 0.0,
},
),
(
"document2",
{
"url": "url_2",
"doc_id": "doc_id_2",
"score": 1.0,
},
),
] ]
assert data_with_citations == expected_value_with_citations assert data_with_citations == expected_value_with_citations

View File

@@ -66,8 +66,8 @@ class TestEsDB(unittest.TestCase):
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True) results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
expected_results_with_citations = [ expected_results_with_citations = [
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1"}), ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2"}), ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
] ]
self.assertEqual(results_with_citations, expected_results_with_citations) self.assertEqual(results_with_citations, expected_results_with_citations)

View File

@@ -123,7 +123,14 @@ class TestZillizDBCollection:
# Mock the MilvusClient search method # Mock the MilvusClient search method
with patch.object(zilliz_db.client, "search") as mock_search: with patch.object(zilliz_db.client, "search") as mock_search:
# Mock the search result # Mock the search result
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]] mock_search.return_value = [
[
{
"distance": 0.5,
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
}
]
]
# Call the query method with skip_embedding=True # Call the query method with skip_embedding=True
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True) query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
@@ -133,7 +140,7 @@ class TestZillizDBCollection:
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_text"], data=["query_text"],
limit=1, limit=1,
output_fields=["text", "url", "doc_id"], output_fields=["*"],
) )
# Assert that the query result matches the expected result # Assert that the query result matches the expected result
@@ -147,11 +154,11 @@ class TestZillizDBCollection:
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_text"], data=["query_text"],
limit=1, limit=1,
output_fields=["text", "url", "doc_id"], output_fields=["*"],
) )
assert query_result_with_citations == [ assert query_result_with_citations == [
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}) ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
] ]
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True) @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@@ -177,7 +184,14 @@ class TestZillizDBCollection:
mock_embedder.embedding_fn.return_value = ["query_vector"] mock_embedder.embedding_fn.return_value = ["query_vector"]
# Mock the search result # Mock the search result
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]] mock_search.return_value = [
[
{
"distance": 0.0,
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
}
]
]
# Call the query method with skip_embedding=False # Call the query method with skip_embedding=False
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False) query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
@@ -187,7 +201,7 @@ class TestZillizDBCollection:
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_vector"], data=["query_vector"],
limit=1, limit=1,
output_fields=["text", "url", "doc_id"], output_fields=["*"],
) )
# Assert that the query result matches the expected result # Assert that the query result matches the expected result
@@ -201,9 +215,9 @@ class TestZillizDBCollection:
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_vector"], data=["query_vector"],
limit=1, limit=1,
output_fields=["text", "url", "doc_id"], output_fields=["*"],
) )
assert query_result_with_citations == [ assert query_result_with_citations == [
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}) ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0})
] ]