[Feature] Add Qdrant support (#822)
This commit is contained in:
@@ -183,12 +183,29 @@ vectordb:
|
||||
|
||||
## Qdrant
|
||||
|
||||
_Coming soon_
|
||||
In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).
|
||||
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
from embedchain import App
|
||||
|
||||
# load qdrant configuration from yaml file
|
||||
app = App.from_config(yaml_path="config.yaml")
|
||||
```
|
||||
|
||||
```yaml config.yaml
|
||||
vectordb:
|
||||
provider: qdrant
|
||||
config:
|
||||
collection_name: my_qdrant_index
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
## Weaviate
|
||||
|
||||
In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard).
|
||||
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
from embedchain import App
|
||||
|
||||
@@ -202,6 +219,6 @@ vectordb:
|
||||
config:
|
||||
collection_name: my_weaviate_index
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
<Snippet file="missing-vector-db-tip.mdx" />
|
||||
|
||||
44
embedchain/config/vectordb/qdrant.py
Normal file
44
embedchain/config/vectordb/qdrant.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QdrantDBConfig(BaseVectorDbConfig):
|
||||
"""
|
||||
Config to initialize an qdrant client.
|
||||
:param url. qdrant url or list of nodes url to be used for connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
hnsw_config: Optional[Dict[str, any]] = None,
|
||||
quantization_config: Optional[Dict[str, any]] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
**extra_params: Dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for a qdrant client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param hnsw_config: Params for HNSW index
|
||||
:type hnsw_config: Optional[Dict[str, any]], defaults to None
|
||||
:param quantization_config: Params for quantization, if None - quantization will be disabled
|
||||
:type quantization_config: Optional[Dict[str, any]], defaults to None
|
||||
:param on_disk: If true - point`s payload will not be stored in memory.
|
||||
It will be read from the disk every time it is requested.
|
||||
This setting saves RAM by (slightly) increasing the response time.
|
||||
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
|
||||
:type on_disk: bool, optional, defaults to None
|
||||
"""
|
||||
self.hnsw_config = hnsw_config
|
||||
self.quantization_config = quantization_config
|
||||
self.on_disk = on_disk
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
@@ -73,6 +73,7 @@ class VectorDBFactory:
|
||||
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
||||
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
|
||||
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
|
||||
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||
@@ -80,6 +81,7 @@ class VectorDBFactory:
|
||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
||||
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
|
||||
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -31,7 +31,8 @@ class OpenAILlm(BaseLlm):
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .embedding_functions import EmbeddingFunctions # noqa: F401
|
||||
from .providers import Providers # noqa: F401
|
||||
from .vector_databases import VectorDatabases # noqa: F401
|
||||
from .vector_dimensions import VectorDimensions # noqa: F401
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class VectorDatabases(Enum):
|
||||
CHROMADB = "CHROMADB"
|
||||
ELASTICSEARCH = "ELASTICSEARCH"
|
||||
OPENSEARCH = "OPENSEARCH"
|
||||
ZILLIZ = "ZILLIZ"
|
||||
213
embedchain/vectordb/qdrant.py
Normal file
213
embedchain/vectordb/qdrant.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import copy
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Batch
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
except ImportError:
|
||||
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
|
||||
|
||||
from embedchain.config.vectordb.qdrant import QdrantDBConfig
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
class QdrantDB(BaseVectorDB):
|
||||
"""
|
||||
Qdrant as vector database
|
||||
"""
|
||||
|
||||
BATCH_SIZE = 10
|
||||
|
||||
def __init__(self, config: QdrantDBConfig = None):
|
||||
"""
|
||||
Qdrant as vector database
|
||||
:param config. Qdrant database config to be used for connection
|
||||
"""
|
||||
if config is None:
|
||||
config = QdrantDBConfig()
|
||||
else:
|
||||
if not isinstance(config, QdrantDBConfig):
|
||||
raise TypeError(
|
||||
"config is not a `QdrantDBConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
||||
# Call parent init here because embedder is needed
|
||||
super().__init__(config=self.config)
|
||||
|
||||
def _initialize(self):
|
||||
"""
|
||||
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||
"""
|
||||
if not self.embedder:
|
||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||
|
||||
self.collection_name = self._get_or_create_collection()
|
||||
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
|
||||
all_collections = self.client.get_collections()
|
||||
collection_names = [collection.name for collection in all_collections.collections]
|
||||
if self.collection_name not in collection_names:
|
||||
self.client.recreate_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self.embedder.vector_dimension,
|
||||
distance=Distance.COSINE,
|
||||
hnsw_config=self.config.hnsw_config,
|
||||
quantization_config=self.config.quantization_config,
|
||||
on_disk=self.config.on_disk,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_or_create_db(self):
|
||||
return self.client
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: _list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param limit: The number of entries to be fetched
|
||||
:type limit: Optional int, defaults to None
|
||||
:return: All the existing IDs
|
||||
:rtype: Set[str]
|
||||
"""
|
||||
if ids is None or len(ids) == 0:
|
||||
return {"ids": []}
|
||||
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
qdrant_must_filters = [
|
||||
models.FieldCondition(
|
||||
key="identifier",
|
||||
match=models.MatchAny(
|
||||
any=ids,
|
||||
),
|
||||
)
|
||||
]
|
||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||
for key in keys.intersection(self.metadata_keys):
|
||||
qdrant_must_filters.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.{}".format(key),
|
||||
match=models.MatchValue(
|
||||
value=where.get(key),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
offset = 0
|
||||
existing_ids = []
|
||||
while offset is not None:
|
||||
response = self.client.scroll(
|
||||
collection_name=self.collection_name,
|
||||
scroll_filter=models.Filter(must=qdrant_must_filters),
|
||||
offset=offset,
|
||||
limit=self.BATCH_SIZE,
|
||||
)
|
||||
offset = response[1]
|
||||
for doc in response[0]:
|
||||
existing_ids.append(doc.payload["identifier"])
|
||||
return {"ids": existing_ids}
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
):
|
||||
"""add data in vector database
|
||||
:param embeddings: list of embeddings for the corresponding documents to be added
|
||||
:type documents: List[List[float]]
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
:type metadatas: List[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
"""
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
|
||||
payloads = []
|
||||
qdrant_ids = []
|
||||
for id, document, metadata in zip(ids, documents, metadatas):
|
||||
metadata["text"] = document
|
||||
qdrant_ids.append(str(uuid.uuid4()))
|
||||
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
|
||||
for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=Batch(
|
||||
ids=qdrant_ids[i : i + self.BATCH_SIZE],
|
||||
payloads=payloads[i : i + self.BATCH_SIZE],
|
||||
vectors=embeddings[i : i + self.BATCH_SIZE],
|
||||
),
|
||||
)
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
qdrant_must_filters = []
|
||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||
for key in keys.intersection(self.metadata_keys):
|
||||
qdrant_must_filters.append(
|
||||
models.FieldCondition(
|
||||
key="payload.metadata.{}".format(key),
|
||||
match=models.MatchValue(
|
||||
value=where.get(key),
|
||||
),
|
||||
)
|
||||
)
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_filter=models.Filter(must=qdrant_must_filters),
|
||||
query_vector=query_vector,
|
||||
limit=n_results,
|
||||
)
|
||||
response = []
|
||||
for result in results:
|
||||
response.append(result.payload.get("text", ""))
|
||||
return response
|
||||
|
||||
def count(self) -> int:
|
||||
response = self.client.get_collection(collection_name=self.collection_name)
|
||||
return response.points_count
|
||||
|
||||
def reset(self):
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
self._initialize()
|
||||
@@ -85,7 +85,7 @@ exclude = '''
|
||||
color = true
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<3.9.7 || >3.9.7,<4.0"
|
||||
python = ">=3.9,<3.13"
|
||||
python-dotenv = "^1.0.0"
|
||||
langchain = "^0.0.279"
|
||||
requests = "^2.31.0"
|
||||
@@ -114,6 +114,7 @@ cohere = { version = "^4.27", optional= true }
|
||||
weaviate-client = { version = "^3.24.1", optional= true }
|
||||
docx2txt = { version="^0.8", optional=true }
|
||||
pinecone-client = { version = "^2.2.4", optional = true }
|
||||
qdrant-client = { version = "1.6.3", optional = true }
|
||||
unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true}
|
||||
pillow = { version = "10.0.1", optional = true }
|
||||
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
||||
@@ -151,6 +152,7 @@ slack = ["slack-sdk", "flask"]
|
||||
whatsapp = ["twilio", "flask"]
|
||||
weaviate = ["weaviate-client"]
|
||||
pinecone = ["pinecone-client"]
|
||||
qdrant = ["qdrant-client"]
|
||||
images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
|
||||
huggingface_hub=["huggingface_hub"]
|
||||
cohere = ["cohere"]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
from string import Template
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.base import BaseLlm, BaseLlmConfig
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.huggingface import HuggingFaceLlm
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.jina import JinaLlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.llama2 import Llama2Llm
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
|
||||
|
||||
def test_subclass_types_in_data_type():
|
||||
|
||||
158
tests/vectordb/test_qdrant.py
Normal file
158
tests/vectordb/test_qdrant.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from mock import patch
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Batch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.qdrant import QdrantDB
|
||||
|
||||
|
||||
class TestQdrantDB(unittest.TestCase):
|
||||
TEST_UUIDS = ["abc", "def", "ghi"]
|
||||
|
||||
def test_incorrect_config_throws_error(self):
|
||||
"""Test the init method of the Qdrant class throws error for incorrect config"""
|
||||
with self.assertRaises(TypeError):
|
||||
QdrantDB(config=PineconeDBConfig())
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_initialize(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
self.assertEqual(db.collection_name, "embedchain-store-1526")
|
||||
self.assertEqual(db.client, qdrant_client_mock.return_value)
|
||||
qdrant_client_mock.return_value.get_collections.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_get(self, qdrant_client_mock):
|
||||
qdrant_client_mock.return_value.scroll.return_value = ([], None)
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
resp = db.get(ids=[], where={})
|
||||
self.assertEqual(resp, {"ids": []})
|
||||
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
|
||||
self.assertEqual(resp2, {"ids": []})
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
|
||||
def test_add(self, uuid_mock, qdrant_client_mock):
|
||||
qdrant_client_mock.return_value.scroll.return_value = ([], None)
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["123", "456"]
|
||||
skip_embedding = True
|
||||
db.add(embeddings, documents, metadatas, ids, skip_embedding)
|
||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
points=Batch(
|
||||
ids=["def", "ghi"],
|
||||
payloads=[
|
||||
{
|
||||
"identifier": "123",
|
||||
"text": "This is a test document.",
|
||||
"metadata": {"text": "This is a test document."},
|
||||
},
|
||||
{
|
||||
"identifier": "456",
|
||||
"text": "This is another test document.",
|
||||
"metadata": {"text": "This is another test document."},
|
||||
},
|
||||
],
|
||||
vectors=embeddings,
|
||||
),
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_query(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
|
||||
qdrant_client_mock.return_value.search.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
query_filter=models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="payload.metadata.doc_id",
|
||||
match=models.MatchValue(
|
||||
value="123",
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
query_vector=["This is a test document."],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_count(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
db.count()
|
||||
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_reset(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1526)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
db.reset()
|
||||
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user