From bbce18caacbb5630e6ccbf1ea7f8397b468868f3 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Wed, 25 Oct 2023 22:52:20 +0530 Subject: [PATCH] [Bug fix] `reset()` erases everything from db (#844) --- .gitignore | 1 + embedchain/vectordb/chroma.py | 4 +- tests/vectordb/test_chroma_db.py | 544 ++++++++++++++----------------- 3 files changed, 242 insertions(+), 307 deletions(-) diff --git a/.gitignore b/.gitignore index 04de5b05..f5126be0 100644 --- a/.gitignore +++ b/.gitignore @@ -165,6 +165,7 @@ cython_debug/ # Database db +test-db .vscode /poetry.lock diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index c77e83c1..b5e65c88 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -262,9 +262,9 @@ class ChromaDB(BaseVectorDB): """ Resets the database. Deletes all embeddings irreversibly. """ - # Delete all data from the database + # Delete all data from the collection try: - self.client.reset() + self.client.delete_collection(self.config.collection_name) except ValueError: raise ValueError( "For safety reasons, resetting is disabled. " diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 8caf14ac..3477d717 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -1,366 +1,300 @@ -# ruff: noqa: E501 - -import unittest +import os +import shutil +import pytest from unittest.mock import patch from chromadb.config import Settings - from embedchain import App from embedchain.config import AppConfig, ChromaDbConfig from embedchain.vectordb.chroma import ChromaDB - -class TestChromaDbHosts(unittest.TestCase): - def test_init_with_host_and_port(self): - """ - Test if the `ChromaDB` instance is initialized with the correct host and port values. - """ - host = "test-host" - port = "1234" - config = ChromaDbConfig(host=host, port=port) - - db = ChromaDB(config=config) - settings = db.client.get_settings() - self.assertEqual(settings.chroma_server_host, host) - self.assertEqual(settings.chroma_server_http_port, port) - - def test_init_with_basic_auth(self): - host = "test-host" - port = "1234" - - chroma_config = { - "host": host, - "port": port, - "chroma_settings": { - "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider", - "chroma_client_auth_credentials": "admin:admin", - }, - } - - config = ChromaDbConfig(**chroma_config) - db = ChromaDB(config=config) - settings = db.client.get_settings() - self.assertEqual(settings.chroma_server_host, host) - self.assertEqual(settings.chroma_server_http_port, port) - self.assertEqual( - settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"] - ) - self.assertEqual( - settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"] - ) +os.environ["OPENAI_API_KEY"] = "test-api-key" -# Review this test -class TestChromaDbHostsInit(unittest.TestCase): - @patch("embedchain.vectordb.chroma.chromadb.Client") - def test_app_init_with_host_and_port(self, mock_client): - """ - Test if the `App` instance is initialized with the correct host and port values. - """ - host = "test-host" - port = "1234" - - config = AppConfig(collect_metrics=False) - db_config = ChromaDbConfig(host=host, port=port) - - _app = App(config, db_config=db_config) - - called_settings: Settings = mock_client.call_args[0][0] - - self.assertEqual(called_settings.chroma_server_host, host) - self.assertEqual(called_settings.chroma_server_http_port, port) +@pytest.fixture +def chroma_db(): + return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234")) -class TestChromaDbHostsNone(unittest.TestCase): - @patch("embedchain.vectordb.chroma.chromadb.Client") - def test_init_with_host_and_port_none(self, mock_client): - """ - Test if the `App` instance is initialized without default hosts and ports. - """ - - _app = App(config=AppConfig(collect_metrics=False)) - - called_settings: Settings = mock_client.call_args[0][0] - self.assertEqual(called_settings.chroma_server_host, None) - self.assertEqual(called_settings.chroma_server_http_port, None) +@pytest.fixture +def app_with_settings(): + chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db") + app_config = AppConfig(collect_metrics=False) + return App(config=app_config, db_config=chroma_config) -class TestChromaDbHostsLoglevel(unittest.TestCase): - @patch("embedchain.vectordb.chroma.chromadb.Client") - def test_init_with_host_and_port_log_level(self, mock_client): - """ - Test if the `App` instance is initialized without a config that does not contain default hosts and ports. - """ - - _app = App(config=AppConfig(collect_metrics=False)) - - self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) - self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None) +@pytest.fixture(scope="session", autouse=True) +def cleanup_db(): + yield + try: + shutil.rmtree("test-db") + except OSError as e: + print("Error: %s - %s." % (e.filename, e.strerror)) -class TestChromaDbDuplicateHandling: - chroma_config = ChromaDbConfig(allow_reset=True) - app_config = AppConfig(collection_name=False, collect_metrics=False) - app_with_settings = App(config=app_config, db_config=chroma_config) - - def test_duplicates_throw_warning(self, caplog): - """ - Test that add duplicates throws an error. - """ - # Start with a clean app - self.app_with_settings.reset() - - app = App(config=AppConfig(collect_metrics=False)) - app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - assert "Insert of existing embedding ID: 0" in caplog.text - assert "Add of existing embedding ID: 0" in caplog.text - - def test_duplicates_collections_no_warning(self, caplog): - """ - Test that different collections can have duplicates. - """ - # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog. - - # Start with a clean app - self.app_with_settings.reset() - - app = App(config=AppConfig(collect_metrics=False)) - app.set_collection_name("test_collection_1") - app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - app.set_collection_name("test_collection_2") - app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - assert "Insert of existing embedding ID: 0" not in caplog.text # not - assert "Add of existing embedding ID: 0" not in caplog.text # not +def test_chroma_db_init_with_host_and_port(chroma_db): + settings = chroma_db.client.get_settings() + assert settings.chroma_server_host == "test-host" + assert settings.chroma_server_http_port == "1234" -class TestChromaDbCollection(unittest.TestCase): - chroma_config = ChromaDbConfig(allow_reset=True) - app_config = AppConfig(collection_name=False, collect_metrics=False) - app_with_settings = App(config=app_config, db_config=chroma_config) +def test_chroma_db_init_with_basic_auth(): + chroma_config = { + "host": "test-host", + "port": "1234", + "chroma_settings": { + "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider", + "chroma_client_auth_credentials": "admin:admin", + }, + } - def test_init_with_default_collection(self): - """ - Test if the `App` instance is initialized with the correct default collection name. - """ - app = App(config=AppConfig(collect_metrics=False)) + db = ChromaDB(config=ChromaDbConfig(**chroma_config)) + settings = db.client.get_settings() + assert settings.chroma_server_host == "test-host" + assert settings.chroma_server_http_port == "1234" + assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"] + assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"] - self.assertEqual(app.db.collection.name, "embedchain_store") - def test_init_with_custom_collection(self): - """ - Test if the `App` instance is initialized with the correct custom collection name. - """ - config = AppConfig(collect_metrics=False) - app = App(config=config) - app.set_collection_name(name="test_collection") +@patch("embedchain.vectordb.chroma.chromadb.Client") +def test_app_init_with_host_and_port(mock_client): + host = "test-host" + port = "1234" + config = AppConfig(collect_metrics=False) + db_config = ChromaDbConfig(host=host, port=port) + _app = App(config, db_config=db_config) - self.assertEqual(app.db.collection.name, "test_collection") + called_settings: Settings = mock_client.call_args[0][0] + assert called_settings.chroma_server_host == host + assert called_settings.chroma_server_http_port == port - def test_set_collection_name(self): - """ - Test if the `App` collection is correctly switched using the `set_collection_name` method. - """ - app = App(config=AppConfig(collect_metrics=False)) - app.set_collection_name("test_collection") - self.assertEqual(app.db.collection.name, "test_collection") +@patch("embedchain.vectordb.chroma.chromadb.Client") +def test_app_init_with_host_and_port_none(mock_client): + _app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) - def test_changes_encapsulated(self): - """ - Test that changes to one collection do not affect the other collection - """ - # Start with a clean app - self.app_with_settings.reset() + called_settings: Settings = mock_client.call_args[0][0] + assert called_settings.chroma_server_host is None + assert called_settings.chroma_server_http_port is None - app = App(config=AppConfig(collect_metrics=False)) - app.set_collection_name("test_collection_1") - # Collection should be empty when created - self.assertEqual(app.db.count(), 0) - app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) - # After adding, should contain one item - self.assertEqual(app.db.count(), 1) +def test_chroma_db_duplicates_throw_warning(caplog): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + assert "Insert of existing embedding ID: 0" in caplog.text + assert "Add of existing embedding ID: 0" in caplog.text + app.db.reset() - app.set_collection_name("test_collection_2") - # New collection is empty - self.assertEqual(app.db.count(), 0) - # Adding to new collection should not effect existing collection - app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) - app.set_collection_name("test_collection_1") - # Should still be 1, not 2. - self.assertEqual(app.db.count(), 1) +def test_chroma_db_duplicates_collections_no_warning(caplog): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.set_collection_name("test_collection_1") + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.set_collection_name("test_collection_2") + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + assert "Insert of existing embedding ID: 0" not in caplog.text + assert "Add of existing embedding ID: 0" not in caplog.text + app.db.reset() + app.set_collection_name("test_collection_1") + app.db.reset() - def test_add_with_skip_embedding(self): - """ - Test that changes to one collection do not affect the other collection - """ - # Start with a clean app - self.app_with_settings.reset() - # app = App(config=AppConfig(collect_metrics=False), db=db) - # Collection should be empty when created - self.assertEqual(self.app_with_settings.db.count(), 0) +def test_chroma_db_collection_init_with_default_collection(): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + assert app.db.collection.name == "embedchain_store" - self.app_with_settings.db.add( + +def test_chroma_db_collection_init_with_custom_collection(): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.set_collection_name(name="test_collection") + assert app.db.collection.name == "test_collection" + + +def test_chroma_db_collection_set_collection_name(): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.set_collection_name("test_collection") + assert app.db.collection.name == "test_collection" + + +def test_chroma_db_collection_changes_encapsulated(): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.set_collection_name("test_collection_1") + assert app.db.count() == 0 + + app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) + assert app.db.count() == 1 + + app.set_collection_name("test_collection_2") + assert app.db.count() == 0 + + app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) + app.set_collection_name("test_collection_1") + assert app.db.count() == 1 + app.db.reset() + app.set_collection_name("test_collection_2") + app.db.reset() + + +def test_chroma_db_collection_add_with_skip_embedding(app_with_settings): + # Start with a clean app + app_with_settings.db.reset() + + assert app_with_settings.db.count() == 0 + + app_with_settings.db.add( + embeddings=[[0, 0, 0]], + documents=["document"], + metadatas=[{"value": "somevalue"}], + ids=["id"], + skip_embedding=True, + ) + + assert app_with_settings.db.count() == 1 + + data = app_with_settings.db.get(["id"], limit=1) + expected_value = { + "documents": ["document"], + "embeddings": None, + "ids": ["id"], + "metadatas": [{"value": "somevalue"}], + } + + assert data == expected_value + + data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) + expected_value = ["document"] + + assert data == expected_value + app_with_settings.db.reset() + + +def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings): + # Start with a clean app + app_with_settings.db.reset() + + assert app_with_settings.db.count() == 0 + + with pytest.raises(ValueError): + app_with_settings.db.add( embeddings=[[0, 0, 0]], - documents=["document"], + documents=["document", "document2"], metadatas=[{"value": "somevalue"}], ids=["id"], skip_embedding=True, ) - # After adding, should contain one item - self.assertEqual(self.app_with_settings.db.count(), 1) - # Validate if the get utility of the database is working as expected - data = self.app_with_settings.db.get(["id"], limit=1) - expected_value = { - "documents": ["document"], - "embeddings": None, - "ids": ["id"], - "metadatas": [{"value": "somevalue"}], - } - self.assertEqual(data, expected_value) + assert app_with_settings.db.count() == 0 - # Validate if the query utility of the database is working as expected - data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) - expected_value = ["document"] - self.assertEqual(data, expected_value) + with pytest.raises(ValueError): + app_with_settings.db.add( + embeddings=None, + documents=["document", "document2"], + metadatas=[{"value": "somevalue"}], + ids=["id"], + skip_embedding=True, + ) - def test_add_with_invalid_inputs(self): - """ - Test add fails with invalid inputs - """ - # Start with a clean app - self.app_with_settings.reset() - # app = App(config=AppConfig(collect_metrics=False), db=db) + assert app_with_settings.db.count() == 0 + app_with_settings.db.reset() - # Collection should be empty when created - self.assertEqual(self.app_with_settings.db.count(), 0) - with self.assertRaises(ValueError): - self.app_with_settings.db.add( - embeddings=[[0, 0, 0]], - documents=["document", "document2"], - metadatas=[{"value": "somevalue"}], - ids=["id"], - skip_embedding=True, - ) - # After adding, should contain no item - self.assertEqual(self.app_with_settings.db.count(), 0) +def test_chroma_db_collection_collections_are_persistent(): + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.set_collection_name("test_collection_1") + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + del app - with self.assertRaises(ValueError): - self.app_with_settings.db.add( - embeddings=None, - documents=["document", "document2"], - metadatas=[{"value": "somevalue"}], - ids=["id"], - skip_embedding=True, - ) + app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")) + app.set_collection_name("test_collection_1") + assert app.db.count() == 1 - # After adding, should contain no item - self.assertEqual(self.app_with_settings.db.count(), 0) + app.db.reset() - def test_collections_are_persistent(self): - """ - Test that a collection can be picked up later. - """ - # Start with a clean app - self.app_with_settings.reset() - app = App(config=AppConfig(collect_metrics=False)) - app.set_collection_name("test_collection_1") - app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - del app +def test_chroma_db_collection_parallel_collections(): + app1 = App( + AppConfig(collection_name="test_collection_1", collect_metrics=False), + db_config=ChromaDbConfig(allow_reset=True, dir="test-db"), + ) + app2 = App( + AppConfig(collection_name="test_collection_2", collect_metrics=False), + db_config=ChromaDbConfig(allow_reset=True, dir="test-db"), + ) - app = App(config=AppConfig(collect_metrics=False)) - app.set_collection_name("test_collection_1") - self.assertEqual(app.db.count(), 1) + # cleanup if any previous tests failed or were interrupted + app1.db.reset() + app2.db.reset() - def test_parallel_collections(self): - """ - Test that two apps can have different collections open in parallel. - Switching the names will allow instant access to the collection of - the other app. - """ - # Start clean - self.app_with_settings.reset() + app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) + assert app1.db.count() == 1 + assert app2.db.count() == 0 - # Create two apps - app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False)) - app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False)) + app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) + app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) - # app2 has been created last, but adding to app1 will still write to collection 1. - app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) - self.assertEqual(app1.db.count(), 1) - self.assertEqual(app2.db.count(), 0) + app1.set_collection_name("test_collection_2") + assert app1.db.count() == 1 + app2.set_collection_name("test_collection_1") + assert app2.db.count() == 3 - # Add data - app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) - app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) + # cleanup + app1.db.reset() + app2.db.reset() - # Swap names and test - app1.set_collection_name("test_collection_2") - self.assertEqual(app1.count(), 1) - app2.set_collection_name("test_collection_1") - self.assertEqual(app2.count(), 3) - def test_ids_share_collections(self): - """ - Different ids should still share collections. - """ - # Start clean - self.app_with_settings.reset() +def test_chroma_db_collection_ids_share_collections(): + app1 = App( + AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db") + ) + app1.set_collection_name("one_collection") + app2 = App( + AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db") + ) + app2.set_collection_name("one_collection") - # Create two apps - app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) - app1.set_collection_name("one_collection") - app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) - app2.set_collection_name("one_collection") + app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) + app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) - # Add data - app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) - app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) + assert app1.db.count() == 3 + assert app2.db.count() == 3 - # Both should have the same collection - self.assertEqual(app1.count(), 3) - self.assertEqual(app2.count(), 3) + # cleanup + app1.db.reset() + app2.db.reset() - def test_reset(self): - """ - Resetting should hit all collections and ids. - """ - # Start clean - self.app_with_settings.reset() - # Create four apps. - # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. - app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config) - app1.set_collection_name("one_collection") - app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) - app2.set_collection_name("one_collection") - app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) - app3.set_collection_name("three_collection") - app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False)) - app4.set_collection_name("four_collection") +def test_chroma_db_collection_reset(): + app1 = App( + AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db") + ) + app1.set_collection_name("one_collection") + app2 = App( + AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db") + ) + app2.set_collection_name("two_collection") + app3 = App( + AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db") + ) + app3.set_collection_name("three_collection") + app4 = App( + AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db") + ) + app4.set_collection_name("four_collection") - # Each one of them get data - app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"]) - app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) - app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"]) - app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"]) + app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"]) + app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) + app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"]) + app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"]) - # Resetting the first one should reset them all. - app1.reset() + app1.db.reset() - # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319) - app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) - app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False)) - app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False)) + assert app1.db.count() == 0 + assert app2.db.count() == 1 + assert app3.db.count() == 1 + assert app4.db.count() == 1 - # All should be empty - self.assertEqual(app1.count(), 0) - self.assertEqual(app2.count(), 0) - self.assertEqual(app3.count(), 0) - self.assertEqual(app4.count(), 0) + # cleanup + app2.db.reset() + app3.db.reset() + app4.db.reset()