From cdfd6519c8e57698e879fd7de096805ff5d543c8 Mon Sep 17 00:00:00 2001 From: Rupesh Bansal Date: Wed, 18 Oct 2023 10:48:53 +0530 Subject: [PATCH] [Feature] Add support for weaviate vector db (#782) --- configs/weaviate.yaml | 4 + docs/components/vector-databases.mdx | 17 +- embedchain/config/vectordb/weaviate.py | 16 ++ embedchain/embedchain.py | 3 - embedchain/factory.py | 2 + embedchain/llm/base.py | 1 - embedchain/vectordb/weaviate.py | 297 +++++++++++++++++++++++++ pyproject.toml | 2 + tests/vectordb/test_weaviate.py | 244 ++++++++++++++++++++ 9 files changed, 581 insertions(+), 5 deletions(-) create mode 100644 configs/weaviate.yaml create mode 100644 embedchain/config/vectordb/weaviate.py create mode 100644 embedchain/vectordb/weaviate.py create mode 100644 tests/vectordb/test_weaviate.py diff --git a/configs/weaviate.yaml b/configs/weaviate.yaml new file mode 100644 index 00000000..a27623ab --- /dev/null +++ b/configs/weaviate.yaml @@ -0,0 +1,4 @@ +vectordb: + provider: weaviate + config: + collection_name: my_weaviate_index diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index 2b30140d..44bb0d92 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -187,6 +187,21 @@ _Coming soon_ ## Weaviate -_Coming soon_ +In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard). + +```python main.py +from embedchain import App + +# load weaviate configuration from yaml file +app = App.from_config(yaml_path="config.yaml") +``` + +```yaml config.yaml +vectordb: + provider: weaviate + config: + collection_name: my_weaviate_index +``` + diff --git a/embedchain/config/vectordb/weaviate.py b/embedchain/config/vectordb/weaviate.py new file mode 100644 index 00000000..4035877b --- /dev/null +++ b/embedchain/config/vectordb/weaviate.py @@ -0,0 +1,16 @@ +from typing import Dict, Optional + +from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class WeaviateDBConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + **extra_params: Dict[str, any], + ): + self.extra_params = extra_params + super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 133e3e42..47544769 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -359,7 +359,6 @@ class EmbedChain(JSONSerializable): db_result = self.db.get(ids=ids, where=where) # optional filter existing_ids = set(db_result["ids"]) - if len(existing_ids): data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids} @@ -436,7 +435,6 @@ class EmbedChain(JSONSerializable): :rtype: List[str] """ query_config = config or self.llm.config - if where is not None: where = where elif query_config is not None and query_config.where is not None: @@ -463,7 +461,6 @@ class EmbedChain(JSONSerializable): where=where, skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), ) - return contents def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str: diff --git a/embedchain/factory.py b/embedchain/factory.py index e1ebcf37..dee01d2f 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -72,12 +72,14 @@ class VectorDBFactory: "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", "pinecone": "embedchain.vectordb.pinecone.PineconeDB", + "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", } provider_to_config_class = { "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", + "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", } @classmethod diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 2a1819c5..c90771a7 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -206,7 +206,6 @@ class BaseLlm(JSONSerializable): k["web_search_result"] = self.access_search_and_get_results(input_query) prompt = self.generate_prompt(input_query, contexts, **k) logging.info(f"Prompt: {prompt}") - if dry_run: return prompt diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py new file mode 100644 index 00000000..6416cb35 --- /dev/null +++ b/embedchain/vectordb/weaviate.py @@ -0,0 +1,297 @@ +import copy +import os +from typing import Dict, List, Optional + +try: + import weaviate +except ImportError: + raise ImportError( + "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`" + ) from None + +from embedchain.config.vectordb.weaviate import WeaviateDBConfig +from embedchain.helper.json_serializable import register_deserializable +from embedchain.vectordb.base import BaseVectorDB + + +@register_deserializable +class WeaviateDB(BaseVectorDB): + """ + Weaviate as vector database + """ + + BATCH_SIZE = 100 + + def __init__( + self, + config: Optional[WeaviateDBConfig] = None, + ): + """Weaviate as vector database. + :param config: Weaviate database config, defaults to None + :type config: WeaviateDBConfig, optional + :raises ValueError: No config provided + """ + if config is None: + self.config = WeaviateDBConfig() + else: + if not isinstance(config, WeaviateDBConfig): + raise TypeError( + "config is not a `WeaviateDBConfig` instance. " + "Please make sure the type is right and that you are passing an instance." + ) + self.config = config + self.client = weaviate.Client( + url=os.environ.get("WEAVIATE_ENDPOINT"), + auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")), + **self.config.extra_params, + ) + + # Call parent init here because embedder is needed + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + """ + + if not self.embedder: + raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") + + self.index_name = self._get_index_name() + self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"} + if not self.client.schema.exists(self.index_name): + # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier + # The none vectorizer is crucial as we have our own custom embedding function + class_obj = { + "classes": [ + { + "class": self.index_name, + "vectorizer": "none", + "properties": [ + { + "name": "identifier", + "dataType": ["text"], + }, + { + "name": "text", + "dataType": ["text"], + }, + { + "name": "metadata", + "dataType": [self.index_name + "_metadata"], + }, + ], + }, + { + "class": self.index_name + "_metadata", + "vectorizer": "none", + "properties": [ + { + "name": "data_type", + "dataType": ["text"], + }, + { + "name": "doc_id", + "dataType": ["text"], + }, + { + "name": "url", + "dataType": ["text"], + }, + { + "name": "hash", + "dataType": ["text"], + }, + { + "name": "app_id", + "dataType": ["text"], + }, + { + "name": "text", + "dataType": ["text"], + }, + ], + }, + ] + } + + self.client.schema.create(class_obj) + + def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + """ + Get existing doc ids present in vector database + :param ids: _list of doc ids to check for existance + :type ids: List[str] + :param where: to filter data + :type where: Dict[str, any] + :return: ids + :rtype: Set[str] + """ + + if ids is None or len(ids) == 0: + return {"ids": []} + + existing_ids = [] + cursor = None + has_iterated_once = False + while cursor is not None or not has_iterated_once: + has_iterated_once = True + results = self._query_with_cursor( + self.client.query.get(self.index_name, ["identifier"]) + .with_additional(["id"]) + .with_limit(self.BATCH_SIZE), + cursor, + ) + fetched_results = results["data"]["Get"].get(self.index_name, []) + if len(fetched_results) == 0: + break + for result in fetched_results: + existing_ids.append(result["identifier"]) + cursor = result["_additional"]["id"] + + return {"ids": existing_ids} + + def add( + self, + embeddings: List[List[float]], + documents: List[str], + metadatas: List[object], + ids: List[str], + skip_embedding: bool, + ): + """add data in vector database + :param embeddings: list of embeddings for the corresponding documents to be added + :type documents: List[List[float]] + :param documents: list of texts to add + :type documents: List[str] + :param metadatas: list of metadata associated with docs + :type metadatas: List[object] + :param ids: ids of docs + :type ids: List[str] + :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be + generated or not + :type skip_embedding: bool + """ + + print("Adding documents to Weaviate...") + if not skip_embedding: + embeddings = self.embedder.embedding_fn(documents) + self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch + with self.client.batch as batch: # Initialize a batch process + for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): + doc = {"identifier": id, "text": text} + updated_metadata = {"text": text} + if metadata is not None: + updated_metadata.update(**metadata) + + obj_uuid = batch.add_data_object( + data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding + ) + metadata_uuid = batch.add_data_object( + data_object=copy.deepcopy(updated_metadata), + class_name=self.index_name + "_metadata", + vector=embedding, + ) + batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + """ + query contents from vector database based on vector similarity + :param input_query: list of query string + :type input_query: List[str] + :param n_results: no of similar documents to fetch from database + :type n_results: int + :param where: Optional. to filter data + :type where: Dict[str, any] + :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be + generated or not + :type skip_embedding: bool + :return: Database contents that are the result of the query + :rtype: List[str] + """ + if not skip_embedding: + query_vector = self.embedder.embedding_fn([input_query])[0] + else: + query_vector = input_query + keys = set(where.keys() if where is not None else set()) + if len(keys.intersection(self.metadata_keys)) != 0: + weaviate_where_operands = [] + for key in keys: + if key in self.metadata_keys: + weaviate_where_operands.append( + { + "path": ["metadata", self.index_name + "_metadata", key], + "operator": "Equal", + "valueText": where.get(key), + } + ) + if len(weaviate_where_operands) == 1: + weaviate_where_clause = weaviate_where_operands[0] + else: + weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} + + results = ( + self.client.query.get(self.index_name, ["text"]) + .with_where(weaviate_where_clause) + .with_near_vector({"vector": query_vector}) + .with_limit(n_results) + .do() + ) + else: + results = ( + self.client.query.get(self.index_name, ["text"]) + .with_near_vector({"vector": query_vector}) + .with_limit(n_results) + .do() + ) + matched_tokens = [] + for result in results["data"]["Get"].get(self.index_name): + matched_tokens.append(result["text"]) + + return matched_tokens + + def set_collection_name(self, name: str): + """ + Set the name of the collection. A collection is an isolated space for vectors. + :param name: Name of the collection. + :type name: str + """ + if not isinstance(name, str): + raise TypeError("Collection name must be a string") + self.config.collection_name = name + + def count(self) -> int: + """ + Count number of documents/chunks embedded in the database. + :return: number of documents + :rtype: int + """ + data = self.client.query.aggregate(self.index_name).with_meta_count().do() + return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"] + + def _get_or_create_db(self): + """Called during initialization""" + return self.client + + def reset(self): + """ + Resets the database. Deletes all embeddings irreversibly. + """ + # Delete all data from the database + self.client.batch.delete_objects( + self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"} + ) + + # Weaviate internally by default capitalizes the class name + def _get_index_name(self) -> str: + """Get the Weaviate index for a collection + :return: Weaviate index + :rtype: str + """ + return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize() + + def _query_with_cursor(self, query, cursor): + if cursor is not None: + query.with_after(cursor) + results = query.do() + return results diff --git a/pyproject.toml b/pyproject.toml index 9d60c5a9..e9d68a1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ fastapi-poe = { version = "0.0.16", optional = true } discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } cohere = { version = "^4.27", optional= true } +weaviate-client = { version = "^3.24.1", optional= true } docx2txt = { version="^0.8", optional=true } pinecone-client = { version = "^2.2.4", optional = true } unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true} @@ -145,6 +146,7 @@ poe = ["fastapi-poe"] discord = ["discord"] slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] +weaviate = ["weaviate-client"] pinecone = ["pinecone-client"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"] huggingface_hub=["huggingface_hub"] diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py new file mode 100644 index 00000000..5ced280a --- /dev/null +++ b/tests/vectordb/test_weaviate.py @@ -0,0 +1,244 @@ +import unittest +from unittest.mock import patch + +from embedchain import App +from embedchain.config import AppConfig +from embedchain.config.vectordb.pinecone import PineconeDBConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.vectordb.weaviate import WeaviateDB + + +class TestWeaviateDb(unittest.TestCase): + def test_incorrect_config_throws_error(self): + """Test the init method of the WeaviateDb class throws error for incorrect config""" + with self.assertRaises(TypeError): + WeaviateDB(config=PineconeDBConfig()) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_initialize(self, weaviate_mock): + """Test the init method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + weaviate_client_schema_mock = weaviate_client_mock.schema + + # Mock that schema doesn't already exist so that a new schema is created + weaviate_client_schema_mock.exists.return_value = False + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + expected_class_obj = { + "classes": [ + { + "class": "Embedchain_store_1526", + "vectorizer": "none", + "properties": [ + { + "name": "identifier", + "dataType": ["text"], + }, + { + "name": "text", + "dataType": ["text"], + }, + { + "name": "metadata", + "dataType": ["Embedchain_store_1526_metadata"], + }, + ], + }, + { + "class": "Embedchain_store_1526_metadata", + "vectorizer": "none", + "properties": [ + { + "name": "data_type", + "dataType": ["text"], + }, + { + "name": "doc_id", + "dataType": ["text"], + }, + { + "name": "url", + "dataType": ["text"], + }, + { + "name": "hash", + "dataType": ["text"], + }, + { + "name": "app_id", + "dataType": ["text"], + }, + { + "name": "text", + "dataType": ["text"], + }, + ], + }, + ] + } + + # Assert that the Weaviate client was initialized + weaviate_mock.Client.assert_called_once() + self.assertEqual(db.index_name, "Embedchain_store_1526") + weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_get_or_create_db(self, weaviate_mock): + """Test the _get_or_create_db method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + expected_client = db._get_or_create_db() + self.assertEqual(expected_client, weaviate_client_mock) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_add(self, weaviate_mock): + """Test the add method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + weaviate_client_batch_mock = weaviate_client_mock.batch + weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + db.BATCH_SIZE = 1 + + embeddings = [[1, 2, 3], [4, 5, 6]] + documents = ["This is a test document.", "This is another test document."] + metadatas = [None, None] + ids = ["123", "456"] + skip_embedding = True + db.add(embeddings, documents, metadatas, ids, skip_embedding) + + # Check if the document was added to the database. + weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3) + weaviate_client_batch_enter_mock.add_data_object.assert_any_call( + data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0] + ) + weaviate_client_batch_enter_mock.add_data_object.assert_any_call( + data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1] + ) + + weaviate_client_batch_enter_mock.add_data_object.assert_any_call( + data_object={"identifier": ids[0], "text": documents[0]}, + class_name="Embedchain_store_1526", + vector=embeddings[0], + ) + weaviate_client_batch_enter_mock.add_data_object.assert_any_call( + data_object={"identifier": ids[1], "text": documents[1]}, + class_name="Embedchain_store_1526", + vector=embeddings[1], + ) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_query_without_where(self, weaviate_mock): + """Test the query method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + weaviate_client_query_mock = weaviate_client_mock.query + weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + # Query for the document. + db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True) + + weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"]) + weaviate_client_query_get_mock.with_near_vector.assert_called_once_with( + {"vector": ["This is a test document."]} + ) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_query_with_where(self, weaviate_mock): + """Test the query method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + weaviate_client_query_mock = weaviate_client_mock.query + weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value + weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + # Query for the document. + db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True) + + weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"]) + weaviate_client_query_get_mock.with_where.assert_called_once_with( + {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"} + ) + weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with( + {"vector": ["This is a test document."]} + ) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_reset(self, weaviate_mock): + """Test the reset method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + weaviate_client_batch_mock = weaviate_client_mock.batch + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + # Reset the database. + db.reset() + + weaviate_client_batch_mock.delete_objects.assert_called_once_with( + "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"} + ) + + @patch("embedchain.vectordb.weaviate.weaviate") + def test_count(self, weaviate_mock): + """Test the reset method of the WeaviateDb class.""" + weaviate_client_mock = weaviate_mock.Client.return_value + weaviate_client_query = weaviate_client_mock.query + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Weaviate instance + db = WeaviateDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + # Reset the database. + db.count() + + weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")