Refactoring vectordb naming convention in embedchain.config (#1469)
This commit is contained in:
@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
|
|||||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||||
from .embedder.ollama import OllamaEmbedderConfig
|
from .embedder.ollama import OllamaEmbedderConfig
|
||||||
from .llm.base import BaseLlmConfig
|
from .llm.base import BaseLlmConfig
|
||||||
from .vectordb.chroma import ChromaDbConfig
|
from .vector_db.chroma import ChromaDbConfig
|
||||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
||||||
from .vectordb.opensearch import OpenSearchDBConfig
|
from .vector_db.opensearch import OpenSearchDBConfig
|
||||||
from .vectordb.zilliz import ZillizDBConfig
|
from .vector_db.zilliz import ZillizDBConfig
|
||||||
from .mem0_config import Mem0Config
|
from .mem0_config import Mem0Config
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -1,56 +1,56 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
es_url: Union[str, list[str]] = None,
|
es_url: Union[str, list[str]] = None,
|
||||||
cloud_id: Optional[str] = None,
|
cloud_id: Optional[str] = None,
|
||||||
batch_size: Optional[int] = 100,
|
batch_size: Optional[int] = 100,
|
||||||
**ES_EXTRA_PARAMS: dict[str, any],
|
**ES_EXTRA_PARAMS: dict[str, any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a configuration class instance for an Elasticsearch client.
|
Initializes a configuration class instance for an Elasticsearch client.
|
||||||
|
|
||||||
:param collection_name: Default name for the collection, defaults to None
|
:param collection_name: Default name for the collection, defaults to None
|
||||||
:type collection_name: Optional[str], optional
|
:type collection_name: Optional[str], optional
|
||||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||||
:type dir: Optional[str], optional
|
:type dir: Optional[str], optional
|
||||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||||
:type es_url: Union[str, list[str]], optional
|
:type es_url: Union[str, list[str]], optional
|
||||||
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
||||||
:type cloud_id: Optional[str], optional
|
:type cloud_id: Optional[str], optional
|
||||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||||
:type batch_size: Optional[int], optional
|
:type batch_size: Optional[int], optional
|
||||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||||
"""
|
"""
|
||||||
if es_url and cloud_id:
|
if es_url and cloud_id:
|
||||||
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
||||||
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
|
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
|
||||||
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
||||||
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
||||||
if not self.ES_URL and not self.CLOUD_ID:
|
if not self.ES_URL and not self.CLOUD_ID:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"Elasticsearch needs a URL or CLOUD_ID attribute, "
|
"Elasticsearch needs a URL or CLOUD_ID attribute, "
|
||||||
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
|
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
|
||||||
)
|
)
|
||||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||||
# Load API key from .env if it's not explicitly passed.
|
# Load API key from .env if it's not explicitly passed.
|
||||||
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
|
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
|
||||||
if (
|
if (
|
||||||
not self.ES_EXTRA_PARAMS.get("api_key")
|
not self.ES_EXTRA_PARAMS.get("api_key")
|
||||||
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
||||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||||
):
|
):
|
||||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -1,41 +1,41 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
class OpenSearchDBConfig(BaseVectorDbConfig):
|
class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
opensearch_url: str,
|
opensearch_url: str,
|
||||||
http_auth: tuple[str, str],
|
http_auth: tuple[str, str],
|
||||||
vector_dimension: int = 1536,
|
vector_dimension: int = 1536,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
batch_size: Optional[int] = 100,
|
batch_size: Optional[int] = 100,
|
||||||
**extra_params: dict[str, any],
|
**extra_params: dict[str, any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a configuration class instance for an OpenSearch client.
|
Initializes a configuration class instance for an OpenSearch client.
|
||||||
|
|
||||||
:param collection_name: Default name for the collection, defaults to None
|
:param collection_name: Default name for the collection, defaults to None
|
||||||
:type collection_name: Optional[str], optional
|
:type collection_name: Optional[str], optional
|
||||||
:param opensearch_url: URL of the OpenSearch domain
|
:param opensearch_url: URL of the OpenSearch domain
|
||||||
:type opensearch_url: str, Eg, "http://localhost:9200"
|
:type opensearch_url: str, Eg, "http://localhost:9200"
|
||||||
:param http_auth: Tuple of username and password
|
:param http_auth: Tuple of username and password
|
||||||
:type http_auth: tuple[str, str], Eg, ("username", "password")
|
:type http_auth: tuple[str, str], Eg, ("username", "password")
|
||||||
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
|
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
|
||||||
:type vector_dimension: int, optional
|
:type vector_dimension: int, optional
|
||||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||||
:type dir: Optional[str], optional
|
:type dir: Optional[str], optional
|
||||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||||
:type batch_size: Optional[int], optional
|
:type batch_size: Optional[int], optional
|
||||||
"""
|
"""
|
||||||
self.opensearch_url = opensearch_url
|
self.opensearch_url = opensearch_url
|
||||||
self.http_auth = http_auth
|
self.http_auth = http_auth
|
||||||
self.vector_dimension = vector_dimension
|
self.vector_dimension = vector_dimension
|
||||||
self.extra_params = extra_params
|
self.extra_params = extra_params
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
@@ -98,14 +98,14 @@ class VectorDBFactory:
|
|||||||
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
|
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
"chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig",
|
||||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
"elasticsearch": "embedchain.config.vector_db.elasticsearch.ElasticsearchDBConfig",
|
||||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
"opensearch": "embedchain.config.vector_db.opensearch.OpenSearchDBConfig",
|
||||||
"lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
|
"lancedb": "embedchain.config.vector_db.lancedb.LanceDBConfig",
|
||||||
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
"pinecone": "embedchain.config.vector_db.pinecone.PineconeDBConfig",
|
||||||
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
|
"qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig",
|
||||||
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
|
"weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig",
|
||||||
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
|
"zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.helpers.json_serializable import JSONSerializable
|
from embedchain.helpers.json_serializable import JSONSerializable
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
|
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
|
||||||
|
|
||||||
from embedchain.config.vectordb.lancedb import LanceDBConfig
|
from embedchain.config.vector_db.lancedb import LanceDBConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ except ImportError:
|
|||||||
|
|
||||||
from pinecone_text.sparse import BM25Encoder
|
from pinecone_text.sparse import BM25Encoder
|
||||||
|
|
||||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
from embedchain.utils.misc import chunks
|
from embedchain.utils.misc import chunks
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ except ImportError:
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from embedchain.config.vectordb.qdrant import QdrantDBConfig
|
from embedchain.config.vector_db.qdrant import QdrantDBConfig
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ except ImportError:
|
|||||||
"Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
|
"Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
from embedchain.config.vectordb.weaviate import WeaviateDBConfig
|
from embedchain.config.vector_db.weaviate import WeaviateDBConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig
|
||||||
from embedchain.config.vectordb.lancedb import LanceDBConfig
|
from embedchain.config.vector_db.lancedb import LanceDBConfig
|
||||||
from embedchain.vectordb.lancedb import LanceDB
|
from embedchain.vectordb.lancedb import LanceDB
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||||
from embedchain.vectordb.pinecone import PineconeDB
|
from embedchain.vectordb.pinecone import PineconeDB
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from qdrant_client.http.models import Batch
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig
|
||||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.vectordb.qdrant import QdrantDB
|
from embedchain.vectordb.qdrant import QdrantDB
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig
|
||||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.vectordb.weaviate import WeaviateDB
|
from embedchain.vectordb.weaviate import WeaviateDB
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user