[Feature] Return score when doing search in vectorDB (#1060)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -515,7 +515,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
where: Optional[Dict] = None,
|
where: Optional[Dict] = None,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Dict[str, Any],
|
||||||
) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
|
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
@@ -566,7 +566,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
where: Optional[Dict[str, str]] = None,
|
where: Optional[Dict[str, str]] = None,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Dict[str, Any],
|
||||||
) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
|
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
|
||||||
"""
|
"""
|
||||||
Queries the vector database on the given input query.
|
Queries the vector database on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector database based on vector similarity
|
Query contents from vector database based on vector similarity
|
||||||
|
|
||||||
@@ -250,6 +250,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
context = result[0].page_content
|
context = result[0].page_content
|
||||||
if citations:
|
if citations:
|
||||||
metadata = result[0].metadata
|
metadata = result[0].metadata
|
||||||
|
metadata["score"] = result[1]
|
||||||
contexts.append((context, metadata))
|
contexts.append((context, metadata))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector data base based on vector similarity
|
query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -210,6 +210,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
context = doc["_source"]["text"]
|
context = doc["_source"]["text"]
|
||||||
if citations:
|
if citations:
|
||||||
metadata = doc["_source"]["metadata"]
|
metadata = doc["_source"]["metadata"]
|
||||||
|
metadata["score"] = doc["_score"]
|
||||||
contexts.append(tuple((context, metadata)))
|
contexts.append(tuple((context, metadata)))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector data base based on vector similarity
|
query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -202,7 +202,7 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
if "app_id" in where:
|
if "app_id" in where:
|
||||||
app_id = where["app_id"]
|
app_id = where["app_id"]
|
||||||
pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
|
pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
|
||||||
docs = docsearch.similarity_search(
|
docs = docsearch.similarity_search_with_score(
|
||||||
input_query,
|
input_query,
|
||||||
search_type="script_scoring",
|
search_type="script_scoring",
|
||||||
space_type="cosinesimil",
|
space_type="cosinesimil",
|
||||||
@@ -215,10 +215,12 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
|
|
||||||
contexts = []
|
contexts = []
|
||||||
for doc in docs:
|
for doc, score in docs:
|
||||||
context = doc.page_content
|
context = doc.page_content
|
||||||
if citations:
|
if citations:
|
||||||
contexts.append(tuple((context, doc.metadata)))
|
metadata = doc.metadata
|
||||||
|
metadata["score"] = score
|
||||||
|
contexts.append(tuple((context, metadata)))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
return contexts
|
return contexts
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class PineconeDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, any]],
|
**kwargs: Optional[Dict[str, any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -154,6 +154,7 @@ class PineconeDB(BaseVectorDB):
|
|||||||
metadata = doc["metadata"]
|
metadata = doc["metadata"]
|
||||||
context = metadata["text"]
|
context = metadata["text"]
|
||||||
if citations:
|
if citations:
|
||||||
|
metadata["score"] = doc["score"]
|
||||||
contexts.append(tuple((context, metadata)))
|
contexts.append(tuple((context, metadata)))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class QdrantDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -219,6 +219,7 @@ class QdrantDB(BaseVectorDB):
|
|||||||
context = result.payload["text"]
|
context = result.payload["text"]
|
||||||
if citations:
|
if citations:
|
||||||
metadata = result.payload["metadata"]
|
metadata = result.payload["metadata"]
|
||||||
|
metadata["score"] = result.score
|
||||||
contexts.append(tuple((context, metadata)))
|
contexts.append(tuple((context, metadata)))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -255,6 +255,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
.with_where(weaviate_where_clause)
|
.with_where(weaviate_where_clause)
|
||||||
.with_near_vector({"vector": query_vector})
|
.with_near_vector({"vector": query_vector})
|
||||||
.with_limit(n_results)
|
.with_limit(n_results)
|
||||||
|
.with_additional(["distance"])
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -262,6 +263,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
self.client.query.get(self.index_name, data_fields)
|
self.client.query.get(self.index_name, data_fields)
|
||||||
.with_near_vector({"vector": query_vector})
|
.with_near_vector({"vector": query_vector})
|
||||||
.with_limit(n_results)
|
.with_limit(n_results)
|
||||||
|
.with_additional(["distance"])
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -271,6 +273,8 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
context = doc["text"]
|
context = doc["text"]
|
||||||
if citations:
|
if citations:
|
||||||
metadata = doc["metadata"][0]
|
metadata = doc["metadata"][0]
|
||||||
|
score = doc["_additional"]["distance"]
|
||||||
|
metadata["score"] = score
|
||||||
contexts.append((context, metadata))
|
contexts.append((context, metadata))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
skip_embedding: bool,
|
skip_embedding: bool,
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector data base based on vector similarity
|
Query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
if not isinstance(where, str):
|
if not isinstance(where, str):
|
||||||
where = None
|
where = None
|
||||||
|
|
||||||
output_fields = ["text", "url", "doc_id"]
|
output_fields = ["*"]
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
query_result = self.client.search(
|
query_result = self.client.search(
|
||||||
@@ -181,12 +181,18 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
output_fields=output_fields,
|
output_fields=output_fields,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
query_result = query_result[0]
|
||||||
contexts = []
|
contexts = []
|
||||||
for query in query_result:
|
for query in query_result:
|
||||||
data = query[0]["entity"]
|
data = query["entity"]
|
||||||
|
score = query["distance"]
|
||||||
context = data["text"]
|
context = data["text"]
|
||||||
|
|
||||||
|
if "embeddings" in data:
|
||||||
|
data.pop("embeddings")
|
||||||
|
|
||||||
if citations:
|
if citations:
|
||||||
|
data["score"] = score
|
||||||
contexts.append(tuple((context, data)))
|
contexts.append(tuple((context, data)))
|
||||||
else:
|
else:
|
||||||
contexts.append(context)
|
contexts.append(context)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.44"
|
version = "0.1.45"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
@@ -342,8 +342,22 @@ def test_chroma_db_collection_query(app_with_settings):
|
|||||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
|
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
|
||||||
)
|
)
|
||||||
expected_value_with_citations = [
|
expected_value_with_citations = [
|
||||||
("document", {"url": "url_1", "doc_id": "doc_id_1"}),
|
(
|
||||||
("document2", {"url": "url_2", "doc_id": "doc_id_2"}),
|
"document",
|
||||||
|
{
|
||||||
|
"url": "url_1",
|
||||||
|
"doc_id": "doc_id_1",
|
||||||
|
"score": 0.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"document2",
|
||||||
|
{
|
||||||
|
"url": "url_2",
|
||||||
|
"doc_id": "doc_id_2",
|
||||||
|
"score": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
assert data_with_citations == expected_value_with_citations
|
assert data_with_citations == expected_value_with_citations
|
||||||
|
|
||||||
|
|||||||
@@ -66,8 +66,8 @@ class TestEsDB(unittest.TestCase):
|
|||||||
|
|
||||||
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
|
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
|
||||||
expected_results_with_citations = [
|
expected_results_with_citations = [
|
||||||
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1"}),
|
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
|
||||||
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2"}),
|
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
|
||||||
]
|
]
|
||||||
self.assertEqual(results_with_citations, expected_results_with_citations)
|
self.assertEqual(results_with_citations, expected_results_with_citations)
|
||||||
|
|
||||||
|
|||||||
@@ -123,7 +123,14 @@ class TestZillizDBCollection:
|
|||||||
# Mock the MilvusClient search method
|
# Mock the MilvusClient search method
|
||||||
with patch.object(zilliz_db.client, "search") as mock_search:
|
with patch.object(zilliz_db.client, "search") as mock_search:
|
||||||
# Mock the search result
|
# Mock the search result
|
||||||
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
|
mock_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"distance": 0.5,
|
||||||
|
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
# Call the query method with skip_embedding=True
|
# Call the query method with skip_embedding=True
|
||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||||
@@ -133,7 +140,7 @@ class TestZillizDBCollection:
|
|||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_text"],
|
data=["query_text"],
|
||||||
limit=1,
|
limit=1,
|
||||||
output_fields=["text", "url", "doc_id"],
|
output_fields=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the query result matches the expected result
|
# Assert that the query result matches the expected result
|
||||||
@@ -147,11 +154,11 @@ class TestZillizDBCollection:
|
|||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_text"],
|
data=["query_text"],
|
||||||
limit=1,
|
limit=1,
|
||||||
output_fields=["text", "url", "doc_id"],
|
output_fields=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert query_result_with_citations == [
|
assert query_result_with_citations == [
|
||||||
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"})
|
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
|
||||||
]
|
]
|
||||||
|
|
||||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||||
@@ -177,7 +184,14 @@ class TestZillizDBCollection:
|
|||||||
mock_embedder.embedding_fn.return_value = ["query_vector"]
|
mock_embedder.embedding_fn.return_value = ["query_vector"]
|
||||||
|
|
||||||
# Mock the search result
|
# Mock the search result
|
||||||
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
|
mock_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"distance": 0.0,
|
||||||
|
"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
# Call the query method with skip_embedding=False
|
# Call the query method with skip_embedding=False
|
||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||||
@@ -187,7 +201,7 @@ class TestZillizDBCollection:
|
|||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_vector"],
|
data=["query_vector"],
|
||||||
limit=1,
|
limit=1,
|
||||||
output_fields=["text", "url", "doc_id"],
|
output_fields=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the query result matches the expected result
|
# Assert that the query result matches the expected result
|
||||||
@@ -201,9 +215,9 @@ class TestZillizDBCollection:
|
|||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_vector"],
|
data=["query_vector"],
|
||||||
limit=1,
|
limit=1,
|
||||||
output_fields=["text", "url", "doc_id"],
|
output_fields=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert query_result_with_citations == [
|
assert query_result_with_citations == [
|
||||||
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"})
|
("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0})
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user