From 862ff6cca6502f63fb013ff5b83963fc6187e585 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Fri, 12 Jan 2024 14:15:39 +0530 Subject: [PATCH] [Bug fix] Fix embedding issue for opensearch and some other vector databases (#1163) --- embedchain/chunkers/base_chunker.py | 2 +- embedchain/embedchain.py | 10 +---- embedchain/vectordb/chroma.py | 4 -- embedchain/vectordb/elasticsearch.py | 3 -- embedchain/vectordb/opensearch.py | 18 ++------ embedchain/vectordb/pinecone.py | 1 - embedchain/vectordb/qdrant.py | 3 -- embedchain/vectordb/weaviate.py | 20 ++------- embedchain/vectordb/zilliz.py | 1 - tests/vectordb/test_elasticsearch_db.py | 5 +-- tests/vectordb/test_pinecone.py | 6 +-- tests/vectordb/test_qdrant.py | 5 +-- tests/vectordb/test_weaviate.py | 57 +++++++++++-------------- 13 files changed, 40 insertions(+), 95 deletions(-) diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 653d52a8..6ce10b67 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -27,7 +27,7 @@ class BaseChunker(JSONSerializable): chunk_ids = [] id_map = {} min_chunk_size = config.min_chunk_size if config is not None else 1 - logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters") + logging.info(f"Skipping chunks smaller than {min_chunk_size} characters") data_result = loader.load_data(src) data_records = data_result["data"] doc_id = data_result["doc_id"] diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index f52956a7..bb69051c 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -369,7 +369,7 @@ class EmbedChain(JSONSerializable): metadatas = embeddings_data["metadatas"] ids = embeddings_data["ids"] new_doc_id = embeddings_data["doc_id"] - embeddings = embeddings_data.get("embeddings") + if existing_doc_id and existing_doc_id == new_doc_id: print("Doc content has not changed. Skipping creating chunks and embeddings") return [], [], [], 0 @@ -433,13 +433,7 @@ class EmbedChain(JSONSerializable): # Count before, to calculate a delta in the end. chunks_before_addition = self.db.count() - self.db.add( - embeddings=embeddings, - documents=documents, - metadatas=metadatas, - ids=ids, - **kwargs, - ) + self.db.add(documents=documents, metadatas=metadatas, ids=ids, **kwargs) count_new_chunks = self.db.count() - chunks_before_addition print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}") diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 827f12d3..7e17eb55 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -129,17 +129,13 @@ class ChromaDB(BaseVectorDB): def add( self, - embeddings: list[list[float]], documents: list[str], metadatas: list[object], ids: list[str], - **kwargs: Optional[dict[str, Any]], ) -> Any: """ Add vectors to chroma database - :param embeddings: list of embeddings to add - :type embeddings: list[list[str]] :param documents: Documents :type documents: list[str] :param metadatas: Metadatas diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index d47344d6..23fac609 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -110,7 +110,6 @@ class ElasticsearchDB(BaseVectorDB): def add( self, - embeddings: list[list[float]], documents: list[str], metadatas: list[object], ids: list[str], @@ -118,8 +117,6 @@ class ElasticsearchDB(BaseVectorDB): ) -> Any: """ add data in vector database - :param embeddings: list of embeddings to add - :type embeddings: list[list[str]] :param documents: list of texts to add :type documents: list[str] :param metadatas: list of metadata associated with docs diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 1626f772..fe798c0e 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -114,22 +114,10 @@ class OpenSearchDB(BaseVectorDB): result["metadatas"].append({"doc_id": doc_id}) return result - def add( - self, - embeddings: list[list[str]], - documents: list[str], - metadatas: list[object], - ids: list[str], - **kwargs: Optional[dict[str, any]], - ): - """Add data in vector database. + def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]): + """Adds documents to the opensearch index""" - Args: - embeddings (list[list[str]]): list of embeddings to add. - documents (list[str]): list of texts to add. - metadatas (list[object]): list of metadata associated with docs. - ids (list[str]): IDs of docs. - """ + embeddings = self.embedder.embedding_fn(documents) for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"): batch_end = batch_start + self.BATCH_SIZE batch_documents = documents[batch_start:batch_end] diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 16076dda..982789ad 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -88,7 +88,6 @@ class PineconeDB(BaseVectorDB): def add( self, - embeddings: list[list[float]], documents: list[str], metadatas: list[object], ids: list[str], diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index b107568d..1f0a1b6d 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -122,15 +122,12 @@ class QdrantDB(BaseVectorDB): def add( self, - embeddings: list[list[float]], documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]], ): """add data in vector database - :param embeddings: list of embeddings for the corresponding documents to be added - :type documents: list[list[float]] :param documents: list of texts to add :type documents: list[str] :param metadatas: list of metadata associated with docs diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index b3535bbe..d2b70a04 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -1,6 +1,6 @@ import copy import os -from typing import Any, Optional, Union +from typing import Optional, Union try: import weaviate @@ -151,17 +151,8 @@ class WeaviateDB(BaseVectorDB): return {"ids": existing_ids} - def add( - self, - embeddings: list[list[float]], - documents: list[str], - metadatas: list[object], - ids: list[str], - **kwargs: Optional[dict[str, any]], - ): + def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]): """add data in vector database - :param embeddings: list of embeddings for the corresponding documents to be added - :type documents: list[list[float]] :param documents: list of texts to add :type documents: list[str] :param metadatas: list of metadata associated with docs @@ -191,12 +182,7 @@ class WeaviateDB(BaseVectorDB): ) def query( - self, - input_query: list[str], - n_results: int, - where: dict[str, any], - citations: bool = False, - **kwargs: Optional[dict[str, Any]], + self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index 657eb644..e957cd4d 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -108,7 +108,6 @@ class ZillizVectorDB(BaseVectorDB): def add( self, - embeddings: list[list[float]], documents: list[str], metadatas: list[object], ids: list[str], diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 953f7813..5445dbb6 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -28,14 +28,13 @@ class TestEsDB(unittest.TestCase): # Assert that the Elasticsearch client is stored in the ElasticsearchDB class. self.assertEqual(self.db.client, mock_client.return_value) - # Create some dummy data. - embeddings = [[1, 2, 3], [4, 5, 6]] + # Create some dummy data documents = ["This is a document.", "This is another document."] metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}] ids = ["doc_1", "doc_2"] # Add the data to the database. - self.db.add(embeddings, documents, metadatas, ids) + self.db.add(documents, metadatas, ids) search_response = { "hits": { diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py index 08a18a65..8cb08788 100644 --- a/tests/vectordb/test_pinecone.py +++ b/tests/vectordb/test_pinecone.py @@ -43,8 +43,8 @@ class TestPinecone: embedding_function = mock.Mock() base_embedder = BaseEmbedder() base_embedder.set_embedding_fn(embedding_function) - vectors = [[0, 0, 0], [1, 1, 1]] - embedding_function.return_value = vectors + embedding_function.return_value = [[0, 0, 0], [1, 1, 1]] + # Create a PineconeDb instance db = PineconeDB() app_config = AppConfig(collect_metrics=False) @@ -54,7 +54,7 @@ class TestPinecone: documents = ["This is a document.", "This is another document."] metadatas = [{}, {}] ids = ["doc1", "doc2"] - db.add(vectors, documents, metadatas, ids) + db.add(documents, metadatas, ids) expected_pinecone_upsert_args = [ {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}}, diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py index c38e5786..563f50be 100644 --- a/tests/vectordb/test_qdrant.py +++ b/tests/vectordb/test_qdrant.py @@ -75,11 +75,10 @@ class TestQdrantDB(unittest.TestCase): app_config = AppConfig(collect_metrics=False) App(config=app_config, db=db, embedding_model=embedder) - embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] documents = ["This is a test document.", "This is another test document."] metadatas = [{}, {}] ids = ["123", "456"] - db.add(embeddings, documents, metadatas, ids) + db.add(documents, metadatas, ids) qdrant_client_mock.return_value.upsert.assert_called_once_with( collection_name="embedchain-store-1526", points=Batch( @@ -96,7 +95,7 @@ class TestQdrantDB(unittest.TestCase): "metadata": {"text": "This is another test document."}, }, ], - vectors=embeddings, + vectors=[[1, 2, 3], [4, 5, 6]], ), ) diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py index ba4045a7..66e75e76 100644 --- a/tests/vectordb/test_weaviate.py +++ b/tests/vectordb/test_weaviate.py @@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase): weaviate_client_schema_mock.exists.return_value = False # Set the embedder embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -40,7 +40,7 @@ class TestWeaviateDb(unittest.TestCase): expected_class_obj = { "classes": [ { - "class": "Embedchain_store_1526", + "class": "Embedchain_store_1536", "vectorizer": "none", "properties": [ { @@ -53,12 +53,12 @@ class TestWeaviateDb(unittest.TestCase): }, { "name": "metadata", - "dataType": ["Embedchain_store_1526_metadata"], + "dataType": ["Embedchain_store_1536_metadata"], }, ], }, { - "class": "Embedchain_store_1526_metadata", + "class": "Embedchain_store_1536_metadata", "vectorizer": "none", "properties": [ { @@ -88,7 +88,7 @@ class TestWeaviateDb(unittest.TestCase): # Assert that the Weaviate client was initialized weaviate_mock.Client.assert_called_once() - self.assertEqual(db.index_name, "Embedchain_store_1526") + self.assertEqual(db.index_name, "Embedchain_store_1536") weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj) @patch("embedchain.vectordb.weaviate.weaviate") @@ -97,7 +97,7 @@ class TestWeaviateDb(unittest.TestCase): weaviate_client_mock = weaviate_mock.Client.return_value embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -117,7 +117,7 @@ class TestWeaviateDb(unittest.TestCase): # Set the embedder embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -126,30 +126,21 @@ class TestWeaviateDb(unittest.TestCase): App(config=app_config, db=db, embedding_model=embedder) db.BATCH_SIZE = 1 - embeddings = [[1, 2, 3], [4, 5, 6]] - documents = ["This is a test document.", "This is another test document."] - metadatas = [None, None] - ids = ["123", "456"] - db.add(embeddings, documents, metadatas, ids) + documents = ["This is test document"] + metadatas = [None] + ids = ["id_1"] + db.add(documents, metadatas, ids) # Check if the document was added to the database. weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3) weaviate_client_batch_enter_mock.add_data_object.assert_any_call( - data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0] - ) - weaviate_client_batch_enter_mock.add_data_object.assert_any_call( - data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1] + data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3] ) weaviate_client_batch_enter_mock.add_data_object.assert_any_call( - data_object={"identifier": ids[0], "text": documents[0]}, - class_name="Embedchain_store_1526", - vector=embeddings[0], - ) - weaviate_client_batch_enter_mock.add_data_object.assert_any_call( - data_object={"identifier": ids[1], "text": documents[1]}, - class_name="Embedchain_store_1526", - vector=embeddings[1], + data_object={"text": documents[0]}, + class_name="Embedchain_store_1536_metadata", + vector=[1, 2, 3], ) @patch("embedchain.vectordb.weaviate.weaviate") @@ -161,7 +152,7 @@ class TestWeaviateDb(unittest.TestCase): # Set the embedder embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -172,7 +163,7 @@ class TestWeaviateDb(unittest.TestCase): # Query for the document. db.query(input_query=["This is a test document."], n_results=1, where={}) - weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"]) + weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"]) weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]}) @patch("embedchain.vectordb.weaviate.weaviate") @@ -185,7 +176,7 @@ class TestWeaviateDb(unittest.TestCase): # Set the embedder embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -196,9 +187,9 @@ class TestWeaviateDb(unittest.TestCase): # Query for the document. db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}) - weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"]) + weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"]) weaviate_client_query_get_mock.with_where.assert_called_once_with( - {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"} + {"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"} ) weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]}) @@ -210,7 +201,7 @@ class TestWeaviateDb(unittest.TestCase): # Set the embedder embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -222,7 +213,7 @@ class TestWeaviateDb(unittest.TestCase): db.reset() weaviate_client_batch_mock.delete_objects.assert_called_once_with( - "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"} + "Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"} ) @patch("embedchain.vectordb.weaviate.weaviate") @@ -233,7 +224,7 @@ class TestWeaviateDb(unittest.TestCase): # Set the embedder embedder = BaseEmbedder() - embedder.set_vector_dimension(1526) + embedder.set_vector_dimension(1536) embedder.set_embedding_fn(mock_embedding_fn) # Create a Weaviate instance @@ -244,4 +235,4 @@ class TestWeaviateDb(unittest.TestCase): # Reset the database. db.count() - weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526") + weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")