Refactoring vectordb naming convention in embedchain.config (#1469)

This commit is contained in:
Vatsal Rathod
2024-07-08 19:01:17 -04:00
committed by GitHub
parent 1a5d0d236a
commit 83e8c97295
20 changed files with 124 additions and 124 deletions

View File

@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig
from .vectordb.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig
from .vectordb.zilliz import ZillizDBConfig
from .vector_db.chroma import ChromaDbConfig
from .vector_db.elasticsearch import ElasticsearchDBConfig
from .vector_db.opensearch import OpenSearchDBConfig
from .vector_db.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config

View File

@@ -1,6 +1,6 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,56 +1,56 @@
import os
from typing import Optional, Union
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)
import os
from typing import Optional, Union
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -1,6 +1,6 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,41 +1,41 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -1,7 +1,7 @@
import os
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,6 +1,6 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,6 +1,6 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,7 +1,7 @@
import os
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -98,14 +98,14 @@ class VectorDBFactory:
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
"chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vector_db.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vector_db.opensearch.OpenSearchDBConfig",
"lancedb": "embedchain.config.vector_db.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vector_db.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig",
"zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig",
}
@classmethod

View File

@@ -1,4 +1,4 @@
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable

View File

@@ -7,7 +7,7 @@ try:
except ImportError:
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
from embedchain.config.vectordb.lancedb import LanceDBConfig
from embedchain.config.vector_db.lancedb import LanceDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -11,7 +11,7 @@ except ImportError:
from pinecone_text.sparse import BM25Encoder
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils.misc import chunks
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -12,7 +12,7 @@ except ImportError:
from tqdm import tqdm
from embedchain.config.vectordb.qdrant import QdrantDBConfig
from embedchain.config.vector_db.qdrant import QdrantDBConfig
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -9,7 +9,7 @@ except ImportError:
"Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
) from None
from embedchain.config.vectordb.weaviate import WeaviateDBConfig
from embedchain.config.vector_db.weaviate import WeaviateDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB