[Bug fix] reset() erases everything from db (#844)

This commit is contained in:
Sidharth Mohanty
2023-10-25 22:52:20 +05:30
committed by GitHub
parent 3ce2d8a656
commit bbce18caac
3 changed files with 242 additions and 307 deletions

1
.gitignore vendored
View File

@@ -165,6 +165,7 @@ cython_debug/
# Database # Database
db db
test-db
.vscode .vscode
/poetry.lock /poetry.lock

View File

@@ -262,9 +262,9 @@ class ChromaDB(BaseVectorDB):
""" """
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
""" """
# Delete all data from the database # Delete all data from the collection
try: try:
self.client.reset() self.client.delete_collection(self.config.collection_name)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"For safety reasons, resetting is disabled. " "For safety reasons, resetting is disabled. "

View File

@@ -1,366 +1,300 @@
# ruff: noqa: E501 import os
import shutil
import unittest import pytest
from unittest.mock import patch from unittest.mock import patch
from chromadb.config import Settings from chromadb.config import Settings
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
class TestChromaDbHosts(unittest.TestCase):
def test_init_with_host_and_port(self):
"""
Test if the `ChromaDB` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = ChromaDbConfig(host=host, port=port)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
def test_init_with_basic_auth(self):
host = "test-host"
port = "1234"
chroma_config = {
"host": host,
"port": port,
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
config = ChromaDbConfig(**chroma_config)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
self.assertEqual(
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
)
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
# Review this test @pytest.fixture
class TestChromaDbHostsInit(unittest.TestCase): def chroma_db():
@patch("embedchain.vectordb.chroma.chromadb.Client") return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
def test_app_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, db_config=db_config)
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, host)
self.assertEqual(called_settings.chroma_server_http_port, port)
class TestChromaDbHostsNone(unittest.TestCase): @pytest.fixture
@patch("embedchain.vectordb.chroma.chromadb.Client") def app_with_settings():
def test_init_with_host_and_port_none(self, mock_client): chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
""" app_config = AppConfig(collect_metrics=False)
Test if the `App` instance is initialized without default hosts and ports. return App(config=app_config, db_config=chroma_config)
"""
_app = App(config=AppConfig(collect_metrics=False))
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, None)
self.assertEqual(called_settings.chroma_server_http_port, None)
class TestChromaDbHostsLoglevel(unittest.TestCase): @pytest.fixture(scope="session", autouse=True)
@patch("embedchain.vectordb.chroma.chromadb.Client") def cleanup_db():
def test_init_with_host_and_port_log_level(self, mock_client): yield
""" try:
Test if the `App` instance is initialized without a config that does not contain default hosts and ports. shutil.rmtree("test-db")
""" except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
_app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
class TestChromaDbDuplicateHandling: def test_chroma_db_init_with_host_and_port(chroma_db):
chroma_config = ChromaDbConfig(allow_reset=True) settings = chroma_db.client.get_settings()
app_config = AppConfig(collection_name=False, collect_metrics=False) assert settings.chroma_server_host == "test-host"
app_with_settings = App(config=app_config, db_config=chroma_config) assert settings.chroma_server_http_port == "1234"
def test_duplicates_throw_warning(self, caplog):
"""
Test that add duplicates throws an error.
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
def test_duplicates_collections_no_warning(self, caplog):
"""
Test that different collections can have duplicates.
"""
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not
class TestChromaDbCollection(unittest.TestCase): def test_chroma_db_init_with_basic_auth():
chroma_config = ChromaDbConfig(allow_reset=True) chroma_config = {
app_config = AppConfig(collection_name=False, collect_metrics=False) "host": "test-host",
app_with_settings = App(config=app_config, db_config=chroma_config) "port": "1234",
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
def test_init_with_default_collection(self): db = ChromaDB(config=ChromaDbConfig(**chroma_config))
""" settings = db.client.get_settings()
Test if the `App` instance is initialized with the correct default collection name. assert settings.chroma_server_host == "test-host"
""" assert settings.chroma_server_http_port == "1234"
app = App(config=AppConfig(collect_metrics=False)) assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
self.assertEqual(app.db.collection.name, "embedchain_store")
def test_init_with_custom_collection(self): @patch("embedchain.vectordb.chroma.chromadb.Client")
""" def test_app_init_with_host_and_port(mock_client):
Test if the `App` instance is initialized with the correct custom collection name. host = "test-host"
""" port = "1234"
config = AppConfig(collect_metrics=False) config = AppConfig(collect_metrics=False)
app = App(config=config) db_config = ChromaDbConfig(host=host, port=port)
app.set_collection_name(name="test_collection") _app = App(config, db_config=db_config)
self.assertEqual(app.db.collection.name, "test_collection") called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == host
assert called_settings.chroma_server_http_port == port
def test_set_collection_name(self):
"""
Test if the `App` collection is correctly switched using the `set_collection_name` method.
"""
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection")
self.assertEqual(app.db.collection.name, "test_collection") @patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port_none(mock_client):
_app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
def test_changes_encapsulated(self): called_settings: Settings = mock_client.call_args[0][0]
""" assert called_settings.chroma_server_host is None
Test that changes to one collection do not affect the other collection assert called_settings.chroma_server_http_port is None
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1")
# Collection should be empty when created
self.assertEqual(app.db.count(), 0)
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) def test_chroma_db_duplicates_throw_warning(caplog):
# After adding, should contain one item app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
self.assertEqual(app.db.count(), 1) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
app.db.reset()
app.set_collection_name("test_collection_2")
# New collection is empty
self.assertEqual(app.db.count(), 0)
# Adding to new collection should not effect existing collection def test_chroma_db_duplicates_collections_no_warning(caplog):
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1") app.set_collection_name("test_collection_1")
# Should still be 1, not 2. app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
self.assertEqual(app.db.count(), 1) app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text
assert "Add of existing embedding ID: 0" not in caplog.text
app.db.reset()
app.set_collection_name("test_collection_1")
app.db.reset()
def test_add_with_skip_embedding(self):
"""
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
self.app_with_settings.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created def test_chroma_db_collection_init_with_default_collection():
self.assertEqual(self.app_with_settings.db.count(), 0) app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
assert app.db.collection.name == "embedchain_store"
self.app_with_settings.db.add(
def test_chroma_db_collection_init_with_custom_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_set_collection_name():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_changes_encapsulated():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
assert app.db.count() == 1
app.set_collection_name("test_collection_2")
assert app.db.count() == 0
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
app.set_collection_name("test_collection_2")
app.db.reset()
def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
# Start with a clean app
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 1
data = app_with_settings.db.get(["id"], limit=1)
expected_value = {
"documents": ["document"],
"embeddings": None,
"ids": ["id"],
"metadatas": [{"value": "somevalue"}],
}
assert data == expected_value
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ["document"]
assert data == expected_value
app_with_settings.db.reset()
def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
# Start with a clean app
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
with pytest.raises(ValueError):
app_with_settings.db.add(
embeddings=[[0, 0, 0]], embeddings=[[0, 0, 0]],
documents=["document"], documents=["document", "document2"],
metadatas=[{"value": "somevalue"}], metadatas=[{"value": "somevalue"}],
ids=["id"], ids=["id"],
skip_embedding=True, skip_embedding=True,
) )
# After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected assert app_with_settings.db.count() == 0
data = self.app_with_settings.db.get(["id"], limit=1)
expected_value = {
"documents": ["document"],
"embeddings": None,
"ids": ["id"],
"metadatas": [{"value": "somevalue"}],
}
self.assertEqual(data, expected_value)
# Validate if the query utility of the database is working as expected with pytest.raises(ValueError):
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) app_with_settings.db.add(
expected_value = ["document"] embeddings=None,
self.assertEqual(data, expected_value) documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
def test_add_with_invalid_inputs(self): assert app_with_settings.db.count() == 0
""" app_with_settings.db.reset()
Test add fails with invalid inputs
"""
# Start with a clean app
self.app_with_settings.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0)
with self.assertRaises(ValueError): def test_chroma_db_collection_collections_are_persistent():
self.app_with_settings.db.add( app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
embeddings=[[0, 0, 0]], app.set_collection_name("test_collection_1")
documents=["document", "document2"], app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
metadatas=[{"value": "somevalue"}], del app
ids=["id"],
skip_embedding=True,
)
# After adding, should contain no item
self.assertEqual(self.app_with_settings.db.count(), 0)
with self.assertRaises(ValueError): app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
self.app_with_settings.db.add( app.set_collection_name("test_collection_1")
embeddings=None, assert app.db.count() == 1
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
# After adding, should contain no item app.db.reset()
self.assertEqual(self.app_with_settings.db.count(), 0)
def test_collections_are_persistent(self):
"""
Test that a collection can be picked up later.
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) def test_chroma_db_collection_parallel_collections():
app.set_collection_name("test_collection_1") app1 = App(
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) AppConfig(collection_name="test_collection_1", collect_metrics=False),
del app db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
)
app2 = App(
AppConfig(collection_name="test_collection_2", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
)
app = App(config=AppConfig(collect_metrics=False)) # cleanup if any previous tests failed or were interrupted
app.set_collection_name("test_collection_1") app1.db.reset()
self.assertEqual(app.db.count(), 1) app2.db.reset()
def test_parallel_collections(self): app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
""" assert app1.db.count() == 1
Test that two apps can have different collections open in parallel. assert app2.db.count() == 0
Switching the names will allow instant access to the collection of
the other app.
"""
# Start clean
self.app_with_settings.reset()
# Create two apps app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False)) app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
# app2 has been created last, but adding to app1 will still write to collection 1. app1.set_collection_name("test_collection_2")
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) assert app1.db.count() == 1
self.assertEqual(app1.db.count(), 1) app2.set_collection_name("test_collection_1")
self.assertEqual(app2.db.count(), 0) assert app2.db.count() == 3
# Add data # cleanup
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) app1.db.reset()
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app2.db.reset()
# Swap names and test
app1.set_collection_name("test_collection_2")
self.assertEqual(app1.count(), 1)
app2.set_collection_name("test_collection_1")
self.assertEqual(app2.count(), 3)
def test_ids_share_collections(self): def test_chroma_db_collection_ids_share_collections():
""" app1 = App(
Different ids should still share collections. AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
""" )
# Start clean app1.set_collection_name("one_collection")
self.app_with_settings.reset() app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app2.set_collection_name("one_collection")
# Create two apps app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection_name("one_collection")
# Add data assert app1.db.count() == 3
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) assert app2.db.count() == 3
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
# Both should have the same collection # cleanup
self.assertEqual(app1.count(), 3) app1.db.reset()
self.assertEqual(app2.count(), 3) app2.db.reset()
def test_reset(self):
"""
Resetting should hit all collections and ids.
"""
# Start clean
self.app_with_settings.reset()
# Create four apps. def test_chroma_db_collection_reset():
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. app1 = App(
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config) AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
app1.set_collection_name("one_collection") )
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) app1.set_collection_name("one_collection")
app2.set_collection_name("one_collection") app2 = App(
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
app3.set_collection_name("three_collection") )
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False)) app2.set_collection_name("two_collection")
app4.set_collection_name("four_collection") app3 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app3.set_collection_name("three_collection")
app4 = App(
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app4.set_collection_name("four_collection")
# Each one of them get data app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"]) app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"]) app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
# Resetting the first one should reset them all. app1.db.reset()
app1.reset()
# Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319) assert app1.db.count() == 0
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) assert app2.db.count() == 1
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False)) assert app3.db.count() == 1
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False)) assert app4.db.count() == 1
# All should be empty # cleanup
self.assertEqual(app1.count(), 0) app2.db.reset()
self.assertEqual(app2.count(), 0) app3.db.reset()
self.assertEqual(app3.count(), 0) app4.db.reset()
self.assertEqual(app4.count(), 0)