diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index cada461d..192ed3e6 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -515,7 +515,7 @@ class EmbedChain(JSONSerializable): where: Optional[Dict] = None, citations: bool = False, **kwargs: Dict[str, Any], - ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: + ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]: """ Queries the vector database based on the given input query. Gets relevant doc based on the query and then passes it to an @@ -566,7 +566,7 @@ class EmbedChain(JSONSerializable): where: Optional[Dict[str, str]] = None, citations: bool = False, **kwargs: Dict[str, Any], - ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: + ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]: """ Queries the vector database on the given input query. Gets relevant doc based on the query and then passes it to an diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 7b6fd8fe..32e528f0 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -200,7 +200,7 @@ class ChromaDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ Query contents from vector database based on vector similarity @@ -250,6 +250,7 @@ class ChromaDB(BaseVectorDB): context = result[0].page_content if citations: metadata = result[0].metadata + metadata["score"] = result[1] contexts.append((context, metadata)) else: contexts.append(context) diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index e7d4b0a1..11610b12 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -164,7 +164,7 @@ class ElasticsearchDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ query contents from vector data base based on vector similarity @@ -210,6 +210,7 @@ class ElasticsearchDB(BaseVectorDB): context = doc["_source"]["text"] if citations: metadata = doc["_source"]["metadata"] + metadata["score"] = doc["_score"] contexts.append(tuple((context, metadata))) else: contexts.append(context) diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 4cc2f8a2..a1f408f1 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -169,7 +169,7 @@ class OpenSearchDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ query contents from vector data base based on vector similarity @@ -202,7 +202,7 @@ class OpenSearchDB(BaseVectorDB): if "app_id" in where: app_id = where["app_id"] pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}} - docs = docsearch.similarity_search( + docs = docsearch.similarity_search_with_score( input_query, search_type="script_scoring", space_type="cosinesimil", @@ -215,10 +215,12 @@ class OpenSearchDB(BaseVectorDB): ) contexts = [] - for doc in docs: + for doc, score in docs: context = doc.page_content if citations: - contexts.append(tuple((context, doc.metadata))) + metadata = doc.metadata + metadata["score"] = score + contexts.append(tuple((context, metadata))) else: contexts.append(context) return contexts diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index a15d03fe..cd039d62 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -127,7 +127,7 @@ class PineconeDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -154,6 +154,7 @@ class PineconeDB(BaseVectorDB): metadata = doc["metadata"] context = metadata["text"] if citations: + metadata["score"] = doc["score"] contexts.append(tuple((context, metadata))) else: contexts.append(context) diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 1ca111df..e9df0217 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -170,7 +170,7 @@ class QdrantDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -219,6 +219,7 @@ class QdrantDB(BaseVectorDB): context = result.payload["text"] if citations: metadata = result.payload["metadata"] + metadata["score"] = result.score contexts.append(tuple((context, metadata))) else: contexts.append(context) diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 37cd419a..620087bf 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -205,7 +205,7 @@ class WeaviateDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -255,6 +255,7 @@ class WeaviateDB(BaseVectorDB): .with_where(weaviate_where_clause) .with_near_vector({"vector": query_vector}) .with_limit(n_results) + .with_additional(["distance"]) .do() ) else: @@ -262,6 +263,7 @@ class WeaviateDB(BaseVectorDB): self.client.query.get(self.index_name, data_fields) .with_near_vector({"vector": query_vector}) .with_limit(n_results) + .with_additional(["distance"]) .do() ) @@ -271,6 +273,8 @@ class WeaviateDB(BaseVectorDB): context = doc["text"] if citations: metadata = doc["metadata"][0] + score = doc["_additional"]["distance"] + metadata["score"] = score contexts.append((context, metadata)) else: contexts.append(context) diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index 00d7a23e..ca398f14 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -135,7 +135,7 @@ class ZillizVectorDB(BaseVectorDB): skip_embedding: bool, citations: bool = False, **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + ) -> Union[List[Tuple[str, Dict]], List[str]]: """ Query contents from vector data base based on vector similarity @@ -159,7 +159,7 @@ class ZillizVectorDB(BaseVectorDB): if not isinstance(where, str): where = None - output_fields = ["text", "url", "doc_id"] + output_fields = ["*"] if skip_embedding: query_vector = input_query query_result = self.client.search( @@ -181,12 +181,18 @@ class ZillizVectorDB(BaseVectorDB): output_fields=output_fields, **kwargs, ) - + query_result = query_result[0] contexts = [] for query in query_result: - data = query[0]["entity"] + data = query["entity"] + score = query["distance"] context = data["text"] + + if "embeddings" in data: + data.pop("embeddings") + if citations: + data["score"] = score contexts.append(tuple((context, data))) else: contexts.append(context) diff --git a/pyproject.toml b/pyproject.toml index b93a0618..9eb94680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.44" +version = "0.1.45" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index d2a57689..43f7d88b 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -342,8 +342,22 @@ def test_chroma_db_collection_query(app_with_settings): input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True ) expected_value_with_citations = [ - ("document", {"url": "url_1", "doc_id": "doc_id_1"}), - ("document2", {"url": "url_2", "doc_id": "doc_id_2"}), + ( + "document", + { + "url": "url_1", + "doc_id": "doc_id_1", + "score": 0.0, + }, + ), + ( + "document2", + { + "url": "url_2", + "doc_id": "doc_id_2", + "score": 1.0, + }, + ), ] assert data_with_citations == expected_value_with_citations diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 9ceaa846..2a97e865 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -66,8 +66,8 @@ class TestEsDB(unittest.TestCase): results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True) expected_results_with_citations = [ - ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1"}), - ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2"}), + ("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}), + ("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}), ] self.assertEqual(results_with_citations, expected_results_with_citations) diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py index bdc0fc15..d4d9fdd4 100644 --- a/tests/vectordb/test_zilliz_db.py +++ b/tests/vectordb/test_zilliz_db.py @@ -123,7 +123,14 @@ class TestZillizDBCollection: # Mock the MilvusClient search method with patch.object(zilliz_db.client, "search") as mock_search: # Mock the search result - mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]] + mock_search.return_value = [ + [ + { + "distance": 0.5, + "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]}, + } + ] + ] # Call the query method with skip_embedding=True query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True) @@ -133,7 +140,7 @@ class TestZillizDBCollection: collection_name=mock_config.collection_name, data=["query_text"], limit=1, - output_fields=["text", "url", "doc_id"], + output_fields=["*"], ) # Assert that the query result matches the expected result @@ -147,11 +154,11 @@ class TestZillizDBCollection: collection_name=mock_config.collection_name, data=["query_text"], limit=1, - output_fields=["text", "url", "doc_id"], + output_fields=["*"], ) assert query_result_with_citations == [ - ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}) + ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5}) ] @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True) @@ -177,7 +184,14 @@ class TestZillizDBCollection: mock_embedder.embedding_fn.return_value = ["query_vector"] # Mock the search result - mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]] + mock_search.return_value = [ + [ + { + "distance": 0.0, + "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]}, + } + ] + ] # Call the query method with skip_embedding=False query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False) @@ -187,7 +201,7 @@ class TestZillizDBCollection: collection_name=mock_config.collection_name, data=["query_vector"], limit=1, - output_fields=["text", "url", "doc_id"], + output_fields=["*"], ) # Assert that the query result matches the expected result @@ -201,9 +215,9 @@ class TestZillizDBCollection: collection_name=mock_config.collection_name, data=["query_vector"], limit=1, - output_fields=["text", "url", "doc_id"], + output_fields=["*"], ) assert query_result_with_citations == [ - ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}) + ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0}) ]