Refactoring vectordb naming convention in embedchain.config (#1469)

This commit is contained in:
Vatsal Rathod
2024-07-08 19:01:17 -04:00
committed by GitHub
parent 1a5d0d236a
commit 83e8c97295
20 changed files with 124 additions and 124 deletions

View File

@@ -0,0 +1,36 @@
from typing import Optional
from embedchain.config.base_config import BaseConfig
class BaseVectorDbConfig(BaseConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: str = "db",
host: Optional[str] = None,
port: Optional[str] = None,
**kwargs,
):
"""
Initializes a configuration class instance for the vector database.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to "db"
:type dir: str, optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
"""
self.collection_name = collection_name or "embedchain_store"
self.dir = dir
self.host = host
self.port = port
# Assign additional keyword arguments
if kwargs:
for key, value in kwargs.items():
setattr(self, key, value)

View File

@@ -0,0 +1,41 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ChromaDbConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
batch_size: Optional[int] = 100,
allow_reset=False,
chroma_settings: Optional[dict] = None,
):
"""
Initializes a configuration class instance for ChromaDB.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
:param chroma_settings: Chroma settings dict, defaults to None
:type chroma_settings: Optional[dict], optional
"""
self.chroma_settings = chroma_settings
self.allow_reset = allow_reset
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -0,0 +1,56 @@
import os
from typing import Optional, Union
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,33 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class LanceDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
allow_reset=True,
):
"""
Initializes a configuration class instance for LanceDB.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
"""
self.allow_reset = allow_reset
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -0,0 +1,41 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,47 @@
import os
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
index_name: Optional[str] = None,
api_key: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None,
serverless_config: Optional[dict[str, any]] = None,
hybrid_search: bool = False,
bm25_encoder: any = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.metric = metric
self.api_key = api_key
self.index_name = index_name
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.hybrid_search = hybrid_search
self.bm25_encoder = bm25_encoder
self.batch_size = batch_size
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
else:
self.pod_config = pod_config
self.serverless_config = serverless_config
if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.")
if self.hybrid_search and self.metric != "dotproduct":
raise ValueError(
"Hybrid search is only supported with dotproduct metric in Pinecone. See full docs here: https://docs.pinecone.io/docs/hybrid-search#limitations"
) # noqa:E501
super().__init__(collection_name=self.index_name, dir=None)

View File

@@ -0,0 +1,48 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class QdrantDBConfig(BaseVectorDbConfig):
"""
Config to initialize a qdrant client.
:param: url. qdrant url or list of nodes url to be used for connection
"""
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
hnsw_config: Optional[dict[str, any]] = None,
quantization_config: Optional[dict[str, any]] = None,
on_disk: Optional[bool] = None,
batch_size: Optional[int] = 10,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for a qdrant client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param hnsw_config: Params for HNSW index
:type hnsw_config: Optional[dict[str, any]], defaults to None
:param quantization_config: Params for quantization, if None - quantization will be disabled
:type quantization_config: Optional[dict[str, any]], defaults to None
:param on_disk: If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
:type on_disk: bool, optional, defaults to None
:param batch_size: Number of items to insert in one batch, defaults to 10
:type batch_size: Optional[int], optional
"""
self.hnsw_config = hnsw_config
self.quantization_config = quantization_config
self.on_disk = on_disk
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,18 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class WeaviateDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,49 @@
import os
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ZillizDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
uri: Optional[str] = None,
token: Optional[str] = None,
vector_dim: Optional[str] = None,
metric_type: Optional[str] = None,
):
"""
Initializes a configuration class instance for the vector database.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to "db"
:type dir: str, optional
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
:type uri: Optional[str], optional
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
:type token: Optional[str], optional
"""
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
if not self.uri:
raise AttributeError(
"Zilliz needs a URI attribute, "
"this can either be passed to `ZILLIZ_CLOUD_URI` or as `ZILLIZ_CLOUD_URI` in `.env`"
)
self.token = token or os.environ.get("ZILLIZ_CLOUD_TOKEN")
if not self.token:
raise AttributeError(
"Zilliz needs a token attribute, "
"this can either be passed to `ZILLIZ_CLOUD_TOKEN` or as `ZILLIZ_CLOUD_TOKEN` in `.env`,"
"if having a username and password, pass it in the form 'username:password' to `ZILLIZ_CLOUD_TOKEN`"
)
self.metric_type = metric_type if metric_type else "L2"
self.vector_dim = vector_dim
super().__init__(collection_name=collection_name, dir=dir)