diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index d0bf2a0a..f497b2e6 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig from .embedder.ollama import OllamaEmbedderConfig from .llm.base import BaseLlmConfig -from .vectordb.chroma import ChromaDbConfig -from .vectordb.elasticsearch import ElasticsearchDBConfig -from .vectordb.opensearch import OpenSearchDBConfig -from .vectordb.zilliz import ZillizDBConfig +from .vector_db.chroma import ChromaDbConfig +from .vector_db.elasticsearch import ElasticsearchDBConfig +from .vector_db.opensearch import OpenSearchDBConfig +from .vector_db.zilliz import ZillizDBConfig from .mem0_config import Mem0Config diff --git a/embedchain/config/vectordb/base.py b/embedchain/config/vector_db/base.py similarity index 100% rename from embedchain/config/vectordb/base.py rename to embedchain/config/vector_db/base.py diff --git a/embedchain/config/vectordb/chroma.py b/embedchain/config/vector_db/chroma.py similarity index 96% rename from embedchain/config/vectordb/chroma.py rename to embedchain/config/vector_db/chroma.py index b99f1d94..64220165 100644 --- a/embedchain/config/vectordb/chroma.py +++ b/embedchain/config/vector_db/chroma.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/config/vectordb/elasticsearch.py b/embedchain/config/vector_db/elasticsearch.py similarity index 95% rename from embedchain/config/vectordb/elasticsearch.py rename to embedchain/config/vector_db/elasticsearch.py index 1679700a..5e8ef6b6 100644 --- a/embedchain/config/vectordb/elasticsearch.py +++ b/embedchain/config/vector_db/elasticsearch.py @@ -1,56 +1,56 @@ -import os -from typing import Optional, Union - -from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helpers.json_serializable import register_deserializable - - -@register_deserializable -class ElasticsearchDBConfig(BaseVectorDbConfig): - def __init__( - self, - collection_name: Optional[str] = None, - dir: Optional[str] = None, - es_url: Union[str, list[str]] = None, - cloud_id: Optional[str] = None, - batch_size: Optional[int] = 100, - **ES_EXTRA_PARAMS: dict[str, any], - ): - """ - Initializes a configuration class instance for an Elasticsearch client. - - :param collection_name: Default name for the collection, defaults to None - :type collection_name: Optional[str], optional - :param dir: Path to the database directory, where the database is stored, defaults to None - :type dir: Optional[str], optional - :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None - :type es_url: Union[str, list[str]], optional - :param cloud_id: cloud id of the elasticsearch cluster, defaults to None - :type cloud_id: Optional[str], optional - :param batch_size: Number of items to insert in one batch, defaults to 100 - :type batch_size: Optional[int], optional - :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. - :type ES_EXTRA_PARAMS: dict[str, Any], optional - """ - if es_url and cloud_id: - raise ValueError("Only one of `es_url` and `cloud_id` can be set.") - # self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]): - self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") - self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID") - if not self.ES_URL and not self.CLOUD_ID: - raise AttributeError( - "Elasticsearch needs a URL or CLOUD_ID attribute, " - "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501 - ) - self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS - # Load API key from .env if it's not explicitly passed. - # Can only set one of 'api_key', 'basic_auth', and 'bearer_auth' - if ( - not self.ES_EXTRA_PARAMS.get("api_key") - and not self.ES_EXTRA_PARAMS.get("basic_auth") - and not self.ES_EXTRA_PARAMS.get("bearer_auth") - ): - self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY") - - self.batch_size = batch_size - super().__init__(collection_name=collection_name, dir=dir) +import os +from typing import Optional, Union + +from embedchain.config.vector_db.base import BaseVectorDbConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class ElasticsearchDBConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + es_url: Union[str, list[str]] = None, + cloud_id: Optional[str] = None, + batch_size: Optional[int] = 100, + **ES_EXTRA_PARAMS: dict[str, any], + ): + """ + Initializes a configuration class instance for an Elasticsearch client. + + :param collection_name: Default name for the collection, defaults to None + :type collection_name: Optional[str], optional + :param dir: Path to the database directory, where the database is stored, defaults to None + :type dir: Optional[str], optional + :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None + :type es_url: Union[str, list[str]], optional + :param cloud_id: cloud id of the elasticsearch cluster, defaults to None + :type cloud_id: Optional[str], optional + :param batch_size: Number of items to insert in one batch, defaults to 100 + :type batch_size: Optional[int], optional + :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. + :type ES_EXTRA_PARAMS: dict[str, Any], optional + """ + if es_url and cloud_id: + raise ValueError("Only one of `es_url` and `cloud_id` can be set.") + # self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]): + self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") + self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID") + if not self.ES_URL and not self.CLOUD_ID: + raise AttributeError( + "Elasticsearch needs a URL or CLOUD_ID attribute, " + "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501 + ) + self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS + # Load API key from .env if it's not explicitly passed. + # Can only set one of 'api_key', 'basic_auth', and 'bearer_auth' + if ( + not self.ES_EXTRA_PARAMS.get("api_key") + and not self.ES_EXTRA_PARAMS.get("basic_auth") + and not self.ES_EXTRA_PARAMS.get("bearer_auth") + ): + self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY") + + self.batch_size = batch_size + super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/config/vectordb/lancedb.py b/embedchain/config/vector_db/lancedb.py similarity index 95% rename from embedchain/config/vectordb/lancedb.py rename to embedchain/config/vector_db/lancedb.py index 2e53ccda..08b7d0ac 100644 --- a/embedchain/config/vectordb/lancedb.py +++ b/embedchain/config/vector_db/lancedb.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/config/vectordb/opensearch.py b/embedchain/config/vector_db/opensearch.py similarity index 94% rename from embedchain/config/vectordb/opensearch.py rename to embedchain/config/vector_db/opensearch.py index 3d0b6099..5beeb8ce 100644 --- a/embedchain/config/vectordb/opensearch.py +++ b/embedchain/config/vector_db/opensearch.py @@ -1,41 +1,41 @@ -from typing import Optional - -from embedchain.config.vectordb.base import BaseVectorDbConfig -from embedchain.helpers.json_serializable import register_deserializable - - -@register_deserializable -class OpenSearchDBConfig(BaseVectorDbConfig): - def __init__( - self, - opensearch_url: str, - http_auth: tuple[str, str], - vector_dimension: int = 1536, - collection_name: Optional[str] = None, - dir: Optional[str] = None, - batch_size: Optional[int] = 100, - **extra_params: dict[str, any], - ): - """ - Initializes a configuration class instance for an OpenSearch client. - - :param collection_name: Default name for the collection, defaults to None - :type collection_name: Optional[str], optional - :param opensearch_url: URL of the OpenSearch domain - :type opensearch_url: str, Eg, "http://localhost:9200" - :param http_auth: Tuple of username and password - :type http_auth: tuple[str, str], Eg, ("username", "password") - :param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model) - :type vector_dimension: int, optional - :param dir: Path to the database directory, where the database is stored, defaults to None - :type dir: Optional[str], optional - :param batch_size: Number of items to insert in one batch, defaults to 100 - :type batch_size: Optional[int], optional - """ - self.opensearch_url = opensearch_url - self.http_auth = http_auth - self.vector_dimension = vector_dimension - self.extra_params = extra_params - self.batch_size = batch_size - - super().__init__(collection_name=collection_name, dir=dir) +from typing import Optional + +from embedchain.config.vector_db.base import BaseVectorDbConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class OpenSearchDBConfig(BaseVectorDbConfig): + def __init__( + self, + opensearch_url: str, + http_auth: tuple[str, str], + vector_dimension: int = 1536, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + batch_size: Optional[int] = 100, + **extra_params: dict[str, any], + ): + """ + Initializes a configuration class instance for an OpenSearch client. + + :param collection_name: Default name for the collection, defaults to None + :type collection_name: Optional[str], optional + :param opensearch_url: URL of the OpenSearch domain + :type opensearch_url: str, Eg, "http://localhost:9200" + :param http_auth: Tuple of username and password + :type http_auth: tuple[str, str], Eg, ("username", "password") + :param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model) + :type vector_dimension: int, optional + :param dir: Path to the database directory, where the database is stored, defaults to None + :type dir: Optional[str], optional + :param batch_size: Number of items to insert in one batch, defaults to 100 + :type batch_size: Optional[int], optional + """ + self.opensearch_url = opensearch_url + self.http_auth = http_auth + self.vector_dimension = vector_dimension + self.extra_params = extra_params + self.batch_size = batch_size + + super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vector_db/pinecone.py similarity index 96% rename from embedchain/config/vectordb/pinecone.py rename to embedchain/config/vector_db/pinecone.py index 4eb66f21..83248579 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vector_db/pinecone.py @@ -1,7 +1,7 @@ import os from typing import Optional -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/config/vectordb/qdrant.py b/embedchain/config/vector_db/qdrant.py similarity index 97% rename from embedchain/config/vectordb/qdrant.py rename to embedchain/config/vector_db/qdrant.py index 2520226a..acdeacff 100644 --- a/embedchain/config/vectordb/qdrant.py +++ b/embedchain/config/vector_db/qdrant.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/config/vectordb/weaviate.py b/embedchain/config/vector_db/weaviate.py similarity index 89% rename from embedchain/config/vectordb/weaviate.py rename to embedchain/config/vector_db/weaviate.py index 8277e88f..f40c472e 100644 --- a/embedchain/config/vectordb/weaviate.py +++ b/embedchain/config/vector_db/weaviate.py @@ -1,6 +1,6 @@ from typing import Optional -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/config/vectordb/zilliz.py b/embedchain/config/vector_db/zilliz.py similarity index 96% rename from embedchain/config/vectordb/zilliz.py rename to embedchain/config/vector_db/zilliz.py index 25c3c957..26894115 100644 --- a/embedchain/config/vectordb/zilliz.py +++ b/embedchain/config/vector_db/zilliz.py @@ -1,7 +1,7 @@ import os from typing import Optional -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/factory.py b/embedchain/factory.py index db07e377..dd402f63 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -98,14 +98,14 @@ class VectorDBFactory: "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB", } provider_to_config_class = { - "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", - "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", - "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", - "lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig", - "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", - "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig", - "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", - "zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig", + "chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig", + "elasticsearch": "embedchain.config.vector_db.elasticsearch.ElasticsearchDBConfig", + "opensearch": "embedchain.config.vector_db.opensearch.OpenSearchDBConfig", + "lancedb": "embedchain.config.vector_db.lancedb.LanceDBConfig", + "pinecone": "embedchain.config.vector_db.pinecone.PineconeDBConfig", + "qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig", + "weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig", + "zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig", } @classmethod diff --git a/embedchain/vectordb/base.py b/embedchain/vectordb/base.py index a31cf100..e65cde01 100644 --- a/embedchain/vectordb/base.py +++ b/embedchain/vectordb/base.py @@ -1,4 +1,4 @@ -from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.config.vector_db.base import BaseVectorDbConfig from embedchain.embedder.base import BaseEmbedder from embedchain.helpers.json_serializable import JSONSerializable diff --git a/embedchain/vectordb/lancedb.py b/embedchain/vectordb/lancedb.py index 3179eb5c..d3d4b689 100644 --- a/embedchain/vectordb/lancedb.py +++ b/embedchain/vectordb/lancedb.py @@ -7,7 +7,7 @@ try: except ImportError: raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None -from embedchain.config.vectordb.lancedb import LanceDBConfig +from embedchain.config.vector_db.lancedb import LanceDBConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index c89dbcda..3c0520ce 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -11,7 +11,7 @@ except ImportError: from pinecone_text.sparse import BM25Encoder -from embedchain.config.vectordb.pinecone import PineconeDBConfig +from embedchain.config.vector_db.pinecone import PineconeDBConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.utils.misc import chunks from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index d9b05020..cdac19cf 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -12,7 +12,7 @@ except ImportError: from tqdm import tqdm -from embedchain.config.vectordb.qdrant import QdrantDBConfig +from embedchain.config.vector_db.qdrant import QdrantDBConfig from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index f4632436..897412a6 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -9,7 +9,7 @@ except ImportError: "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`" ) from None -from embedchain.config.vectordb.weaviate import WeaviateDBConfig +from embedchain.config.vector_db.weaviate import WeaviateDBConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.vectordb.base import BaseVectorDB diff --git a/tests/vectordb/test_lancedb.py b/tests/vectordb/test_lancedb.py index f50660bb..91885bdd 100644 --- a/tests/vectordb/test_lancedb.py +++ b/tests/vectordb/test_lancedb.py @@ -5,7 +5,7 @@ import pytest from embedchain import App from embedchain.config import AppConfig -from embedchain.config.vectordb.lancedb import LanceDBConfig +from embedchain.config.vector_db.lancedb import LanceDBConfig from embedchain.vectordb.lancedb import LanceDB os.environ["OPENAI_API_KEY"] = "test-api-key" diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py index 472b0593..00051ed9 100644 --- a/tests/vectordb/test_pinecone.py +++ b/tests/vectordb/test_pinecone.py @@ -1,6 +1,6 @@ import pytest -from embedchain.config.vectordb.pinecone import PineconeDBConfig +from embedchain.config.vector_db.pinecone import PineconeDBConfig from embedchain.vectordb.pinecone import PineconeDB diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py index 47326952..b2b3dfa0 100644 --- a/tests/vectordb/test_qdrant.py +++ b/tests/vectordb/test_qdrant.py @@ -7,7 +7,7 @@ from qdrant_client.http.models import Batch from embedchain import App from embedchain.config import AppConfig -from embedchain.config.vectordb.pinecone import PineconeDBConfig +from embedchain.config.vector_db.pinecone import PineconeDBConfig from embedchain.embedder.base import BaseEmbedder from embedchain.vectordb.qdrant import QdrantDB diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py index b263579f..a51870d4 100644 --- a/tests/vectordb/test_weaviate.py +++ b/tests/vectordb/test_weaviate.py @@ -3,7 +3,7 @@ from unittest.mock import patch from embedchain import App from embedchain.config import AppConfig -from embedchain.config.vectordb.pinecone import PineconeDBConfig +from embedchain.config.vector_db.pinecone import PineconeDBConfig from embedchain.embedder.base import BaseEmbedder from embedchain.vectordb.weaviate import WeaviateDB