feature: Add support for zilliz vector database (#771)

This commit is contained in:
LuciAkirami
2023-10-12 01:47:33 +05:30
committed by GitHub
parent 16e123b7bb
commit d6ed2050d4
7 changed files with 438 additions and 0 deletions

View File

@@ -18,6 +18,9 @@ install_es:
install_opensearch: install_opensearch:
poetry install --extras opensearch poetry install --extras opensearch
install_milvus:
poetry install --extras milvus
shell: shell:
poetry shell poetry shell

View File

@@ -12,3 +12,4 @@ from .llm.base_llm_config import BaseLlmConfig as LlmConfig
from .vectordb.chroma import ChromaDbConfig from .vectordb.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig from .vectordb.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig from .vectordb.opensearch import OpenSearchDBConfig
from .vectordb.zilliz import ZillizDBConfig

View File

@@ -0,0 +1,49 @@
import os
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class ZillizDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
uri: Optional[str] = None,
token: Optional[str] = None,
vector_dim: Optional[str] = None,
metric_type: Optional[str] = None,
):
"""
Initializes a configuration class instance for the vector database.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to "db"
:type dir: str, optional
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
:type uri: Optional[str], optional
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
:type port: Optional[str], optional
"""
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
if not self.uri:
raise AttributeError(
"Zilliz needs a URI attribute, "
"this can either be passed to `ZILLIZ_CLOUD_URI` or as `ZILLIZ_CLOUD_URI` in `.env`"
)
self.token = token or os.environ.get("ZILLIZ_CLOUD_TOKEN")
if not self.token:
raise AttributeError(
"Zilliz needs a token attribute, "
"this can either be passed to `ZILLIZ_CLOUD_TOKEN` or as `ZILLIZ_CLOUD_TOKEN` in `.env`,"
"if having a username and password, pass it in the form 'username:password' to `ZILLIZ_CLOUD_TOKEN`"
)
self.metric_type = metric_type if metric_type else "L2"
self.vector_dim = vector_dim
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -5,3 +5,4 @@ class VectorDatabases(Enum):
CHROMADB = "CHROMADB" CHROMADB = "CHROMADB"
ELASTICSEARCH = "ELASTICSEARCH" ELASTICSEARCH = "ELASTICSEARCH"
OPENSEARCH = "OPENSEARCH" OPENSEARCH = "OPENSEARCH"
ZILLIZ = "ZILLIZ"

View File

@@ -0,0 +1,205 @@
from typing import Dict, List, Optional
from embedchain.config import ZillizDBConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
try:
from pymilvus import MilvusClient
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
except ImportError:
raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
) from None
@register_deserializable
class ZillizVectorDB(BaseVectorDB):
"""Base class for vector database."""
def __init__(self, config: ZillizDBConfig = None):
"""Initialize the database. Save the config and client as an attribute.
:param config: Database configuration class instance.
:type config: ZillizDBConfig
"""
if config is None:
self.config = ZillizDBConfig()
else:
self.config = config
self.client = MilvusClient(
uri=self.config.uri,
token=self.config.token,
)
self.connection = connections.connect(
uri=self.config.uri,
token=self.config.token,
)
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
So it's can't be done in __init__ in one step.
"""
self._get_or_create_collection(self.config.collection_name)
def _get_or_create_db(self):
"""Get or create the database."""
return self.client
def _get_or_create_collection(self, name):
"""
Get or create a named collection.
:param name: Name of the collection
:type name: str
"""
if utility.has_collection(name):
self.collection = Collection(name)
else:
fields = [
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.embedder.vector_dimension),
]
schema = CollectionSchema(fields, enable_dynamic_field=True)
self.collection = Collection(name=name, schema=schema)
index = {
"index_type": "AUTOINDEX",
"metric_type": self.config.metric_type,
}
self.collection.create_index("embeddings", index)
return self.collection
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence
:type ids: List[str]
:param where: Optional. to filter data
:type where: Dict[str, Any]
:param limit: Optional. maximum number of documents
:type limit: Optional[int]
:return: Existing documents.
:rtype: Set[str]
"""
if ids is None or len(ids) == 0 or self.collection.num_entities == 0:
return {"ids": []}
if not (self.collection.is_empty):
filter = f"id in {ids}"
results = self.client.query(
collection_name=self.config.collection_name, filter=filter, output_fields=["id"]
)
results = [res["id"] for res in results]
return {"ids": set(results)}
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""Add to database"""
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
self.client.insert(collection_name=self.config.collection_name, data=data)
self.collection.load()
self.collection.flush()
self.client.flush(self.config.collection_name)
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
Query contents from vector data base based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: to filter data
:type where: str
:raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query.
:rtype: List[str]
"""
if self.collection.is_empty:
return []
if not isinstance(where, str):
where = None
if skip_embedding:
query_vector = input_query
query_result = self.client.search(
collection_name=self.config.collection_name,
data=query_vector,
limit=n_results,
output_fields=["text"],
)
else:
input_query_vector = self.embedder.embedding_fn([input_query])
query_vector = input_query_vector[0]
query_result = self.client.search(
collection_name=self.config.collection_name,
data=[query_vector],
limit=n_results,
output_fields=["text"],
)
doc_list = []
for query in query_result:
doc_list.append(query[0]["entity"]["text"])
return doc_list
def count(self) -> int:
"""
Count number of documents/chunks embedded in the database.
:return: number of documents
:rtype: int
"""
return self.collection.num_entities
def reset(self, collection_names: List[str] = None):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
if self.config.collection_name:
if collection_names:
for collection_name in collection_names:
if collection_name in self.client.list_collections():
self.client.drop_collection(collection_name=collection_name)
else:
self.client.drop_collection(collection_name=self.config.collection_name)
self._get_or_create_collection(self.config.collection_name)
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name

