Refactoring vectordb naming convention in embedchain.config (#1469)

This commit is contained in:
Vatsal Rathod
2024-07-08 19:01:17 -04:00
committed by GitHub
parent 1a5d0d236a
commit 83e8c97295
20 changed files with 124 additions and 124 deletions

View File

@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig from .llm.base import BaseLlmConfig
from .vectordb.chroma import ChromaDbConfig from .vector_db.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig from .vector_db.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig from .vector_db.opensearch import OpenSearchDBConfig
from .vectordb.zilliz import ZillizDBConfig from .vector_db.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config from .mem0_config import Mem0Config

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,56 +1,56 @@
import os import os
from typing import Optional, Union from typing import Optional, Union
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig): class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
es_url: Union[str, list[str]] = None, es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None, cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100, batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any], **ES_EXTRA_PARAMS: dict[str, any],
): ):
""" """
Initializes a configuration class instance for an Elasticsearch client. Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None :param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional :type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None :param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional :type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional :type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None :param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional :type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100 :param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional :type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional :type ES_EXTRA_PARAMS: dict[str, Any], optional
""" """
if es_url and cloud_id: if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.") raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]): # self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID") self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID: if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError( raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, " "Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501 "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
) )
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed. # Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth' # Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if ( if (
not self.ES_EXTRA_PARAMS.get("api_key") not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth") and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth") and not self.ES_EXTRA_PARAMS.get("bearer_auth")
): ):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY") self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
self.batch_size = batch_size self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,41 +1,41 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig): class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
opensearch_url: str, opensearch_url: str,
http_auth: tuple[str, str], http_auth: tuple[str, str],
vector_dimension: int = 1536, vector_dimension: int = 1536,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
batch_size: Optional[int] = 100, batch_size: Optional[int] = 100,
**extra_params: dict[str, any], **extra_params: dict[str, any],
): ):
""" """
Initializes a configuration class instance for an OpenSearch client. Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None :param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional :type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain :param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200" :type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password :param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password") :type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model) :param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional :type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None :param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional :type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100 :param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional :type batch_size: Optional[int], optional
""" """
self.opensearch_url = opensearch_url self.opensearch_url = opensearch_url
self.http_auth = http_auth self.http_auth = http_auth
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
self.extra_params = extra_params self.extra_params = extra_params
self.batch_size = batch_size self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -98,14 +98,14 @@ class VectorDBFactory:
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB", "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
} }
provider_to_config_class = { provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "elasticsearch": "embedchain.config.vector_db.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "opensearch": "embedchain.config.vector_db.opensearch.OpenSearchDBConfig",
"lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig", "lancedb": "embedchain.config.vector_db.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", "pinecone": "embedchain.config.vector_db.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig", "qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", "weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig",
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig", "zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig",
} }
@classmethod @classmethod

View File

@@ -1,4 +1,4 @@
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable

View File

@@ -7,7 +7,7 @@ try:
except ImportError: except ImportError:
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
from embedchain.config.vectordb.lancedb import LanceDBConfig from embedchain.config.vector_db.lancedb import LanceDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -11,7 +11,7 @@ except ImportError:
from pinecone_text.sparse import BM25Encoder from pinecone_text.sparse import BM25Encoder
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils.misc import chunks from embedchain.utils.misc import chunks
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -12,7 +12,7 @@ except ImportError:
from tqdm import tqdm from tqdm import tqdm
from embedchain.config.vectordb.qdrant import QdrantDBConfig from embedchain.config.vector_db.qdrant import QdrantDBConfig
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -9,7 +9,7 @@ except ImportError:
"Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`" "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
) from None ) from None
from embedchain.config.vectordb.weaviate import WeaviateDBConfig from embedchain.config.vector_db.weaviate import WeaviateDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -5,7 +5,7 @@ import pytest
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig
from embedchain.config.vectordb.lancedb import LanceDBConfig from embedchain.config.vector_db.lancedb import LanceDBConfig
from embedchain.vectordb.lancedb import LanceDB from embedchain.vectordb.lancedb import LanceDB
os.environ["OPENAI_API_KEY"] = "test-api-key" os.environ["OPENAI_API_KEY"] = "test-api-key"

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.vectordb.pinecone import PineconeDB from embedchain.vectordb.pinecone import PineconeDB

View File

@@ -7,7 +7,7 @@ from qdrant_client.http.models import Batch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.qdrant import QdrantDB from embedchain.vectordb.qdrant import QdrantDB

View File

@@ -3,7 +3,7 @@ from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.weaviate import WeaviateDB from embedchain.vectordb.weaviate import WeaviateDB