From eecdbc5e065d3f49e1153a231747ce58c246bddb Mon Sep 17 00:00:00 2001 From: wangJm Date: Mon, 4 Sep 2023 14:31:08 +0800 Subject: [PATCH] Upgrade the chromadb version to 0.4.8 and open its settings configuration. (#517) --- embedchain/config/apps/BaseAppConfig.py | 7 ++- embedchain/config/apps/CustomAppConfig.py | 3 ++ embedchain/vectordb/chroma_db.py | 22 ++++++--- pyproject.toml | 2 +- tests/embedchain/test_embedchain.py | 9 +++- tests/vectordb/test_chroma_db.py | 58 +++++++++++++++++++---- 6 files changed, 80 insertions(+), 21 deletions(-) diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index 0fed691f..890045e1 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -24,6 +24,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable): db_type: VectorDatabases = None, vector_dim: VectorDimensions = None, es_config: ElasticsearchDBConfig = None, + chroma_settings: dict = {}, ): """ :param log_level: Optional. (String) Debug level @@ -38,6 +39,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable): :param db_type: Optional. type of Vector database to use :param vector_dim: Vector dimension generated by embedding fn :param es_config: Optional. elasticsearch database config to be used for connection + :param chroma_settings: Optional. Chroma settings for connection. """ self._setup_logging(log_level) self.collection_name = collection_name if collection_name else "embedchain_store" @@ -50,13 +52,14 @@ class BaseAppConfig(BaseConfig, JSONSerializable): vector_dim=vector_dim, collection_name=self.collection_name, es_config=es_config, + chroma_settings=chroma_settings, ) self.id = id self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False return @staticmethod - def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config): + def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings): """ Get db based on db_type, db with default database (`ChromaDb`) :param Optional. (Vector) database to use for embeddings. @@ -85,7 +88,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable): from embedchain.vectordb.chroma_db import ChromaDB - return ChromaDB(embedding_fn=embedding_fn, host=host, port=port) + return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings) def _setup_logging(self, debug_level): level = logging.WARNING # Default level diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index 332ef0e8..d6dfe486 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -35,6 +35,7 @@ class CustomAppConfig(BaseAppConfig): collect_metrics: Optional[bool] = None, db_type: VectorDatabases = None, es_config: ElasticsearchDBConfig = None, + chroma_settings: dict = {}, ): """ :param log_level: Optional. (String) Debug level @@ -51,6 +52,7 @@ class CustomAppConfig(BaseAppConfig): :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. :param db_type: Optional. type of Vector database to use. :param es_config: Optional. elasticsearch database config to be used for connection + :param chroma_settings: Optional. Chroma settings for connection. """ if provider: self.provider = provider @@ -73,6 +75,7 @@ class CustomAppConfig(BaseAppConfig): db_type=db_type, vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn), es_config=es_config, + chroma_settings=chroma_settings, ) @staticmethod diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 1f8097f6..c019c1bb 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -22,23 +22,31 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB class ChromaDB(BaseVectorDB): """Vector database using ChromaDB.""" - def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None): + def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}): self.embedding_fn = embedding_fn if not hasattr(embedding_fn, "__call__"): raise ValueError("Embedding function is not a function") + self.settings = Settings() + for key, value in chroma_settings.items(): + if hasattr(self.settings, key): + setattr(self.settings, key, value) + if host and port: logging.info(f"Connecting to ChromaDB server: {host}:{port}") - self.client = chromadb.HttpClient(host=host, port=port) + self.settings.chroma_server_host = host + self.settings.chroma_server_http_port = port + self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" + else: if db_dir is None: db_dir = "db" - self.settings = Settings(anonymized_telemetry=False, allow_reset=True) - self.client = chromadb.PersistentClient( - path=db_dir, - settings=self.settings, - ) + + self.settings.persist_directory = db_dir + self.settings.is_persistent = True + + self.client = chromadb.Client(self.settings) super().__init__() def _get_or_create_db(self): diff --git a/pyproject.toml b/pyproject.toml index 8a310742..8e228639 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ langchain = "^0.0.279" requests = "^2.31.0" openai = "^0.27.5" tiktoken = "^0.4.0" -chromadb ="^0.4.2" +chromadb ="^0.4.8" youtube-transcript-api = "^0.6.1" beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index 7ac8d9b0..df2f1f81 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -3,7 +3,8 @@ import unittest from unittest.mock import patch from embedchain import App -from embedchain.config import AppConfig +from embedchain.config import AppConfig, CustomAppConfig +from embedchain.models import EmbeddingFunctions, Providers class TestChromaDbHostsLoglevel(unittest.TestCase): @@ -42,7 +43,11 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): """ Test if the `App` instance is correctly reconstructed after a reset. """ - app = App() + app = App( + CustomAppConfig( + provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True} + ) + ) app.reset() # Make sure the client is still healthy diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 77431576..d1e95d4c 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -4,7 +4,8 @@ import unittest from unittest.mock import patch from embedchain import App -from embedchain.config import AppConfig +from embedchain.config import AppConfig, CustomAppConfig +from embedchain.models import EmbeddingFunctions, Providers from embedchain.vectordb.chroma_db import ChromaDB @@ -21,6 +22,24 @@ class TestChromaDbHosts(unittest.TestCase): self.assertEqual(settings.chroma_server_host, host) self.assertEqual(settings.chroma_server_http_port, port) + def test_init_with_basic_auth(self): + host = "test-host" + port = "1234" + + chroma_auth_settings = { + "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider", + "chroma_client_auth_credentials": "admin:admin", + } + + db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings) + settings = db.client.get_settings() + self.assertEqual(settings.chroma_server_host, host) + self.assertEqual(settings.chroma_server_http_port, port) + self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"]) + self.assertEqual( + settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"] + ) + # Review this test class TestChromaDbHostsInit(unittest.TestCase): @@ -68,12 +87,18 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): class TestChromaDbDuplicateHandling: + app_with_settings = App( + CustomAppConfig( + provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True} + ) + ) + def test_duplicates_throw_warning(self, caplog): """ Test that add duplicates throws an error. """ # Start with a clean app - App().reset() + self.app_with_settings.reset() app = App(config=AppConfig(collect_metrics=False)) app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) @@ -88,7 +113,7 @@ class TestChromaDbDuplicateHandling: # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog. # Start with a clean app - App().reset() + self.app_with_settings.reset() app = App(config=AppConfig(collect_metrics=False)) app.set_collection("test_collection_1") @@ -100,6 +125,12 @@ class TestChromaDbDuplicateHandling: class TestChromaDbCollection(unittest.TestCase): + app_with_settings = App( + CustomAppConfig( + provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True} + ) + ) + def test_init_with_default_collection(self): """ Test if the `App` instance is initialized with the correct default collection name. @@ -131,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase): Test that changes to one collection do not affect the other collection """ # Start with a clean app - App().reset() + self.app_with_settings.reset() app = App(config=AppConfig(collect_metrics=False)) app.set_collection("test_collection_1") @@ -157,7 +188,7 @@ class TestChromaDbCollection(unittest.TestCase): Test that a collection can be picked up later. """ # Start with a clean app - App().reset() + self.app_with_settings.reset() app = App(config=AppConfig(collect_metrics=False)) app.set_collection("test_collection_1") @@ -175,7 +206,7 @@ class TestChromaDbCollection(unittest.TestCase): the other app. """ # Start clean - App().reset() + self.app_with_settings.reset() # Create two apps app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False)) @@ -201,7 +232,7 @@ class TestChromaDbCollection(unittest.TestCase): Different ids should still share collections. """ # Start clean - App().reset() + self.app_with_settings.reset() # Create two apps app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False)) @@ -220,11 +251,20 @@ class TestChromaDbCollection(unittest.TestCase): Resetting should hit all collections and ids. """ # Start clean - App().reset() + self.app_with_settings.reset() # Create four apps. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. - app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False)) + app1 = App( + CustomAppConfig( + collection_name="one_collection", + id="new_app_id_1", + collect_metrics=False, + provider=Providers.OPENAI, + embedding_fn=EmbeddingFunctions.OPENAI, + chroma_settings={"allow_reset": True}, + ) + ) app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False)) app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))