Upgrade the chromadb version to 0.4.8 and open its settings configuration. (#517)
This commit is contained in:
@@ -24,6 +24,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
db_type: VectorDatabases = None,
|
||||
vector_dim: VectorDimensions = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
chroma_settings: dict = {},
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
@@ -38,6 +39,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
:param db_type: Optional. type of Vector database to use
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||
@@ -50,13 +52,14 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
vector_dim=vector_dim,
|
||||
collection_name=self.collection_name,
|
||||
es_config=es_config,
|
||||
chroma_settings=chroma_settings,
|
||||
)
|
||||
self.id = id
|
||||
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config):
|
||||
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
|
||||
"""
|
||||
Get db based on db_type, db with default database (`ChromaDb`)
|
||||
:param Optional. (Vector) database to use for embeddings.
|
||||
@@ -85,7 +88,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port)
|
||||
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
|
||||
|
||||
def _setup_logging(self, debug_level):
|
||||
level = logging.WARNING # Default level
|
||||
|
||||
@@ -35,6 +35,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
collect_metrics: Optional[bool] = None,
|
||||
db_type: VectorDatabases = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
chroma_settings: dict = {},
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
@@ -51,6 +52,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param db_type: Optional. type of Vector database to use.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
"""
|
||||
if provider:
|
||||
self.provider = provider
|
||||
@@ -73,6 +75,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
db_type=db_type,
|
||||
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
|
||||
es_config=es_config,
|
||||
chroma_settings=chroma_settings,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -22,23 +22,31 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
class ChromaDB(BaseVectorDB):
|
||||
"""Vector database using ChromaDB."""
|
||||
|
||||
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
|
||||
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
if not hasattr(embedding_fn, "__call__"):
|
||||
raise ValueError("Embedding function is not a function")
|
||||
|
||||
self.settings = Settings()
|
||||
for key, value in chroma_settings.items():
|
||||
if hasattr(self.settings, key):
|
||||
setattr(self.settings, key, value)
|
||||
|
||||
if host and port:
|
||||
logging.info(f"Connecting to ChromaDB server: {host}:{port}")
|
||||
self.client = chromadb.HttpClient(host=host, port=port)
|
||||
self.settings.chroma_server_host = host
|
||||
self.settings.chroma_server_http_port = port
|
||||
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||
|
||||
else:
|
||||
if db_dir is None:
|
||||
db_dir = "db"
|
||||
self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=db_dir,
|
||||
settings=self.settings,
|
||||
)
|
||||
|
||||
self.settings.persist_directory = db_dir
|
||||
self.settings.is_persistent = True
|
||||
|
||||
self.client = chromadb.Client(self.settings)
|
||||
super().__init__()
|
||||
|
||||
def _get_or_create_db(self):
|
||||
|
||||
Reference in New Issue
Block a user