diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx
index 44bb0d92..94e965d3 100644
--- a/docs/components/vector-databases.mdx
+++ b/docs/components/vector-databases.mdx
@@ -183,12 +183,29 @@ vectordb:
## Qdrant
-_Coming soon_
+In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).
+
+
+```python main.py
+from embedchain import App
+
+# load qdrant configuration from yaml file
+app = App.from_config(yaml_path="config.yaml")
+```
+
+```yaml config.yaml
+vectordb:
+ provider: qdrant
+ config:
+ collection_name: my_qdrant_index
+```
+
## Weaviate
In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard).
+
```python main.py
from embedchain import App
@@ -202,6 +219,6 @@ vectordb:
config:
collection_name: my_weaviate_index
```
-
+
diff --git a/embedchain/config/vectordb/qdrant.py b/embedchain/config/vectordb/qdrant.py
new file mode 100644
index 00000000..4468c7b2
--- /dev/null
+++ b/embedchain/config/vectordb/qdrant.py
@@ -0,0 +1,44 @@
+from typing import Dict, Optional
+
+from embedchain.config.vectordb.base import BaseVectorDbConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class QdrantDBConfig(BaseVectorDbConfig):
+ """
+ Config to initialize an qdrant client.
+ :param url. qdrant url or list of nodes url to be used for connection
+ """
+
+ def __init__(
+ self,
+ collection_name: Optional[str] = None,
+ dir: Optional[str] = None,
+ hnsw_config: Optional[Dict[str, any]] = None,
+ quantization_config: Optional[Dict[str, any]] = None,
+ on_disk: Optional[bool] = None,
+ **extra_params: Dict[str, any],
+ ):
+ """
+ Initializes a configuration class instance for a qdrant client.
+
+ :param collection_name: Default name for the collection, defaults to None
+ :type collection_name: Optional[str], optional
+ :param dir: Path to the database directory, where the database is stored, defaults to None
+ :type dir: Optional[str], optional
+ :param hnsw_config: Params for HNSW index
+ :type hnsw_config: Optional[Dict[str, any]], defaults to None
+ :param quantization_config: Params for quantization, if None - quantization will be disabled
+ :type quantization_config: Optional[Dict[str, any]], defaults to None
+ :param on_disk: If true - point`s payload will not be stored in memory.
+ It will be read from the disk every time it is requested.
+ This setting saves RAM by (slightly) increasing the response time.
+ Note: those payload values that are involved in filtering and are indexed - remain in RAM.
+ :type on_disk: bool, optional, defaults to None
+ """
+ self.hnsw_config = hnsw_config
+ self.quantization_config = quantization_config
+ self.on_disk = on_disk
+ self.extra_params = extra_params
+ super().__init__(collection_name=collection_name, dir=dir)
diff --git a/embedchain/factory.py b/embedchain/factory.py
index dee01d2f..97453144 100644
--- a/embedchain/factory.py
+++ b/embedchain/factory.py
@@ -73,6 +73,7 @@ class VectorDBFactory:
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
+ "qdrant": "embedchain.vectordb.qdrant.QdrantDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
@@ -80,6 +81,7 @@ class VectorDBFactory:
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
+ "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
}
@classmethod
diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py
index f7b8bad4..9e69085c 100644
--- a/embedchain/llm/openai.py
+++ b/embedchain/llm/openai.py
@@ -31,7 +31,8 @@ class OpenAILlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+ from langchain.callbacks.streaming_stdout import \
+ StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:
diff --git a/embedchain/models/__init__.py b/embedchain/models/__init__.py
index fc073230..48887545 100644
--- a/embedchain/models/__init__.py
+++ b/embedchain/models/__init__.py
@@ -1,4 +1,3 @@
from .embedding_functions import EmbeddingFunctions # noqa: F401
from .providers import Providers # noqa: F401
-from .vector_databases import VectorDatabases # noqa: F401
from .vector_dimensions import VectorDimensions # noqa: F401
diff --git a/embedchain/models/vector_databases.py b/embedchain/models/vector_databases.py
deleted file mode 100644
index 30f2c635..00000000
--- a/embedchain/models/vector_databases.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from enum import Enum
-
-
-class VectorDatabases(Enum):
- CHROMADB = "CHROMADB"
- ELASTICSEARCH = "ELASTICSEARCH"
- OPENSEARCH = "OPENSEARCH"
- ZILLIZ = "ZILLIZ"
diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py
new file mode 100644
index 00000000..477fa58c
--- /dev/null
+++ b/embedchain/vectordb/qdrant.py
@@ -0,0 +1,213 @@
+import copy
+import os
+import uuid
+from typing import Dict, List, Optional
+
+try:
+ from qdrant_client import QdrantClient
+ from qdrant_client.http import models
+ from qdrant_client.http.models import Batch
+ from qdrant_client.models import Distance, VectorParams
+except ImportError:
+ raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
+
+from embedchain.config.vectordb.qdrant import QdrantDBConfig
+from embedchain.vectordb.base import BaseVectorDB
+
+
+class QdrantDB(BaseVectorDB):
+ """
+ Qdrant as vector database
+ """
+
+ BATCH_SIZE = 10
+
+ def __init__(self, config: QdrantDBConfig = None):
+ """
+ Qdrant as vector database
+ :param config. Qdrant database config to be used for connection
+ """
+ if config is None:
+ config = QdrantDBConfig()
+ else:
+ if not isinstance(config, QdrantDBConfig):
+ raise TypeError(
+ "config is not a `QdrantDBConfig` instance. "
+ "Please make sure the type is right and that you are passing an instance."
+ )
+ self.config = config
+ self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
+ # Call parent init here because embedder is needed
+ super().__init__(config=self.config)
+
+ def _initialize(self):
+ """
+ This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
+ """
+ if not self.embedder:
+ raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
+
+ self.collection_name = self._get_or_create_collection()
+ self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
+ all_collections = self.client.get_collections()
+ collection_names = [collection.name for collection in all_collections.collections]
+ if self.collection_name not in collection_names:
+ self.client.recreate_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(
+ size=self.embedder.vector_dimension,
+ distance=Distance.COSINE,
+ hnsw_config=self.config.hnsw_config,
+ quantization_config=self.config.quantization_config,
+ on_disk=self.config.on_disk,
+ ),
+ )
+
+ def _get_or_create_db(self):
+ return self.client
+
+ def _get_or_create_collection(self):
+ return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
+
+ def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+ """
+ Get existing doc ids present in vector database
+
+ :param ids: _list of doc ids to check for existence
+ :type ids: List[str]
+ :param where: to filter data
+ :type where: Dict[str, any]
+ :param limit: The number of entries to be fetched
+ :type limit: Optional int, defaults to None
+ :return: All the existing IDs
+ :rtype: Set[str]
+ """
+ if ids is None or len(ids) == 0:
+ return {"ids": []}
+
+ keys = set(where.keys() if where is not None else set())
+
+ qdrant_must_filters = [
+ models.FieldCondition(
+ key="identifier",
+ match=models.MatchAny(
+ any=ids,
+ ),
+ )
+ ]
+ if len(keys.intersection(self.metadata_keys)) != 0:
+ for key in keys.intersection(self.metadata_keys):
+ qdrant_must_filters.append(
+ models.FieldCondition(
+ key="metadata.{}".format(key),
+ match=models.MatchValue(
+ value=where.get(key),
+ ),
+ )
+ )
+
+ offset = 0
+ existing_ids = []
+ while offset is not None:
+ response = self.client.scroll(
+ collection_name=self.collection_name,
+ scroll_filter=models.Filter(must=qdrant_must_filters),
+ offset=offset,
+ limit=self.BATCH_SIZE,
+ )
+ offset = response[1]
+ for doc in response[0]:
+ existing_ids.append(doc.payload["identifier"])
+ return {"ids": existing_ids}
+
+ def add(
+ self,
+ embeddings: List[List[float]],
+ documents: List[str],
+ metadatas: List[object],
+ ids: List[str],
+ skip_embedding: bool,
+ ):
+ """add data in vector database
+ :param embeddings: list of embeddings for the corresponding documents to be added
+ :type documents: List[List[float]]
+ :param documents: list of texts to add
+ :type documents: List[str]
+ :param metadatas: list of metadata associated with docs
+ :type metadatas: List[object]
+ :param ids: ids of docs
+ :type ids: List[str]
+ :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
+ generated or not
+ :type skip_embedding: bool
+ """
+ if not skip_embedding:
+ embeddings = self.embedder.embedding_fn(documents)
+
+ payloads = []
+ qdrant_ids = []
+ for id, document, metadata in zip(ids, documents, metadatas):
+ metadata["text"] = document
+ qdrant_ids.append(str(uuid.uuid4()))
+ payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
+ for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
+ self.client.upsert(
+ collection_name=self.collection_name,
+ points=Batch(
+ ids=qdrant_ids[i : i + self.BATCH_SIZE],
+ payloads=payloads[i : i + self.BATCH_SIZE],
+ vectors=embeddings[i : i + self.BATCH_SIZE],
+ ),
+ )
+
+ def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+ """
+ query contents from vector database based on vector similarity
+ :param input_query: list of query string
+ :type input_query: List[str]
+ :param n_results: no of similar documents to fetch from database
+ :type n_results: int
+ :param where: Optional. to filter data
+ :type where: Dict[str, any]
+ :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
+ generated or not
+ :type skip_embedding: bool
+ :return: Database contents that are the result of the query
+ :rtype: List[str]
+ """
+ if not skip_embedding:
+ query_vector = self.embedder.embedding_fn([input_query])[0]
+ else:
+ query_vector = input_query
+
+ keys = set(where.keys() if where is not None else set())
+
+ qdrant_must_filters = []
+ if len(keys.intersection(self.metadata_keys)) != 0:
+ for key in keys.intersection(self.metadata_keys):
+ qdrant_must_filters.append(
+ models.FieldCondition(
+ key="payload.metadata.{}".format(key),
+ match=models.MatchValue(
+ value=where.get(key),
+ ),
+ )
+ )
+ results = self.client.search(
+ collection_name=self.collection_name,
+ query_filter=models.Filter(must=qdrant_must_filters),
+ query_vector=query_vector,
+ limit=n_results,
+ )
+ response = []
+ for result in results:
+ response.append(result.payload.get("text", ""))
+ return response
+
+ def count(self) -> int:
+ response = self.client.get_collection(collection_name=self.collection_name)
+ return response.points_count
+
+ def reset(self):
+ self.client.delete_collection(collection_name=self.collection_name)
+ self._initialize()
diff --git a/pyproject.toml b/pyproject.toml
index ca9175e4..85d78916 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -85,7 +85,7 @@ exclude = '''
color = true
[tool.poetry.dependencies]
-python = ">=3.9,<3.9.7 || >3.9.7,<4.0"
+python = ">=3.9,<3.13"
python-dotenv = "^1.0.0"
langchain = "^0.0.279"
requests = "^2.31.0"
@@ -114,6 +114,7 @@ cohere = { version = "^4.27", optional= true }
weaviate-client = { version = "^3.24.1", optional= true }
docx2txt = { version="^0.8", optional=true }
pinecone-client = { version = "^2.2.4", optional = true }
+qdrant-client = { version = "1.6.3", optional = true }
unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true}
pillow = { version = "10.0.1", optional = true }
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
@@ -151,6 +152,7 @@ slack = ["slack-sdk", "flask"]
whatsapp = ["twilio", "flask"]
weaviate = ["weaviate-client"]
pinecone = ["pinecone-client"]
+qdrant = ["qdrant-client"]
images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
huggingface_hub=["huggingface_hub"]
cohere = ["cohere"]
diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py
index c740e91a..ddbc4747 100644
--- a/tests/llm/test_base_llm.py
+++ b/tests/llm/test_base_llm.py
@@ -1,5 +1,7 @@
-import pytest
from string import Template
+
+import pytest
+
from embedchain.llm.base import BaseLlm, BaseLlmConfig
diff --git a/tests/llm/test_cohere.py b/tests/llm/test_cohere.py
index 5d1a625d..1bee4cff 100644
--- a/tests/llm/test_cohere.py
+++ b/tests/llm/test_cohere.py
@@ -1,4 +1,5 @@
import os
+
import pytest
from embedchain.config import BaseLlmConfig
diff --git a/tests/llm/test_huggingface.py b/tests/llm/test_huggingface.py
index a8a7a646..c43b099e 100644
--- a/tests/llm/test_huggingface.py
+++ b/tests/llm/test_huggingface.py
@@ -1,6 +1,8 @@
import importlib
import os
+
import pytest
+
from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm
diff --git a/tests/llm/test_jina.py b/tests/llm/test_jina.py
index 9ca3f647..4639c410 100644
--- a/tests/llm/test_jina.py
+++ b/tests/llm/test_jina.py
@@ -1,8 +1,10 @@
import os
+
import pytest
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm
-from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
diff --git a/tests/llm/test_llama2.py b/tests/llm/test_llama2.py
index 688149b1..40885fd2 100644
--- a/tests/llm/test_llama2.py
+++ b/tests/llm/test_llama2.py
@@ -1,5 +1,7 @@
import os
+
import pytest
+
from embedchain.llm.llama2 import Llama2Llm
diff --git a/tests/llm/test_openai.py b/tests/llm/test_openai.py
index a1795a6c..fc823337 100644
--- a/tests/llm/test_openai.py
+++ b/tests/llm/test_openai.py
@@ -1,8 +1,10 @@
import os
+
import pytest
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
-from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
diff --git a/tests/llm/test_query.py b/tests/llm/test_query.py
index b208e00c..9ebbecd4 100644
--- a/tests/llm/test_query.py
+++ b/tests/llm/test_query.py
@@ -1,6 +1,8 @@
import os
-import pytest
from unittest.mock import MagicMock, patch
+
+import pytest
+
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
diff --git a/tests/models/test_data_type.py b/tests/models/test_data_type.py
index f0baa588..bf3d6e1e 100644
--- a/tests/models/test_data_type.py
+++ b/tests/models/test_data_type.py
@@ -1,4 +1,5 @@
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+ IndirectDataType, SpecialDataType)
def test_subclass_types_in_data_type():
diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py
new file mode 100644
index 00000000..47b54504
--- /dev/null
+++ b/tests/vectordb/test_qdrant.py
@@ -0,0 +1,158 @@
+import unittest
+import uuid
+
+from mock import patch
+from qdrant_client.http import models
+from qdrant_client.http.models import Batch
+
+from embedchain import App
+from embedchain.config import AppConfig
+from embedchain.config.vectordb.pinecone import PineconeDBConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.vectordb.qdrant import QdrantDB
+
+
+class TestQdrantDB(unittest.TestCase):
+ TEST_UUIDS = ["abc", "def", "ghi"]
+
+ def test_incorrect_config_throws_error(self):
+ """Test the init method of the Qdrant class throws error for incorrect config"""
+ with self.assertRaises(TypeError):
+ QdrantDB(config=PineconeDBConfig())
+
+ @patch("embedchain.vectordb.qdrant.QdrantClient")
+ def test_initialize(self, qdrant_client_mock):
+ # Set the embedder
+ embedder = BaseEmbedder()
+ embedder.set_vector_dimension(1526)
+
+ # Create a Qdrant instance
+ db = QdrantDB()
+ app_config = AppConfig(collect_metrics=False)
+ App(config=app_config, db=db, embedder=embedder)
+
+ self.assertEqual(db.collection_name, "embedchain-store-1526")
+ self.assertEqual(db.client, qdrant_client_mock.return_value)
+ qdrant_client_mock.return_value.get_collections.assert_called_once()
+
+ @patch("embedchain.vectordb.qdrant.QdrantClient")
+ def test_get(self, qdrant_client_mock):
+ qdrant_client_mock.return_value.scroll.return_value = ([], None)
+
+ # Set the embedder
+ embedder = BaseEmbedder()
+ embedder.set_vector_dimension(1526)
+
+ # Create a Qdrant instance
+ db = QdrantDB()
+ app_config = AppConfig(collect_metrics=False)
+ App(config=app_config, db=db, embedder=embedder)
+
+ resp = db.get(ids=[], where={})
+ self.assertEqual(resp, {"ids": []})
+ resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
+ self.assertEqual(resp2, {"ids": []})
+
+ @patch("embedchain.vectordb.qdrant.QdrantClient")
+ @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
+ def test_add(self, uuid_mock, qdrant_client_mock):
+ qdrant_client_mock.return_value.scroll.return_value = ([], None)
+
+ # Set the embedder
+ embedder = BaseEmbedder()
+ embedder.set_vector_dimension(1526)
+
+ # Create a Qdrant instance
+ db = QdrantDB()
+ app_config = AppConfig(collect_metrics=False)
+ App(config=app_config, db=db, embedder=embedder)
+
+ embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
+ documents = ["This is a test document.", "This is another test document."]
+ metadatas = [{}, {}]
+ ids = ["123", "456"]
+ skip_embedding = True
+ db.add(embeddings, documents, metadatas, ids, skip_embedding)
+ qdrant_client_mock.return_value.upsert.assert_called_once_with(
+ collection_name="embedchain-store-1526",
+ points=Batch(
+ ids=["def", "ghi"],
+ payloads=[
+ {
+ "identifier": "123",
+ "text": "This is a test document.",
+ "metadata": {"text": "This is a test document."},
+ },
+ {
+ "identifier": "456",
+ "text": "This is another test document.",
+ "metadata": {"text": "This is another test document."},
+ },
+ ],
+ vectors=embeddings,
+ ),
+ )
+
+ @patch("embedchain.vectordb.qdrant.QdrantClient")
+ def test_query(self, qdrant_client_mock):
+ # Set the embedder
+ embedder = BaseEmbedder()
+ embedder.set_vector_dimension(1526)
+
+ # Create a Qdrant instance
+ db = QdrantDB()
+ app_config = AppConfig(collect_metrics=False)
+ App(config=app_config, db=db, embedder=embedder)
+
+ # Query for the document.
+ db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
+
+ qdrant_client_mock.return_value.search.assert_called_once_with(
+ collection_name="embedchain-store-1526",
+ query_filter=models.Filter(
+ must=[
+ models.FieldCondition(
+ key="payload.metadata.doc_id",
+ match=models.MatchValue(
+ value="123",
+ ),
+ )
+ ]
+ ),
+ query_vector=["This is a test document."],
+ limit=1,
+ )
+
+ @patch("embedchain.vectordb.qdrant.QdrantClient")
+ def test_count(self, qdrant_client_mock):
+ # Set the embedder
+ embedder = BaseEmbedder()
+ embedder.set_vector_dimension(1526)
+
+ # Create a Qdrant instance
+ db = QdrantDB()
+ app_config = AppConfig(collect_metrics=False)
+ App(config=app_config, db=db, embedder=embedder)
+
+ db.count()
+ qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
+
+ @patch("embedchain.vectordb.qdrant.QdrantClient")
+ def test_reset(self, qdrant_client_mock):
+ # Set the embedder
+ embedder = BaseEmbedder()
+ embedder.set_vector_dimension(1526)
+
+ # Create a Qdrant instance
+ db = QdrantDB()
+ app_config = AppConfig(collect_metrics=False)
+ App(config=app_config, db=db, embedder=embedder)
+
+ db.reset()
+ qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
+ collection_name="embedchain-store-1526"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()