[Bug fix] reset() erases everything from db (#844)

This commit is contained in:
Sidharth Mohanty
2023-10-25 22:52:20 +05:30
committed by GitHub
parent 3ce2d8a656
commit bbce18caac
3 changed files with 242 additions and 307 deletions

1
.gitignore vendored
View File

@@ -165,6 +165,7 @@ cython_debug/
# Database # Database
db db
test-db
.vscode .vscode
/poetry.lock /poetry.lock

View File

@@ -262,9 +262,9 @@ class ChromaDB(BaseVectorDB):
""" """
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
""" """
# Delete all data from the database # Delete all data from the collection
try: try:
self.client.reset() self.client.delete_collection(self.config.collection_name)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"For safety reasons, resetting is disabled. " "For safety reasons, resetting is disabled. "

View File

@@ -1,257 +1,193 @@
# ruff: noqa: E501 import os
import shutil
import unittest import pytest
from unittest.mock import patch from unittest.mock import patch
from chromadb.config import Settings from chromadb.config import Settings
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
class TestChromaDbHosts(unittest.TestCase):
def test_init_with_host_and_port(self):
"""
Test if the `ChromaDB` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = ChromaDbConfig(host=host, port=port)
db = ChromaDB(config=config) @pytest.fixture
settings = db.client.get_settings() def chroma_db():
self.assertEqual(settings.chroma_server_host, host) return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
self.assertEqual(settings.chroma_server_http_port, port)
def test_init_with_basic_auth(self):
host = "test-host"
port = "1234"
@pytest.fixture
def app_with_settings():
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db_config=chroma_config)
@pytest.fixture(scope="session", autouse=True)
def cleanup_db():
yield
try:
shutil.rmtree("test-db")
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
def test_chroma_db_init_with_host_and_port(chroma_db):
settings = chroma_db.client.get_settings()
assert settings.chroma_server_host == "test-host"
assert settings.chroma_server_http_port == "1234"
def test_chroma_db_init_with_basic_auth():
chroma_config = { chroma_config = {
"host": host, "host": "test-host",
"port": port, "port": "1234",
"chroma_settings": { "chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider", "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin", "chroma_client_auth_credentials": "admin:admin",
}, },
} }
config = ChromaDbConfig(**chroma_config) db = ChromaDB(config=ChromaDbConfig(**chroma_config))
db = ChromaDB(config=config)
settings = db.client.get_settings() settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host) assert settings.chroma_server_host == "test-host"
self.assertEqual(settings.chroma_server_http_port, port) assert settings.chroma_server_http_port == "1234"
self.assertEqual( assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"] assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
# Review this test
class TestChromaDbHostsInit(unittest.TestCase):
@patch("embedchain.vectordb.chroma.chromadb.Client") @patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port(self, mock_client): def test_app_init_with_host_and_port(mock_client):
"""
Test if the `App` instance is initialized with the correct host and port values.
"""
host = "test-host" host = "test-host"
port = "1234" port = "1234"
config = AppConfig(collect_metrics=False) config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port) db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, db_config=db_config) _app = App(config, db_config=db_config)
called_settings: Settings = mock_client.call_args[0][0] called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == host
self.assertEqual(called_settings.chroma_server_host, host) assert called_settings.chroma_server_http_port == port
self.assertEqual(called_settings.chroma_server_http_port, port)
class TestChromaDbHostsNone(unittest.TestCase):
@patch("embedchain.vectordb.chroma.chromadb.Client") @patch("embedchain.vectordb.chroma.chromadb.Client")
def test_init_with_host_and_port_none(self, mock_client): def test_app_init_with_host_and_port_none(mock_client):
""" _app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
Test if the `App` instance is initialized without default hosts and ports.
"""
_app = App(config=AppConfig(collect_metrics=False))
called_settings: Settings = mock_client.call_args[0][0] called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, None) assert called_settings.chroma_server_host is None
self.assertEqual(called_settings.chroma_server_http_port, None) assert called_settings.chroma_server_http_port is None
class TestChromaDbHostsLoglevel(unittest.TestCase): def test_chroma_db_duplicates_throw_warning(caplog):
@patch("embedchain.vectordb.chroma.chromadb.Client") app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
def test_init_with_host_and_port_log_level(self, mock_client):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
_app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
class TestChromaDbDuplicateHandling:
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_duplicates_throw_warning(self, caplog):
"""
Test that add duplicates throws an error.
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text assert "Add of existing embedding ID: 0" in caplog.text
app.db.reset()
def test_duplicates_collections_no_warning(self, caplog):
"""
Test that different collections can have duplicates.
"""
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
# Start with a clean app def test_chroma_db_duplicates_collections_no_warning(caplog):
self.app_with_settings.reset() app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1") app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2") app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not assert "Insert of existing embedding ID: 0" not in caplog.text
assert "Add of existing embedding ID: 0" not in caplog.text # not assert "Add of existing embedding ID: 0" not in caplog.text
app.db.reset()
app.set_collection_name("test_collection_1")
app.db.reset()
class TestChromaDbCollection(unittest.TestCase): def test_chroma_db_collection_init_with_default_collection():
chroma_config = ChromaDbConfig(allow_reset=True) app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app_config = AppConfig(collection_name=False, collect_metrics=False) assert app.db.collection.name == "embedchain_store"
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_init_with_default_collection(self):
"""
Test if the `App` instance is initialized with the correct default collection name.
"""
app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(app.db.collection.name, "embedchain_store") def test_chroma_db_collection_init_with_custom_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
app.set_collection_name(name="test_collection") app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
self.assertEqual(app.db.collection.name, "test_collection")
def test_set_collection_name(self): def test_chroma_db_collection_set_collection_name():
""" app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
Test if the `App` collection is correctly switched using the `set_collection_name` method.
"""
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection") app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
self.assertEqual(app.db.collection.name, "test_collection")
def test_changes_encapsulated(self): def test_chroma_db_collection_changes_encapsulated():
""" app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1") app.set_collection_name("test_collection_1")
# Collection should be empty when created assert app.db.count() == 0
self.assertEqual(app.db.count(), 0)
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# After adding, should contain one item assert app.db.count() == 1
self.assertEqual(app.db.count(), 1)
app.set_collection_name("test_collection_2") app.set_collection_name("test_collection_2")
# New collection is empty assert app.db.count() == 0
self.assertEqual(app.db.count(), 0)
# Adding to new collection should not effect existing collection
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection_name("test_collection_1") app.set_collection_name("test_collection_1")
# Should still be 1, not 2. assert app.db.count() == 1
self.assertEqual(app.db.count(), 1) app.db.reset()
app.set_collection_name("test_collection_2")
app.db.reset()
def test_add_with_skip_embedding(self):
""" def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app # Start with a clean app
self.app_with_settings.reset() app_with_settings.db.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created assert app_with_settings.db.count() == 0
self.assertEqual(self.app_with_settings.db.count(), 0)
self.app_with_settings.db.add( app_with_settings.db.add(
embeddings=[[0, 0, 0]], embeddings=[[0, 0, 0]],
documents=["document"], documents=["document"],
metadatas=[{"value": "somevalue"}], metadatas=[{"value": "somevalue"}],
ids=["id"], ids=["id"],
skip_embedding=True, skip_embedding=True,
) )
# After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected assert app_with_settings.db.count() == 1
data = self.app_with_settings.db.get(["id"], limit=1)
data = app_with_settings.db.get(["id"], limit=1)
expected_value = { expected_value = {
"documents": ["document"], "documents": ["document"],
"embeddings": None, "embeddings": None,
"ids": ["id"], "ids": ["id"],
"metadatas": [{"value": "somevalue"}], "metadatas": [{"value": "somevalue"}],
} }
self.assertEqual(data, expected_value)
# Validate if the query utility of the database is working as expected assert data == expected_value
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ["document"] expected_value = ["document"]
self.assertEqual(data, expected_value)
def test_add_with_invalid_inputs(self): assert data == expected_value
""" app_with_settings.db.reset()
Test add fails with invalid inputs
"""
def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
# Start with a clean app # Start with a clean app
self.app_with_settings.reset() app_with_settings.db.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created assert app_with_settings.db.count() == 0
self.assertEqual(self.app_with_settings.db.count(), 0)
with self.assertRaises(ValueError): with pytest.raises(ValueError):
self.app_with_settings.db.add( app_with_settings.db.add(
embeddings=[[0, 0, 0]], embeddings=[[0, 0, 0]],
documents=["document", "document2"], documents=["document", "document2"],
metadatas=[{"value": "somevalue"}], metadatas=[{"value": "somevalue"}],
ids=["id"], ids=["id"],
skip_embedding=True, skip_embedding=True,
) )
# After adding, should contain no item
self.assertEqual(self.app_with_settings.db.count(), 0)
with self.assertRaises(ValueError): assert app_with_settings.db.count() == 0
self.app_with_settings.db.add(
with pytest.raises(ValueError):
app_with_settings.db.add(
embeddings=None, embeddings=None,
documents=["document", "document2"], documents=["document", "document2"],
metadatas=[{"value": "somevalue"}], metadatas=[{"value": "somevalue"}],
@@ -259,108 +195,106 @@ class TestChromaDbCollection(unittest.TestCase):
skip_embedding=True, skip_embedding=True,
) )
# After adding, should contain no item assert app_with_settings.db.count() == 0
self.assertEqual(self.app_with_settings.db.count(), 0) app_with_settings.db.reset()
def test_collections_are_persistent(self):
"""
Test that a collection can be picked up later.
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) def test_chroma_db_collection_collections_are_persistent():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1") app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app del app
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1") app.set_collection_name("test_collection_1")
self.assertEqual(app.db.count(), 1) assert app.db.count() == 1
def test_parallel_collections(self): app.db.reset()
"""
Test that two apps can have different collections open in parallel.
Switching the names will allow instant access to the collection of
the other app.
"""
# Start clean
self.app_with_settings.reset()
# Create two apps
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
# app2 has been created last, but adding to app1 will still write to collection 1. def test_chroma_db_collection_parallel_collections():
app1 = App(
AppConfig(collection_name="test_collection_1", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
)
app2 = App(
AppConfig(collection_name="test_collection_2", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
)
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.db.count(), 1) assert app1.db.count() == 1
self.assertEqual(app2.db.count(), 0) assert app2.db.count() == 0
# Add data
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# Swap names and test
app1.set_collection_name("test_collection_2") app1.set_collection_name("test_collection_2")
self.assertEqual(app1.count(), 1) assert app1.db.count() == 1
app2.set_collection_name("test_collection_1") app2.set_collection_name("test_collection_1")
self.assertEqual(app2.count(), 3) assert app2.db.count() == 3
def test_ids_share_collections(self): # cleanup
""" app1.db.reset()
Different ids should still share collections. app2.db.reset()
"""
# Start clean
self.app_with_settings.reset()
# Create two apps
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) def test_chroma_db_collection_ids_share_collections():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app1.set_collection_name("one_collection") app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app2.set_collection_name("one_collection") app2.set_collection_name("one_collection")
# Add data
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
# Both should have the same collection assert app1.db.count() == 3
self.assertEqual(app1.count(), 3) assert app2.db.count() == 3
self.assertEqual(app2.count(), 3)
def test_reset(self): # cleanup
""" app1.db.reset()
Resetting should hit all collections and ids. app2.db.reset()
"""
# Start clean
self.app_with_settings.reset()
# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. def test_chroma_db_collection_reset():
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config) app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app1.set_collection_name("one_collection") app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) app2 = App(
app2.set_collection_name("one_collection") AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) )
app2.set_collection_name("two_collection")
app3 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app3.set_collection_name("three_collection") app3.set_collection_name("three_collection")
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False)) app4 = App(
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app4.set_collection_name("four_collection") app4.set_collection_name("four_collection")
# Each one of them get data
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"]) app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"]) app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"]) app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
# Resetting the first one should reset them all. app1.db.reset()
app1.reset()
# Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319) assert app1.db.count() == 0
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) assert app2.db.count() == 1
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False)) assert app3.db.count() == 1
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False)) assert app4.db.count() == 1
# All should be empty # cleanup
self.assertEqual(app1.count(), 0) app2.db.reset()
self.assertEqual(app2.count(), 0) app3.db.reset()
self.assertEqual(app3.count(), 0) app4.db.reset()
self.assertEqual(app4.count(), 0)