From c8846e0e932bdc6fd26b51abd3318ce276ec9858 Mon Sep 17 00:00:00 2001 From: Rupesh Bansal Date: Thu, 19 Oct 2023 02:57:57 +0530 Subject: [PATCH] [Feature] Add Qdrant support (#822) --- docs/components/vector-databases.mdx | 21 ++- embedchain/config/vectordb/qdrant.py | 44 ++++++ embedchain/factory.py | 2 + embedchain/llm/openai.py | 3 +- embedchain/models/__init__.py | 1 - embedchain/models/vector_databases.py | 8 - embedchain/vectordb/qdrant.py | 213 ++++++++++++++++++++++++++ pyproject.toml | 4 +- tests/llm/test_base_llm.py | 4 +- tests/llm/test_cohere.py | 1 + tests/llm/test_huggingface.py | 2 + tests/llm/test_jina.py | 4 +- tests/llm/test_llama2.py | 2 + tests/llm/test_openai.py | 4 +- tests/llm/test_query.py | 4 +- tests/models/test_data_type.py | 3 +- tests/vectordb/test_qdrant.py | 158 +++++++++++++++++++ 17 files changed, 460 insertions(+), 18 deletions(-) create mode 100644 embedchain/config/vectordb/qdrant.py delete mode 100644 embedchain/models/vector_databases.py create mode 100644 embedchain/vectordb/qdrant.py create mode 100644 tests/vectordb/test_qdrant.py diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index 44bb0d92..94e965d3 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -183,12 +183,29 @@ vectordb: ## Qdrant -_Coming soon_ +In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/). + + +```python main.py +from embedchain import App + +# load qdrant configuration from yaml file +app = App.from_config(yaml_path="config.yaml") +``` + +```yaml config.yaml +vectordb: + provider: qdrant + config: + collection_name: my_qdrant_index +``` + ## Weaviate In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard). + ```python main.py from embedchain import App @@ -202,6 +219,6 @@ vectordb: config: collection_name: my_weaviate_index ``` - + diff --git a/embedchain/config/vectordb/qdrant.py b/embedchain/config/vectordb/qdrant.py new file mode 100644 index 00000000..4468c7b2 --- /dev/null +++ b/embedchain/config/vectordb/qdrant.py @@ -0,0 +1,44 @@ +from typing import Dict, Optional + +from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class QdrantDBConfig(BaseVectorDbConfig): + """ + Config to initialize an qdrant client. + :param url. qdrant url or list of nodes url to be used for connection + """ + + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + hnsw_config: Optional[Dict[str, any]] = None, + quantization_config: Optional[Dict[str, any]] = None, + on_disk: Optional[bool] = None, + **extra_params: Dict[str, any], + ): + """ + Initializes a configuration class instance for a qdrant client. + + :param collection_name: Default name for the collection, defaults to None + :type collection_name: Optional[str], optional + :param dir: Path to the database directory, where the database is stored, defaults to None + :type dir: Optional[str], optional + :param hnsw_config: Params for HNSW index + :type hnsw_config: Optional[Dict[str, any]], defaults to None + :param quantization_config: Params for quantization, if None - quantization will be disabled + :type quantization_config: Optional[Dict[str, any]], defaults to None + :param on_disk: If true - point`s payload will not be stored in memory. + It will be read from the disk every time it is requested. + This setting saves RAM by (slightly) increasing the response time. + Note: those payload values that are involved in filtering and are indexed - remain in RAM. + :type on_disk: bool, optional, defaults to None + """ + self.hnsw_config = hnsw_config + self.quantization_config = quantization_config + self.on_disk = on_disk + self.extra_params = extra_params + super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/factory.py b/embedchain/factory.py index dee01d2f..97453144 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -73,6 +73,7 @@ class VectorDBFactory: "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", "pinecone": "embedchain.vectordb.pinecone.PineconeDB", "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", + "qdrant": "embedchain.vectordb.qdrant.QdrantDB", } provider_to_config_class = { "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", @@ -80,6 +81,7 @@ class VectorDBFactory: "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", + "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig", } @classmethod diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index f7b8bad4..9e69085c 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -31,7 +31,8 @@ class OpenAILlm(BaseLlm): if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import \ + StreamingStdOutCallbackHandler chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) else: diff --git a/embedchain/models/__init__.py b/embedchain/models/__init__.py index fc073230..48887545 100644 --- a/embedchain/models/__init__.py +++ b/embedchain/models/__init__.py @@ -1,4 +1,3 @@ from .embedding_functions import EmbeddingFunctions # noqa: F401 from .providers import Providers # noqa: F401 -from .vector_databases import VectorDatabases # noqa: F401 from .vector_dimensions import VectorDimensions # noqa: F401 diff --git a/embedchain/models/vector_databases.py b/embedchain/models/vector_databases.py deleted file mode 100644 index 30f2c635..00000000 --- a/embedchain/models/vector_databases.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum - - -class VectorDatabases(Enum): - CHROMADB = "CHROMADB" - ELASTICSEARCH = "ELASTICSEARCH" - OPENSEARCH = "OPENSEARCH" - ZILLIZ = "ZILLIZ" diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py new file mode 100644 index 00000000..477fa58c --- /dev/null +++ b/embedchain/vectordb/qdrant.py @@ -0,0 +1,213 @@ +import copy +import os +import uuid +from typing import Dict, List, Optional + +try: + from qdrant_client import QdrantClient + from qdrant_client.http import models + from qdrant_client.http.models import Batch + from qdrant_client.models import Distance, VectorParams +except ImportError: + raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None + +from embedchain.config.vectordb.qdrant import QdrantDBConfig +from embedchain.vectordb.base import BaseVectorDB + + +class QdrantDB(BaseVectorDB): + """ + Qdrant as vector database + """ + + BATCH_SIZE = 10 + + def __init__(self, config: QdrantDBConfig = None): + """ + Qdrant as vector database + :param config. Qdrant database config to be used for connection + """ + if config is None: + config = QdrantDBConfig() + else: + if not isinstance(config, QdrantDBConfig): + raise TypeError( + "config is not a `QdrantDBConfig` instance. " + "Please make sure the type is right and that you are passing an instance." + ) + self.config = config + self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) + # Call parent init here because embedder is needed + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + """ + if not self.embedder: + raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") + + self.collection_name = self._get_or_create_collection() + self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"} + all_collections = self.client.get_collections() + collection_names = [collection.name for collection in all_collections.collections] + if self.collection_name not in collection_names: + self.client.recreate_collection( + collection_name=self.collection_name, + vectors_config=VectorParams( + size=self.embedder.vector_dimension, + distance=Distance.COSINE, + hnsw_config=self.config.hnsw_config, + quantization_config=self.config.quantization_config, + on_disk=self.config.on_disk, + ), + ) + + def _get_or_create_db(self): + return self.client + + def _get_or_create_collection(self): + return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-") + + def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + """ + Get existing doc ids present in vector database + + :param ids: _list of doc ids to check for existence + :type ids: List[str] + :param where: to filter data + :type where: Dict[str, any] + :param limit: The number of entries to be fetched + :type limit: Optional int, defaults to None + :return: All the existing IDs + :rtype: Set[str] + """ + if ids is None or len(ids) == 0: + return {"ids": []} + + keys = set(where.keys() if where is not None else set()) + + qdrant_must_filters = [ + models.FieldCondition( + key="identifier", + match=models.MatchAny( + any=ids, + ), + ) + ] + if len(keys.intersection(self.metadata_keys)) != 0: + for key in keys.intersection(self.metadata_keys): + qdrant_must_filters.append( + models.FieldCondition( + key="metadata.{}".format(key), + match=models.MatchValue( + value=where.get(key), + ), + ) + ) + + offset = 0 + existing_ids = [] + while offset is not None: + response = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=models.Filter(must=qdrant_must_filters), + offset=offset, + limit=self.BATCH_SIZE, + ) + offset = response[1] + for doc in response[0]: + existing_ids.append(doc.payload["identifier"]) + return {"ids": existing_ids} + + def add( + self, + embeddings: List[List[float]], + documents: List[str], + metadatas: List[object], + ids: List[str], + skip_embedding: bool, + ): + """add data in vector database + :param embeddings: list of embeddings for the corresponding documents to be added + :type documents: List[List[float]] + :param documents: list of texts to add + :type documents: List[str] + :param metadatas: list of metadata associated with docs + :type metadatas: List[object] + :param ids: ids of docs + :type ids: List[str] + :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be + generated or not + :type skip_embedding: bool + """ + if not skip_embedding: + embeddings = self.embedder.embedding_fn(documents) + + payloads = [] + qdrant_ids = [] + for id, document, metadata in zip(ids, documents, metadatas): + metadata["text"] = document + qdrant_ids.append(str(uuid.uuid4())) + payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)}) + for i in range(0, len(qdrant_ids), self.BATCH_SIZE): + self.client.upsert( + collection_name=self.collection_name, + points=Batch( + ids=qdrant_ids[i : i + self.BATCH_SIZE], + payloads=payloads[i : i + self.BATCH_SIZE], + vectors=embeddings[i : i + self.BATCH_SIZE], + ), + ) + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + """ + query contents from vector database based on vector similarity + :param input_query: list of query string + :type input_query: List[str] + :param n_results: no of similar documents to fetch from database + :type n_results: int + :param where: Optional. to filter data + :type where: Dict[str, any] + :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be + generated or not + :type skip_embedding: bool + :return: Database contents that are the result of the query + :rtype: List[str] + """ + if not skip_embedding: + query_vector = self.embedder.embedding_fn([input_query])[0] + else: + query_vector = input_query + + keys = set(where.keys() if where is not None else set()) + + qdrant_must_filters = [] + if len(keys.intersection(self.metadata_keys)) != 0: + for key in keys.intersection(self.metadata_keys): + qdrant_must_filters.append( + models.FieldCondition( + key="payload.metadata.{}".format(key), + match=models.MatchValue( + value=where.get(key), + ), + ) + ) + results = self.client.search( + collection_name=self.collection_name, + query_filter=models.Filter(must=qdrant_must_filters), + query_vector=query_vector, + limit=n_results, + ) + response = [] + for result in results: + response.append(result.payload.get("text", "")) + return response + + def count(self) -> int: + response = self.client.get_collection(collection_name=self.collection_name) + return response.points_count + + def reset(self): + self.client.delete_collection(collection_name=self.collection_name) + self._initialize() diff --git a/pyproject.toml b/pyproject.toml index ca9175e4..85d78916 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ exclude = ''' color = true [tool.poetry.dependencies] -python = ">=3.9,<3.9.7 || >3.9.7,<4.0" +python = ">=3.9,<3.13" python-dotenv = "^1.0.0" langchain = "^0.0.279" requests = "^2.31.0" @@ -114,6 +114,7 @@ cohere = { version = "^4.27", optional= true } weaviate-client = { version = "^3.24.1", optional= true } docx2txt = { version="^0.8", optional=true } pinecone-client = { version = "^2.2.4", optional = true } +qdrant-client = { version = "1.6.3", optional = true } unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true} pillow = { version = "10.0.1", optional = true } torchvision = { version = ">=0.15.1, !=0.15.2", optional = true } @@ -151,6 +152,7 @@ slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] weaviate = ["weaviate-client"] pinecone = ["pinecone-client"] +qdrant = ["qdrant-client"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"] huggingface_hub=["huggingface_hub"] cohere = ["cohere"] diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py index c740e91a..ddbc4747 100644 --- a/tests/llm/test_base_llm.py +++ b/tests/llm/test_base_llm.py @@ -1,5 +1,7 @@ -import pytest from string import Template + +import pytest + from embedchain.llm.base import BaseLlm, BaseLlmConfig diff --git a/tests/llm/test_cohere.py b/tests/llm/test_cohere.py index 5d1a625d..1bee4cff 100644 --- a/tests/llm/test_cohere.py +++ b/tests/llm/test_cohere.py @@ -1,4 +1,5 @@ import os + import pytest from embedchain.config import BaseLlmConfig diff --git a/tests/llm/test_huggingface.py b/tests/llm/test_huggingface.py index a8a7a646..c43b099e 100644 --- a/tests/llm/test_huggingface.py +++ b/tests/llm/test_huggingface.py @@ -1,6 +1,8 @@ import importlib import os + import pytest + from embedchain.config import BaseLlmConfig from embedchain.llm.huggingface import HuggingFaceLlm diff --git a/tests/llm/test_jina.py b/tests/llm/test_jina.py index 9ca3f647..4639c410 100644 --- a/tests/llm/test_jina.py +++ b/tests/llm/test_jina.py @@ -1,8 +1,10 @@ import os + import pytest +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from embedchain.config import BaseLlmConfig from embedchain.llm.jina import JinaLlm -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler @pytest.fixture diff --git a/tests/llm/test_llama2.py b/tests/llm/test_llama2.py index 688149b1..40885fd2 100644 --- a/tests/llm/test_llama2.py +++ b/tests/llm/test_llama2.py @@ -1,5 +1,7 @@ import os + import pytest + from embedchain.llm.llama2 import Llama2Llm diff --git a/tests/llm/test_openai.py b/tests/llm/test_openai.py index a1795a6c..fc823337 100644 --- a/tests/llm/test_openai.py +++ b/tests/llm/test_openai.py @@ -1,8 +1,10 @@ import os + import pytest +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from embedchain.config import BaseLlmConfig from embedchain.llm.openai import OpenAILlm -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler @pytest.fixture diff --git a/tests/llm/test_query.py b/tests/llm/test_query.py index b208e00c..9ebbecd4 100644 --- a/tests/llm/test_query.py +++ b/tests/llm/test_query.py @@ -1,6 +1,8 @@ import os -import pytest from unittest.mock import MagicMock, patch + +import pytest + from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig diff --git a/tests/models/test_data_type.py b/tests/models/test_data_type.py index f0baa588..bf3d6e1e 100644 --- a/tests/models/test_data_type.py +++ b/tests/models/test_data_type.py @@ -1,4 +1,5 @@ -from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) def test_subclass_types_in_data_type(): diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py new file mode 100644 index 00000000..47b54504 --- /dev/null +++ b/tests/vectordb/test_qdrant.py @@ -0,0 +1,158 @@ +import unittest +import uuid + +from mock import patch +from qdrant_client.http import models +from qdrant_client.http.models import Batch + +from embedchain import App +from embedchain.config import AppConfig +from embedchain.config.vectordb.pinecone import PineconeDBConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.vectordb.qdrant import QdrantDB + + +class TestQdrantDB(unittest.TestCase): + TEST_UUIDS = ["abc", "def", "ghi"] + + def test_incorrect_config_throws_error(self): + """Test the init method of the Qdrant class throws error for incorrect config""" + with self.assertRaises(TypeError): + QdrantDB(config=PineconeDBConfig()) + + @patch("embedchain.vectordb.qdrant.QdrantClient") + def test_initialize(self, qdrant_client_mock): + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Qdrant instance + db = QdrantDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + self.assertEqual(db.collection_name, "embedchain-store-1526") + self.assertEqual(db.client, qdrant_client_mock.return_value) + qdrant_client_mock.return_value.get_collections.assert_called_once() + + @patch("embedchain.vectordb.qdrant.QdrantClient") + def test_get(self, qdrant_client_mock): + qdrant_client_mock.return_value.scroll.return_value = ([], None) + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Qdrant instance + db = QdrantDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + resp = db.get(ids=[], where={}) + self.assertEqual(resp, {"ids": []}) + resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"}) + self.assertEqual(resp2, {"ids": []}) + + @patch("embedchain.vectordb.qdrant.QdrantClient") + @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS) + def test_add(self, uuid_mock, qdrant_client_mock): + qdrant_client_mock.return_value.scroll.return_value = ([], None) + + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Qdrant instance + db = QdrantDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + documents = ["This is a test document.", "This is another test document."] + metadatas = [{}, {}] + ids = ["123", "456"] + skip_embedding = True + db.add(embeddings, documents, metadatas, ids, skip_embedding) + qdrant_client_mock.return_value.upsert.assert_called_once_with( + collection_name="embedchain-store-1526", + points=Batch( + ids=["def", "ghi"], + payloads=[ + { + "identifier": "123", + "text": "This is a test document.", + "metadata": {"text": "This is a test document."}, + }, + { + "identifier": "456", + "text": "This is another test document.", + "metadata": {"text": "This is another test document."}, + }, + ], + vectors=embeddings, + ), + ) + + @patch("embedchain.vectordb.qdrant.QdrantClient") + def test_query(self, qdrant_client_mock): + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Qdrant instance + db = QdrantDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + # Query for the document. + db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True) + + qdrant_client_mock.return_value.search.assert_called_once_with( + collection_name="embedchain-store-1526", + query_filter=models.Filter( + must=[ + models.FieldCondition( + key="payload.metadata.doc_id", + match=models.MatchValue( + value="123", + ), + ) + ] + ), + query_vector=["This is a test document."], + limit=1, + ) + + @patch("embedchain.vectordb.qdrant.QdrantClient") + def test_count(self, qdrant_client_mock): + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Qdrant instance + db = QdrantDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + db.count() + qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526") + + @patch("embedchain.vectordb.qdrant.QdrantClient") + def test_reset(self, qdrant_client_mock): + # Set the embedder + embedder = BaseEmbedder() + embedder.set_vector_dimension(1526) + + # Create a Qdrant instance + db = QdrantDB() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + db.reset() + qdrant_client_mock.return_value.delete_collection.assert_called_once_with( + collection_name="embedchain-store-1526" + ) + + +if __name__ == "__main__": + unittest.main()