[Feature] Add Qdrant support (#822)

This commit is contained in:
Rupesh Bansal
2023-10-19 02:57:57 +05:30
committed by GitHub
parent 7641cba01d
commit c8846e0e93
17 changed files with 460 additions and 18 deletions

View File

@@ -0,0 +1,44 @@
from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class QdrantDBConfig(BaseVectorDbConfig):
"""
Config to initialize an qdrant client.
:param url. qdrant url or list of nodes url to be used for connection
"""
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
hnsw_config: Optional[Dict[str, any]] = None,
quantization_config: Optional[Dict[str, any]] = None,
on_disk: Optional[bool] = None,
**extra_params: Dict[str, any],
):
"""
Initializes a configuration class instance for a qdrant client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param hnsw_config: Params for HNSW index
:type hnsw_config: Optional[Dict[str, any]], defaults to None
:param quantization_config: Params for quantization, if None - quantization will be disabled
:type quantization_config: Optional[Dict[str, any]], defaults to None
:param on_disk: If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
:type on_disk: bool, optional, defaults to None
"""
self.hnsw_config = hnsw_config
self.quantization_config = quantization_config
self.on_disk = on_disk
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -73,6 +73,7 @@ class VectorDBFactory:
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
@@ -80,6 +81,7 @@ class VectorDBFactory:
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
}
@classmethod

View File

@@ -31,7 +31,8 @@ class OpenAILlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:

View File

@@ -1,4 +1,3 @@
from .embedding_functions import EmbeddingFunctions # noqa: F401
from .providers import Providers # noqa: F401
from .vector_databases import VectorDatabases # noqa: F401
from .vector_dimensions import VectorDimensions # noqa: F401

View File

@@ -1,8 +0,0 @@
from enum import Enum
class VectorDatabases(Enum):
CHROMADB = "CHROMADB"
ELASTICSEARCH = "ELASTICSEARCH"
OPENSEARCH = "OPENSEARCH"
ZILLIZ = "ZILLIZ"

View File

@@ -0,0 +1,213 @@
import copy
import os
import uuid
from typing import Dict, List, Optional
try:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Batch
from qdrant_client.models import Distance, VectorParams
except ImportError:
raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
from embedchain.config.vectordb.qdrant import QdrantDBConfig
from embedchain.vectordb.base import BaseVectorDB
class QdrantDB(BaseVectorDB):
"""
Qdrant as vector database
"""
BATCH_SIZE = 10
def __init__(self, config: QdrantDBConfig = None):
"""
Qdrant as vector database
:param config. Qdrant database config to be used for connection
"""
if config is None:
config = QdrantDBConfig()
else:
if not isinstance(config, QdrantDBConfig):
raise TypeError(
"config is not a `QdrantDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
# Call parent init here because embedder is needed
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
if not self.embedder:
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self.collection_name = self._get_or_create_collection()
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
all_collections = self.client.get_collections()
collection_names = [collection.name for collection in all_collections.collections]
if self.collection_name not in collection_names:
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedder.vector_dimension,
distance=Distance.COSINE,
hnsw_config=self.config.hnsw_config,
quantization_config=self.config.quantization_config,
on_disk=self.config.on_disk,
),
)
def _get_or_create_db(self):
return self.client
def _get_or_create_collection(self):
return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:param where: to filter data
:type where: Dict[str, any]
:param limit: The number of entries to be fetched
:type limit: Optional int, defaults to None
:return: All the existing IDs
:rtype: Set[str]
"""
if ids is None or len(ids) == 0:
return {"ids": []}
keys = set(where.keys() if where is not None else set())
qdrant_must_filters = [
models.FieldCondition(
key="identifier",
match=models.MatchAny(
any=ids,
),
)
]
if len(keys.intersection(self.metadata_keys)) != 0:
for key in keys.intersection(self.metadata_keys):
qdrant_must_filters.append(
models.FieldCondition(
key="metadata.{}".format(key),
match=models.MatchValue(
value=where.get(key),
),
)
)
offset = 0
existing_ids = []
while offset is not None:
response = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(must=qdrant_must_filters),
offset=offset,
limit=self.BATCH_SIZE,
)
offset = response[1]
for doc in response[0]:
existing_ids.append(doc.payload["identifier"])
return {"ids": existing_ids}
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
:type documents: List[List[float]]
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not
:type skip_embedding: bool
"""
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
payloads = []
qdrant_ids = []
for id, document, metadata in zip(ids, documents, metadatas):
metadata["text"] = document
qdrant_ids.append(str(uuid.uuid4()))
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
self.client.upsert(
collection_name=self.collection_name,
points=Batch(
ids=qdrant_ids[i : i + self.BATCH_SIZE],
payloads=payloads[i : i + self.BATCH_SIZE],
vectors=embeddings[i : i + self.BATCH_SIZE],
),
)
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
"""
if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0]
else:
query_vector = input_query
keys = set(where.keys() if where is not None else set())
qdrant_must_filters = []
if len(keys.intersection(self.metadata_keys)) != 0:
for key in keys.intersection(self.metadata_keys):
qdrant_must_filters.append(
models.FieldCondition(
key="payload.metadata.{}".format(key),
match=models.MatchValue(
value=where.get(key),
),
)
)
results = self.client.search(
collection_name=self.collection_name,
query_filter=models.Filter(must=qdrant_must_filters),
query_vector=query_vector,
limit=n_results,
)
response = []
for result in results:
response.append(result.payload.get("text", ""))
return response
def count(self) -> int:
response = self.client.get_collection(collection_name=self.collection_name)
return response.points_count
def reset(self):
self.client.delete_collection(collection_name=self.collection_name)
self._initialize()