diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index 613679e0..8cafe6a0 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -20,6 +20,10 @@ from chromadb.utils import embedding_functions config = AppConfig(log_level="DEBUG") naval_chat_bot = App(config) +# Example: specify a custom collection name +config = AppConfig(collection_name="naval_chat_bot") +naval_chat_bot = App(config) + # Example: define your own chunker config for `youtube_video` chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len) naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config)) diff --git a/docs/advanced/query_configuration.mdx b/docs/advanced/query_configuration.mdx index f743bd9a..76a9c264 100644 --- a/docs/advanced/query_configuration.mdx +++ b/docs/advanced/query_configuration.mdx @@ -4,11 +4,12 @@ title: '🔍 Query configurations' ## AppConfig -| option | description | type | default | -|-------------|-----------------------|---------------------------------|------------------------| -| log_level | log level | string | WARNING | +| option | description | type | default | +|-----------|-----------------------|---------------------------------|------------------------| +| log_level | log level | string | WARNING | | embedding_fn| embedding function | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} | -| db | vector database (experimental) | BaseVectorDB | ChromaDB | +| db | vector database (experimental) | BaseVectorDB | ChromaDB | +| collection_name | initial collection name for the database | string | embedchain_store | ## AddConfig diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index 548839ab..1e08040e 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -16,16 +16,22 @@ class AppConfig(BaseAppConfig): Config to initialize an embedchain custom `App` instance, with extra config options. """ - def __init__(self, log_level=None, host=None, port=None, id=None): + def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. + :param collection_name: Optional. Collection name for the database. """ super().__init__( - log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id + log_level=log_level, + embedding_fn=AppConfig.default_embedding_function(), + host=host, + port=port, + id=id, + collection_name=collection_name, ) @staticmethod diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index 0a94cf98..c1f0daa6 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -8,19 +8,21 @@ class BaseAppConfig(BaseConfig): Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`. """ - def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None): + def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. :param embedding_fn: Embedding function to use. :param db: Optional. (Vector) database instance to use for embeddings. - :param id: Optional. ID of the app. Document metadata will have this id. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. + :param id: Optional. ID of the app. Document metadata will have this id. + :param collection_name: Optional. Collection name for the database. """ self._setup_logging(log_level) self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port) + self.collection_name = collection_name if collection_name else "embedchain_store" self.id = id return diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index 22f8818e..edacf4b5 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -24,8 +24,8 @@ class CustomAppConfig(BaseAppConfig): host=None, port=None, id=None, + collection_name=None, provider: Providers = None, - model=None, open_source_app_config=None, deployment_name=None, ): @@ -35,9 +35,10 @@ class CustomAppConfig(BaseAppConfig): :param embedding_fn: Optional. Embedding function to use. :param embedding_fn_model: Optional. Model name to use for embedding function. :param db: Optional. (Vector) database to use for embeddings. - :param id: Optional. ID of the app. Document metadata will have this id. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. + :param id: Optional. ID of the app. Document metadata will have this id. + :param collection_name: Optional. Collection name for the database. :param provider: Optional. (Providers): LLM Provider to use. :param open_source_app_config: Optional. Config instance needed for open source apps. """ @@ -58,6 +59,7 @@ class CustomAppConfig(BaseAppConfig): host=host, port=port, id=id, + collection_name=collection_name, ) @staticmethod diff --git a/embedchain/config/apps/OpenSourceAppConfig.py b/embedchain/config/apps/OpenSourceAppConfig.py index 82b4cdcc..8666f125 100644 --- a/embedchain/config/apps/OpenSourceAppConfig.py +++ b/embedchain/config/apps/OpenSourceAppConfig.py @@ -8,11 +8,12 @@ class OpenSourceAppConfig(BaseAppConfig): Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options. """ - def __init__(self, log_level=None, host=None, port=None, id=None, model=None): + def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. :param id: Optional. ID of the app. Document metadata will have this id. + :param collection_name: Optional. Collection name for the database. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. :param model: Optional. GPT4ALL uses the model to instantiate the class. @@ -26,6 +27,7 @@ class OpenSourceAppConfig(BaseAppConfig): host=host, port=port, id=id, + collection_name=collection_name, ) @staticmethod diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index a552aa89..8033eff0 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -32,7 +32,7 @@ class EmbedChain: self.config = config self.db_client = self.config.db.client - self.collection = self.config.db.collection + self.collection = self.config.db._get_or_create_collection(self.config.collection_name) self.user_asks = [] self.is_docs_site_instance = False self.online = False @@ -325,6 +325,14 @@ class EmbedChain: memory.chat_memory.add_ai_message(streamed_answer) logging.info(f"Answer: {streamed_answer}") + def set_collection(self, collection_name): + """ + Set the collection to use. + + :param collection_name: The name of the collection to use. + """ + self.collection = self.config.db._get_or_create_collection(collection_name) + def count(self): """ Count the number of embeddings. diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index 79089063..f38e3d31 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -3,7 +3,6 @@ class BaseVectorDB: def __init__(self): self.client = self._get_or_create_db() - self.collection = self._get_or_create_collection() def _get_or_create_db(self): """Get or create the database.""" diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 61ade9af..168c6221 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -39,9 +39,9 @@ class ChromaDB(BaseVectorDB): """Get or create the database.""" return self.client - def _get_or_create_collection(self): + def _get_or_create_collection(self, name): """Get or create the collection.""" return self.client.get_or_create_collection( - "embedchain_store", + name=name, embedding_function=self.embedding_fn, ) diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 510aff1b..9ddff085 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -72,3 +72,186 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None) + +class TestChromaDbDuplicateHandling: + def test_duplicates_throw_warning(self, caplog): + """ + Test that add duplicates throws an error. + """ + # Start with a clean app + App().reset() + + app = App() + app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + assert "Insert of existing embedding ID: 0" in caplog.text + assert "Add of existing embedding ID: 0" in caplog.text + + def test_duplicates_collections_no_warning(self, caplog): + """ + Test that different collections can have duplicates. + """ + # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog. + + # Start with a clean app + App().reset() + + app = App() + app.set_collection("test_collection_1") + app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.set_collection("test_collection_2") + app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + assert "Insert of existing embedding ID: 0" not in caplog.text # not + assert "Add of existing embedding ID: 0" not in caplog.text # not + + +class TestChromaDbCollection(unittest.TestCase): + def test_init_with_default_collection(self): + """ + Test if the `App` instance is initialized with the correct default collection name. + """ + app = App() + + self.assertEqual(app.collection.name, "embedchain_store") + + def test_init_with_custom_collection(self): + """ + Test if the `App` instance is initialized with the correct custom collection name. + """ + config = AppConfig(collection_name="test_collection") + app = App(config) + + self.assertEqual(app.collection.name, "test_collection") + + def test_set_collection(self): + """ + Test if the `App` collection is correctly switched using the `set_collection` method. + """ + app = App() + app.set_collection("test_collection") + + self.assertEqual(app.collection.name, "test_collection") + + def test_changes_encapsulated(self): + """ + Test that changes to one collection do not affect the other collection + """ + # Start with a clean app + App().reset() + + app = App() + app.set_collection("test_collection_1") + # Collection should be empty when created + self.assertEqual(app.count(), 0) + + app.collection.add(embeddings=[0, 0, 0], ids=["0"]) + # After adding, should contain one item + self.assertEqual(app.count(), 1) + + app.set_collection("test_collection_2") + # New collection is empty + self.assertEqual(app.count(), 0) + + # Adding to new collection should not effect existing collection + app.collection.add(embeddings=[0, 0, 0], ids=["0"]) + app.set_collection("test_collection_1") + # Should still be 1, not 2. + self.assertEqual(app.count(), 1) + + def test_collections_are_persistent(self): + """ + Test that a collection can be picked up later. + """ + # Start with a clean app + App().reset() + + app = App() + app.set_collection("test_collection_1") + app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + del app + + app = App() + app.set_collection("test_collection_1") + self.assertEqual(app.count(), 1) + + def test_parallel_collections(self): + """ + Test that two apps can have different collections open in parallel. + Switching the names will allow instant access to the collection of + the other app. + """ + # Start clean + App().reset() + + # Create two apps + app1 = App(AppConfig(collection_name="test_collection_1")) + app2 = App(AppConfig(collection_name="test_collection_2")) + + # app2 has been created last, but adding to app1 will still write to collection 1. + app1.collection.add(embeddings=[0, 0, 0], ids=["0"]) + self.assertEqual(app1.count(), 1) + self.assertEqual(app2.count(), 0) + + # Add data + app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) + app2.collection.add(embeddings=[0, 0, 0], ids=["0"]) + + # Swap names and test + app1.set_collection('test_collection_2') + self.assertEqual(app1.count(), 1) + app2.set_collection('test_collection_1') + self.assertEqual(app2.count(), 3) + + def test_ids_share_collections(self): + """ + Different ids should still share collections. + """ + # Start clean + App().reset() + + # Create two apps + app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1")) + app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2")) + + # Add data + app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) + app2.collection.add(embeddings=[0, 0, 0], ids=["2"]) + + # Both should have the same collection + self.assertEqual(app1.count(), 3) + self.assertEqual(app2.count(), 3) + + def test_reset(self): + """ + Resetting should hit all collections and ids. + """ + # Start clean + App().reset() + + # Create four apps. + # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. + app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1")) + app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2")) + app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1")) + app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4")) + + # Each one of them get data + app1.collection.add(embeddings=[0, 0, 0], ids=["1"]) + app2.collection.add(embeddings=[0, 0, 0], ids=["2"]) + app3.collection.add(embeddings=[0, 0, 0], ids=["3"]) + app4.collection.add(embeddings=[0, 0, 0], ids=["4"]) + + # Resetting the first one should reset them all. + app1.reset() + + # Reinstantiate them + app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1")) + app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2")) + app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3")) + app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3")) + + # All should be empty + self.assertEqual(app1.count(), 0) + self.assertEqual(app2.count(), 0) + self.assertEqual(app3.count(), 0) + self.assertEqual(app4.count(), 0)