feat: add support for Elastcisearch as vector data source (#402)
This commit is contained in:
committed by
GitHub
parent
f0abfea55d
commit
0179141b2e
@@ -5,3 +5,4 @@ from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
|
||||
from .BaseConfig import BaseConfig # noqa: F401
|
||||
from .ChatConfig import ChatConfig # noqa: F401
|
||||
from .QueryConfig import QueryConfig # noqa: F401
|
||||
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig # noqa: F401
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.models import VectorDatabases, VectorDimensions
|
||||
|
||||
|
||||
class BaseAppConfig(BaseConfig):
|
||||
@@ -8,7 +10,19 @@ class BaseAppConfig(BaseConfig):
|
||||
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
embedding_fn=None,
|
||||
db=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
db_type: VectorDatabases = None,
|
||||
vector_dim: VectorDimensions = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
@@ -18,27 +32,53 @@ class BaseAppConfig(BaseConfig):
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param db_type: Optional. type of Vector database to use
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
|
||||
self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
|
||||
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||
self.db = BaseAppConfig.get_db(
|
||||
db=db,
|
||||
embedding_fn=embedding_fn,
|
||||
host=host,
|
||||
port=port,
|
||||
db_type=db_type,
|
||||
vector_dim=vector_dim,
|
||||
collection_name=self.collection_name,
|
||||
es_config=es_config,
|
||||
)
|
||||
self.id = id
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def default_db(embedding_fn, host, port):
|
||||
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config):
|
||||
"""
|
||||
Sets database to default (`ChromaDb`).
|
||||
|
||||
Get db based on db_type, db with default database (`ChromaDb`)
|
||||
:param Optional. (Vector) database to use for embeddings.
|
||||
:param embedding_fn: Embedding function to use in database.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:returns: Default database
|
||||
:param db_type: Optional. db type to use. Supported values (`es`, `chroma`)
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:raises ValueError: BaseAppConfig knows no default embedding function.
|
||||
:returns: database instance
|
||||
"""
|
||||
if db:
|
||||
return db
|
||||
|
||||
if embedding_fn is None:
|
||||
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
|
||||
|
||||
if db_type == VectorDatabases.ELASTICSEARCH:
|
||||
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
|
||||
|
||||
return ElasticsearchDB(
|
||||
embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config
|
||||
)
|
||||
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.models import EmbeddingFunctions, Providers, VectorDatabases, VectorDimensions
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
|
||||
@@ -28,6 +29,8 @@ class CustomAppConfig(BaseAppConfig):
|
||||
provider: Providers = None,
|
||||
open_source_app_config=None,
|
||||
deployment_name=None,
|
||||
db_type: VectorDatabases = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
@@ -41,6 +44,8 @@ class CustomAppConfig(BaseAppConfig):
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param provider: Optional. (Providers): LLM Provider to use.
|
||||
:param open_source_app_config: Optional. Config instance needed for open source apps.
|
||||
:param db_type: Optional. type of Vector database to use.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
"""
|
||||
if provider:
|
||||
self.provider = provider
|
||||
@@ -59,6 +64,9 @@ class CustomAppConfig(BaseAppConfig):
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
db_type=db_type,
|
||||
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
|
||||
es_config=es_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -108,3 +116,20 @@ class CustomAppConfig(BaseAppConfig):
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model)
|
||||
|
||||
@staticmethod
|
||||
def get_vector_dimension(embedding_function: EmbeddingFunctions):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(f"Invalid option: '{embedding_function}'.")
|
||||
|
||||
if embedding_function == EmbeddingFunctions.OPENAI:
|
||||
return VectorDimensions.OPENAI.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
return VectorDimensions.HUGGING_FACE.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
|
||||
return VectorDimensions.VERTEX_AI.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.GPT4ALL:
|
||||
return VectorDimensions.GPT4ALL.value
|
||||
|
||||
15
embedchain/config/vectordbs/ElasticsearchDBConfig.py
Normal file
15
embedchain/config/vectordbs/ElasticsearchDBConfig.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
|
||||
class ElasticsearchDBConfig(BaseConfig):
|
||||
"""
|
||||
Config to initialize an elasticsearch client.
|
||||
:param es_url. elasticsearch url or list of nodes url to be used for connection
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
"""
|
||||
|
||||
def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||
self.ES_URL = es_url
|
||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||
0
embedchain/config/vectordbs/__init__.py
Normal file
0
embedchain/config/vectordbs/__init__.py
Normal file
Reference in New Issue
Block a user