[Feature] Add Qdrant support (#822)

This commit is contained in:
Rupesh Bansal
2023-10-19 02:57:57 +05:30
committed by GitHub
parent 7641cba01d
commit c8846e0e93
17 changed files with 460 additions and 18 deletions

View File

@@ -183,12 +183,29 @@ vectordb:
## Qdrant ## Qdrant
_Coming soon_ In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).
<CodeGroup>
```python main.py
from embedchain import App
# load qdrant configuration from yaml file
app = App.from_config(yaml_path="config.yaml")
```
```yaml config.yaml
vectordb:
provider: qdrant
config:
collection_name: my_qdrant_index
```
</CodeGroup>
## Weaviate ## Weaviate
In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard). In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard).
<CodeGroup>
```python main.py ```python main.py
from embedchain import App from embedchain import App
@@ -202,6 +219,6 @@ vectordb:
config: config:
collection_name: my_weaviate_index collection_name: my_weaviate_index
``` ```
</CodeGroup>
<Snippet file="missing-vector-db-tip.mdx" /> <Snippet file="missing-vector-db-tip.mdx" />

View File

@@ -0,0 +1,44 @@
from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class QdrantDBConfig(BaseVectorDbConfig):
"""
Config to initialize an qdrant client.
:param url. qdrant url or list of nodes url to be used for connection
"""
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
hnsw_config: Optional[Dict[str, any]] = None,
quantization_config: Optional[Dict[str, any]] = None,
on_disk: Optional[bool] = None,
**extra_params: Dict[str, any],
):
"""
Initializes a configuration class instance for a qdrant client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param hnsw_config: Params for HNSW index
:type hnsw_config: Optional[Dict[str, any]], defaults to None
:param quantization_config: Params for quantization, if None - quantization will be disabled
:type quantization_config: Optional[Dict[str, any]], defaults to None
:param on_disk: If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
:type on_disk: bool, optional, defaults to None
"""
self.hnsw_config = hnsw_config
self.quantization_config = quantization_config
self.on_disk = on_disk
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -73,6 +73,7 @@ class VectorDBFactory:
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB", "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB", "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
} }
provider_to_config_class = { provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
@@ -80,6 +81,7 @@ class VectorDBFactory:
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
} }
@classmethod @classmethod

View File

@@ -31,7 +31,8 @@ class OpenAILlm(BaseLlm):
if config.top_p: if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream: if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else: else:

View File

@@ -1,4 +1,3 @@
from .embedding_functions import EmbeddingFunctions # noqa: F401 from .embedding_functions import EmbeddingFunctions # noqa: F401
from .providers import Providers # noqa: F401 from .providers import Providers # noqa: F401
from .vector_databases import VectorDatabases # noqa: F401
from .vector_dimensions import VectorDimensions # noqa: F401 from .vector_dimensions import VectorDimensions # noqa: F401

View File

@@ -1,8 +0,0 @@
from enum import Enum
class VectorDatabases(Enum):
CHROMADB = "CHROMADB"
ELASTICSEARCH = "ELASTICSEARCH"
OPENSEARCH = "OPENSEARCH"
ZILLIZ = "ZILLIZ"

View File

