Refactoring vectordb naming convention in embedchain.config (#1469)
This commit is contained in:
@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .embedder.ollama import OllamaEmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .vectordb.chroma import ChromaDbConfig
|
||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||
from .vectordb.opensearch import OpenSearchDBConfig
|
||||
from .vectordb.zilliz import ZillizDBConfig
|
||||
from .vector_db.chroma import ChromaDbConfig
|
||||
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
||||
from .vector_db.opensearch import OpenSearchDBConfig
|
||||
from .vector_db.zilliz import ZillizDBConfig
|
||||
from .mem0_config import Mem0Config
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -1,56 +1,56 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, list[str]] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**ES_EXTRA_PARAMS: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an Elasticsearch client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||
:type es_url: Union[str, list[str]], optional
|
||||
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
||||
:type cloud_id: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||
"""
|
||||
if es_url and cloud_id:
|
||||
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
||||
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
|
||||
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
||||
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
||||
if not self.ES_URL and not self.CLOUD_ID:
|
||||
raise AttributeError(
|
||||
"Elasticsearch needs a URL or CLOUD_ID attribute, "
|
||||
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
|
||||
)
|
||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||
# Load API key from .env if it's not explicitly passed.
|
||||
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
|
||||
if (
|
||||
not self.ES_EXTRA_PARAMS.get("api_key")
|
||||
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||
):
|
||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||
|
||||
self.batch_size = batch_size
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, list[str]] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**ES_EXTRA_PARAMS: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an Elasticsearch client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||
:type es_url: Union[str, list[str]], optional
|
||||
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
||||
:type cloud_id: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||
"""
|
||||
if es_url and cloud_id:
|
||||
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
||||
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
|
||||
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
||||
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
||||
if not self.ES_URL and not self.CLOUD_ID:
|
||||
raise AttributeError(
|
||||
"Elasticsearch needs a URL or CLOUD_ID attribute, "
|
||||
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
|
||||
)
|
||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||
# Load API key from .env if it's not explicitly passed.
|
||||
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
|
||||
if (
|
||||
not self.ES_EXTRA_PARAMS.get("api_key")
|
||||
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||
):
|
||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||
|
||||
self.batch_size = batch_size
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -1,41 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
opensearch_url: str,
|
||||
http_auth: tuple[str, str],
|
||||
vector_dimension: int = 1536,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an OpenSearch client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param opensearch_url: URL of the OpenSearch domain
|
||||
:type opensearch_url: str, Eg, "http://localhost:9200"
|
||||
:param http_auth: Tuple of username and password
|
||||
:type http_auth: tuple[str, str], Eg, ("username", "password")
|
||||
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
|
||||
:type vector_dimension: int, optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
"""
|
||||
self.opensearch_url = opensearch_url
|
||||
self.http_auth = http_auth
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.batch_size = batch_size
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
opensearch_url: str,
|
||||
http_auth: tuple[str, str],
|
||||
vector_dimension: int = 1536,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an OpenSearch client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param opensearch_url: URL of the OpenSearch domain
|
||||
:type opensearch_url: str, Eg, "http://localhost:9200"
|
||||
:param http_auth: Tuple of username and password
|
||||
:type http_auth: tuple[str, str], Eg, ("username", "password")
|
||||
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
|
||||
:type vector_dimension: int, optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
"""
|
||||
self.opensearch_url = opensearch_url
|
||||
self.http_auth = http_auth
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.batch_size = batch_size
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -98,14 +98,14 @@ class VectorDBFactory:
|
||||
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||
"lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
|
||||
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
||||
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
|
||||
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
|
||||
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
|
||||
"chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig",
|
||||
"elasticsearch": "embedchain.config.vector_db.elasticsearch.ElasticsearchDBConfig",
|
||||
"opensearch": "embedchain.config.vector_db.opensearch.OpenSearchDBConfig",
|
||||
"lancedb": "embedchain.config.vector_db.lancedb.LanceDBConfig",
|
||||
"pinecone": "embedchain.config.vector_db.pinecone.PineconeDBConfig",
|
||||
"qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig",
|
||||
"weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig",
|
||||
"zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
|
||||
|
||||
from embedchain.config.vectordb.lancedb import LanceDBConfig
|
||||
from embedchain.config.vector_db.lancedb import LanceDBConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ except ImportError:
|
||||
|
||||
from pinecone_text.sparse import BM25Encoder
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.utils.misc import chunks
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
@@ -12,7 +12,7 @@ except ImportError:
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config.vectordb.qdrant import QdrantDBConfig
|
||||
from embedchain.config.vector_db.qdrant import QdrantDBConfig
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ except ImportError:
|
||||
"Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.weaviate import WeaviateDBConfig
|
||||
from embedchain.config.vector_db.weaviate import WeaviateDBConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.lancedb import LanceDBConfig
|
||||
from embedchain.config.vector_db.lancedb import LanceDBConfig
|
||||
from embedchain.vectordb.lancedb import LanceDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.vectordb.pinecone import PineconeDB
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from qdrant_client.http.models import Batch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.qdrant import QdrantDB
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.weaviate import WeaviateDB
|
||||
|
||||
|
||||
Reference in New Issue
Block a user