View File

@@ -113,6 +113,7 @@ torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
ftfy = { version = "6.1.1", optional = true } ftfy = { version = "6.1.1", optional = true }
regex = { version = "2023.8.8", optional = true } regex = { version = "2023.8.8", optional = true }
huggingface_hub = { version = "^0.17.3", optional = true } huggingface_hub = { version = "^0.17.3", optional = true }
pymilvus = "2.3.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^23.3.0" black = "^23.3.0"
@@ -139,6 +140,7 @@ whatsapp = ["twilio", "flask"]
images = ["torch", "ftfy", "regex", "pillow", "torchvision"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
huggingface_hub=["huggingface_hub"] huggingface_hub=["huggingface_hub"]
cohere = ["cohere"] cohere = ["cohere"]
milvus = ["pymilvus"]
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]

View File

@@ -0,0 +1,177 @@
# ruff: noqa: E501
import os
import pytest
from unittest import mock
from unittest.mock import patch, Mock
from embedchain.config import ZillizDBConfig
from embedchain.vectordb.zilliz import ZillizVectorDB
# to run tests, provide the URI and TOKEN in .env file
class TestZillizVectorDBConfig:
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_with_uri_and_token(self):
"""
Test if the `ZillizVectorDBConfig` instance is initialized with the correct uri and token values.
"""
# Create a ZillizDBConfig instance with mocked values
expected_uri = "mocked_uri"
expected_token = "mocked_token"
db_config = ZillizDBConfig()
# Assert that the values in the ZillizVectorDB instance match the mocked values
assert db_config.uri == expected_uri
assert db_config.token == expected_token
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_without_uri(self):
"""
Test if the `ZillizVectorDBConfig` instance throws an error when no URI found.
"""
try:
del os.environ["ZILLIZ_CLOUD_URI"]
except KeyError:
pass
with pytest.raises(AttributeError):
ZillizDBConfig()
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_without_token(self):
"""
Test if the `ZillizVectorDBConfig` instance throws an error when no Token found.
"""
try:
del os.environ["ZILLIZ_CLOUD_TOKEN"]
except KeyError:
pass
# Test if an exception is raised when ZILLIZ_CLOUD_TOKEN is missing
with pytest.raises(AttributeError):
ZillizDBConfig()
class TestZillizVectorDB:
@pytest.fixture
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def mock_config(self, mocker):
return mocker.Mock(spec=ZillizDBConfig())
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections.connect", autospec=True)
def test_zilliz_vector_db_setup(self, mock_connect, mock_client, mock_config):
"""
Test if the `ZillizVectorDB` instance is initialized with the correct uri and token values.
"""
# Create an instance of ZillizVectorDB with the mock config
# zilliz_db = ZillizVectorDB(config=mock_config)
ZillizVectorDB(config=mock_config)
# Assert that the MilvusClient and connections.connect were called
mock_client.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
mock_connect.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
class TestZillizDBCollection:
@pytest.fixture
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def mock_config(self, mocker):
return mocker.Mock(spec=ZillizDBConfig())
@pytest.fixture
def mock_embedder(self, mocker):
return mocker.Mock()
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_with_default_collection(self):
"""
Test if the `ZillizVectorDB` instance is initialized with the correct default collection name.
"""
# Create a ZillizDBConfig instance
db_config = ZillizDBConfig()
assert db_config.collection_name == "embedchain_store"
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_with_custom_collection(self):
"""
Test if the `ZillizVectorDB` instance is initialized with the correct custom collection name.
"""
# Create a ZillizDBConfig instance with mocked values
expected_collection = "test_collection"
db_config = ZillizDBConfig(collection_name="test_collection")
assert db_config.collection_name == expected_collection
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
"""
Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
"""
# Create an instance of ZillizVectorDB with mock config
zilliz_db = ZillizVectorDB(config=mock_config)
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
assert zilliz_db.client == mock_client()
# Mock the MilvusClient search method
with patch.object(zilliz_db.client, "search") as mock_search:
# Mock the search result
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
# Call the query method with skip_embedding=True
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
# Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_once_with(
collection_name=mock_config.collection_name,
data=["query_text"],
limit=1,
output_fields=["text"],
)
# Assert that the query result matches the expected result
assert query_result == ["result_doc"]
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
"""
Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
"""
# Create an instance of ZillizVectorDB with mock config
zilliz_db = ZillizVectorDB(config=mock_config)
# Add a 'embedder' attribute to the ZillizVectorDB instance for testing
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
assert zilliz_db.client == mock_client()
# Mock the MilvusClient search method
with patch.object(zilliz_db.client, "search") as mock_search:
# Mock the embedding function
mock_embedder.embedding_fn.return_value = ["query_vector"]
# Mock the search result
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
# Call the query method with skip_embedding=False
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
# Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_once_with(
collection_name=mock_config.collection_name,
data=["query_vector"],
limit=1,
output_fields=["text"],
)
# Assert that the query result matches the expected result
assert query_result == ["result_doc"]