@@ -0,0 +1,213 @@
import copy
import os
import uuid
from typing import Dict, List, Optional
try:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Batch
from qdrant_client.models import Distance, VectorParams
except ImportError:
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
from embedchain.config.vectordb.qdrant import QdrantDBConfig
from embedchain.vectordb.base import BaseVectorDB
class QdrantDB(BaseVectorDB):
"""
Qdrant as vector database
"""
BATCH_SIZE = 10
def __init__(self, config: QdrantDBConfig = None):
"""
Qdrant as vector database
:param config. Qdrant database config to be used for connection
"""
if config is None:
config = QdrantDBConfig()
else:
if not isinstance(config, QdrantDBConfig):
raise TypeError(
"config is not a `QdrantDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
# Call parent init here because embedder is needed
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
if not self.embedder:
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self.collection_name = self._get_or_create_collection()
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
all_collections = self.client.get_collections()
collection_names = [collection.name for collection in all_collections.collections]
if self.collection_name not in collection_names:
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedder.vector_dimension,
distance=Distance.COSINE,
hnsw_config=self.config.hnsw_config,
quantization_config=self.config.quantization_config,
on_disk=self.config.on_disk,
),
)
def _get_or_create_db(self):
return self.client
def _get_or_create_collection(self):
return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:param where: to filter data
:type where: Dict[str, any]
:param limit: The number of entries to be fetched
:type limit: Optional int, defaults to None
:return: All the existing IDs
:rtype: Set[str]
"""
if ids is None or len(ids) == 0:
return {"ids": []}
keys = set(where.keys() if where is not None else set())
qdrant_must_filters = [
models.FieldCondition(
key="identifier",
match=models.MatchAny(
any=ids,
),
)
]
if len(keys.intersection(self.metadata_keys)) != 0:
for key in keys.intersection(self.metadata_keys):
qdrant_must_filters.append(
models.FieldCondition(
key="metadata.{}".format(key),
match=models.MatchValue(
value=where.get(key),
),
)
)
offset = 0
existing_ids = []
while offset is not None:
response = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(must=qdrant_must_filters),
offset=offset,
limit=self.BATCH_SIZE,
)
offset = response[1]
for doc in response[0]:
existing_ids.append(doc.payload["identifier"])
return {"ids": existing_ids}
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
:type documents: List[List[float]]
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not
:type skip_embedding: bool
"""
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
payloads = []
qdrant_ids = []
for id, document, metadata in zip(ids, documents, metadatas):
metadata["text"] = document
qdrant_ids.append(str(uuid.uuid4()))
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
self.client.upsert(
collection_name=self.collection_name,
points=Batch(
ids=qdrant_ids[i : i + self.BATCH_SIZE],
payloads=payloads[i : i + self.BATCH_SIZE],
vectors=embeddings[i : i + self.BATCH_SIZE],
),
)
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
"""
if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0]
else:
query_vector = input_query
keys = set(where.keys() if where is not None else set())
qdrant_must_filters = []
if len(keys.intersection(self.metadata_keys)) != 0:
for key in keys.intersection(self.metadata_keys):
qdrant_must_filters.append(
models.FieldCondition(
key="payload.metadata.{}".format(key),
match=models.MatchValue(
value=where.get(key),
),
)
)
results = self.client.search(
collection_name=self.collection_name,
query_filter=models.Filter(must=qdrant_must_filters),
query_vector=query_vector,
limit=n_results,
)
response = []
for result in results:
response.append(result.payload.get("text", ""))
return response
def count(self) -> int:
response = self.client.get_collection(collection_name=self.collection_name)
return response.points_count
def reset(self):
self.client.delete_collection(collection_name=self.collection_name)
self._initialize()

View File

@@ -85,7 +85,7 @@ exclude = '''
color = true color = true
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<3.9.7 || >3.9.7,<4.0" python = ">=3.9,<3.13"
python-dotenv = "^1.0.0" python-dotenv = "^1.0.0"
langchain = "^0.0.279" langchain = "^0.0.279"
requests = "^2.31.0" requests = "^2.31.0"
@@ -114,6 +114,7 @@ cohere = { version = "^4.27", optional= true }
weaviate-client = { version = "^3.24.1", optional= true } weaviate-client = { version = "^3.24.1", optional= true }
docx2txt = { version="^0.8", optional=true } docx2txt = { version="^0.8", optional=true }
pinecone-client = { version = "^2.2.4", optional = true } pinecone-client = { version = "^2.2.4", optional = true }
qdrant-client = { version = "1.6.3", optional = true }
unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true} unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true}
pillow = { version = "10.0.1", optional = true } pillow = { version = "10.0.1", optional = true }
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true } torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
@@ -151,6 +152,7 @@ slack = ["slack-sdk", "flask"]
whatsapp = ["twilio", "flask"] whatsapp = ["twilio", "flask"]
weaviate = ["weaviate-client"] weaviate = ["weaviate-client"]
pinecone = ["pinecone-client"] pinecone = ["pinecone-client"]
qdrant = ["qdrant-client"]
images = ["torch", "ftfy", "regex", "pillow", "torchvision"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
huggingface_hub=["huggingface_hub"] huggingface_hub=["huggingface_hub"]
cohere = ["cohere"] cohere = ["cohere"]

View File

@@ -1,5 +1,7 @@
import pytest
from string import Template from string import Template
import pytest
from embedchain.llm.base import BaseLlm, BaseLlmConfig from embedchain.llm.base import BaseLlm, BaseLlmConfig

View File

@@ -1,4 +1,5 @@
import os import os
import pytest import pytest
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig

View File

@@ -1,6 +1,8 @@
import importlib import importlib
import os import os
import pytest import pytest
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm from embedchain.llm.huggingface import HuggingFaceLlm

View File

@@ -1,8 +1,10 @@
import os import os
import pytest import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm from embedchain.llm.jina import JinaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture @pytest.fixture

View File

@@ -1,5 +1,7 @@
import os import os
import pytest import pytest
from embedchain.llm.llama2 import Llama2Llm from embedchain.llm.llama2 import Llama2Llm

View File

@@ -1,8 +1,10 @@
import os import os
import pytest import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture @pytest.fixture

View File

@@ -1,6 +1,8 @@
import os import os
import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig

View File

@@ -1,4 +1,5 @@
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
def test_subclass_types_in_data_type(): def test_subclass_types_in_data_type():

View File

@@ -0,0 +1,158 @@
import unittest
import uuid
from mock import patch
from qdrant_client.http import models
from qdrant_client.http.models import Batch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.qdrant import QdrantDB
class TestQdrantDB(unittest.TestCase):
TEST_UUIDS = ["abc", "def", "ghi"]
def test_incorrect_config_throws_error(self):
"""Test the init method of the Qdrant class throws error for incorrect config"""
with self.assertRaises(TypeError):
QdrantDB(config=PineconeDBConfig())
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_initialize(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
self.assertEqual(db.collection_name, "embedchain-store-1526")
self.assertEqual(db.client, qdrant_client_mock.return_value)
qdrant_client_mock.return_value.get_collections.assert_called_once()
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_get(self, qdrant_client_mock):
qdrant_client_mock.return_value.scroll.return_value = ([], None)
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
resp = db.get(ids=[], where={})
self.assertEqual(resp, {"ids": []})
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
self.assertEqual(resp2, {"ids": []})
@patch("embedchain.vectordb.qdrant.QdrantClient")
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
def test_add(self, uuid_mock, qdrant_client_mock):
qdrant_client_mock.return_value.scroll.return_value = ([], None)
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
documents = ["This is a test document.", "This is another test document."]
metadatas = [{}, {}]
ids = ["123", "456"]
skip_embedding = True
db.add(embeddings, documents, metadatas, ids, skip_embedding)
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
points=Batch(
ids=["def", "ghi"],
payloads=[
{
"identifier": "123",
"text": "This is a test document.",
"metadata": {"text": "This is a test document."},
},
{
"identifier": "456",
"text": "This is another test document.",
"metadata": {"text": "This is another test document."},
},
],
vectors=embeddings,
),
)
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_query(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
qdrant_client_mock.return_value.search.assert_called_once_with(
collection_name="embedchain-store-1526",
query_filter=models.Filter(
must=[
models.FieldCondition(
key="payload.metadata.doc_id",
match=models.MatchValue(
value="123",
),
)
]
),
query_vector=["This is a test document."],
limit=1,
)
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_count(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
db.count()
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_reset(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
db.reset()
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
collection_name="embedchain-store-1526"
)
if __name__ == "__main__":
unittest.main()