[Bug fix] reset() erases everything from db (#844)

This commit is contained in:
Sidharth Mohanty
2023-10-25 22:52:20 +05:30
committed by GitHub
parent 3ce2d8a656
commit bbce18caac
3 changed files with 242 additions and 307 deletions

1
.gitignore vendored
View File

@@ -165,6 +165,7 @@ cython_debug/
# Database
db
test-db
.vscode
/poetry.lock

View File

@@ -262,9 +262,9 @@ class ChromaDB(BaseVectorDB):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the database
# Delete all data from the collection
try:
self.client.reset()
self.client.delete_collection(self.config.collection_name)
except ValueError:
raise ValueError(
"For safety reasons, resetting is disabled. "

View File

@@ -1,366 +1,300 @@
# ruff: noqa: E501
import unittest
import os
import shutil
import pytest
from unittest.mock import patch
from chromadb.config import Settings
from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.vectordb.chroma import ChromaDB
class TestChromaDbHosts(unittest.TestCase):
def test_init_with_host_and_port(self):
"""
Test if the `ChromaDB` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = ChromaDbConfig(host=host, port=port)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
def test_init_with_basic_auth(self):
host = "test-host"
port = "1234"
chroma_config = {
"host": host,
"port": port,
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
config = ChromaDbConfig(**chroma_config)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
self.assertEqual(
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
)
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
os.environ["OPENAI_API_KEY"] = "test-api-key"
# Review this test
class TestChromaDbHostsInit(unittest.TestCase):
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, db_config=db_config)
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, host)
self.assertEqual(called_settings.chroma_server_http_port, port)
@pytest.fixture
def chroma_db():
return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
class TestChromaDbHostsNone(unittest.TestCase):
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_init_with_host_and_port_none(self, mock_client):
"""
Test if the `App` instance is initialized without default hosts and ports.
"""
_app = App(config=AppConfig(collect_metrics=False))
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, None)
self.assertEqual(called_settings.chroma_server_http_port, None)
@pytest.fixture
def app_with_settings():
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db_config=chroma_config)
class TestChromaDbHostsLoglevel(unittest.TestCase):
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_init_with_host_and_port_log_level(self, mock_client):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
_app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
@pytest.fixture(scope="session", autouse=True)
def cleanup_db():
yield
try:
shutil.rmtree("test-db")
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
class TestChromaDbDuplicateHandling:
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_duplicates_throw_warning(self, caplog):
"""
Test that add duplicates throws an error.
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
def test_duplicates_collections_no_warning(self, caplog):
"""
Test that different collections can have duplicates.
"""
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not
def test_chroma_db_init_with_host_and_port(chroma_db):
settings = chroma_db.client.get_settings()
assert settings.chroma_server_host == "test-host"
assert settings.chroma_server_http_port == "1234"
class TestChromaDbCollection(unittest.TestCase):
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_chroma_db_init_with_basic_auth():
chroma_config = {
"host": "test-host",
"port": "1234",
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
def test_init_with_default_collection(self):
"""
Test if the `App` instance is initialized with the correct default collection name.
"""
app = App(config=AppConfig(collect_metrics=False))
db = ChromaDB(config=ChromaDbConfig(**chroma_config))
settings = db.client.get_settings()
assert settings.chroma_server_host == "test-host"
assert settings.chroma_server_http_port == "1234"
assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
self.assertEqual(app.db.collection.name, "embedchain_store")
def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
app.set_collection_name(name="test_collection")
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port(mock_client):
host = "test-host"
port = "1234"
config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, db_config=db_config)
self.assertEqual(app.db.collection.name, "test_collection")
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == host
assert called_settings.chroma_server_http_port == port
def test_set_collection_name(self):
"""
Test if the `App` collection is correctly switched using the `set_collection_name` method.
"""
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection")
self.assertEqual(app.db.collection.name, "test_collection")
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port_none(mock_client):
_app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
def test_changes_encapsulated(self):
"""
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
self.app_with_settings.reset()
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host is None
assert called_settings.chroma_server_http_port is None
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1")
# Collection should be empty when created
self.assertEqual(app.db.count(), 0)
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# After adding, should contain one item
self.assertEqual(app.db.count(), 1)
def test_chroma_db_duplicates_throw_warning(caplog):
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
app.db.reset()
app.set_collection_name("test_collection_2")
# New collection is empty
self.assertEqual(app.db.count(), 0)
# Adding to new collection should not effect existing collection
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection_name("test_collection_1")
# Should still be 1, not 2.
self.assertEqual(app.db.count(), 1)
def test_chroma_db_duplicates_collections_no_warning(caplog):
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text
assert "Add of existing embedding ID: 0" not in caplog.text
app.db.reset()
app.set_collection_name("test_collection_1")
app.db.reset()
def test_add_with_skip_embedding(self):
"""
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
self.app_with_settings.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
# Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0)
def test_chroma_db_collection_init_with_default_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
assert app.db.collection.name == "embedchain_store"
self.app_with_settings.db.add(
def test_chroma_db_collection_init_with_custom_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_set_collection_name():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_changes_encapsulated():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
assert app.db.count() == 1
app.set_collection_name("test_collection_2")
assert app.db.count() == 0
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
app.set_collection_name("test_collection_2")
app.db.reset()
def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
# Start with a clean app
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
assert app_with_settings.db.count() == 1
data = app_with_settings.db.get(["id"], limit=1)
expected_value = {
"documents": ["document"],
"embeddings": None,
"ids": ["id"],
"metadatas": [{"value": "somevalue"}],
}
assert data == expected_value
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ["document"]
assert data == expected_value
app_with_settings.db.reset()
def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
# Start with a clean app
app_with_settings.db.reset()
assert app_with_settings.db.count() == 0
with pytest.raises(ValueError):
app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document"],
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
# After adding, should contain one item
self.assertEqual(self.app_with_settings.db.count(), 1)
# Validate if the get utility of the database is working as expected
data = self.app_with_settings.db.get(["id"], limit=1)
expected_value = {
"documents": ["document"],
"embeddings": None,
"ids": ["id"],
"metadatas": [{"value": "somevalue"}],
}
self.assertEqual(data, expected_value)
assert app_with_settings.db.count() == 0
# Validate if the query utility of the database is working as expected
data = self.app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
expected_value = ["document"]
self.assertEqual(data, expected_value)
with pytest.raises(ValueError):
app_with_settings.db.add(
embeddings=None,
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
def test_add_with_invalid_inputs(self):
"""
Test add fails with invalid inputs
"""
# Start with a clean app
self.app_with_settings.reset()
# app = App(config=AppConfig(collect_metrics=False), db=db)
assert app_with_settings.db.count() == 0
app_with_settings.db.reset()
# Collection should be empty when created
self.assertEqual(self.app_with_settings.db.count(), 0)
with self.assertRaises(ValueError):
self.app_with_settings.db.add(
embeddings=[[0, 0, 0]],
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
# After adding, should contain no item
self.assertEqual(self.app_with_settings.db.count(), 0)
def test_chroma_db_collection_collections_are_persistent():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
with self.assertRaises(ValueError):
self.app_with_settings.db.add(
embeddings=None,
documents=["document", "document2"],
metadatas=[{"value": "somevalue"}],
ids=["id"],
skip_embedding=True,
)
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
# After adding, should contain no item
self.assertEqual(self.app_with_settings.db.count(), 0)
app.db.reset()
def test_collections_are_persistent(self):
"""
Test that a collection can be picked up later.
"""
# Start with a clean app
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
def test_chroma_db_collection_parallel_collections():
app1 = App(
AppConfig(collection_name="test_collection_1", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
)
app2 = App(
AppConfig(collection_name="test_collection_2", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
)
app = App(config=AppConfig(collect_metrics=False))
app.set_collection_name("test_collection_1")
self.assertEqual(app.db.count(), 1)
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
def test_parallel_collections(self):
"""
Test that two apps can have different collections open in parallel.
Switching the names will allow instant access to the collection of
the other app.
"""
# Start clean
self.app_with_settings.reset()
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
assert app1.db.count() == 1
assert app2.db.count() == 0
# Create two apps
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# app2 has been created last, but adding to app1 will still write to collection 1.
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.db.count(), 1)
self.assertEqual(app2.db.count(), 0)
app1.set_collection_name("test_collection_2")
assert app1.db.count() == 1
app2.set_collection_name("test_collection_1")
assert app2.db.count() == 3
# Add data
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# cleanup
app1.db.reset()
app2.db.reset()
# Swap names and test
app1.set_collection_name("test_collection_2")
self.assertEqual(app1.count(), 1)
app2.set_collection_name("test_collection_1")
self.assertEqual(app2.count(), 3)
def test_ids_share_collections(self):
"""
Different ids should still share collections.
"""
# Start clean
self.app_with_settings.reset()
def test_chroma_db_collection_ids_share_collections():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app1.set_collection_name("one_collection")
app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app2.set_collection_name("one_collection")
# Create two apps
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection_name("one_collection")
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
# Add data
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
assert app1.db.count() == 3
assert app2.db.count() == 3
# Both should have the same collection
self.assertEqual(app1.count(), 3)
self.assertEqual(app2.count(), 3)
# cleanup
app1.db.reset()
app2.db.reset()
def test_reset(self):
"""
Resetting should hit all collections and ids.
"""
# Start clean
self.app_with_settings.reset()
# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config)
app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection_name("one_collection")
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app3.set_collection_name("three_collection")
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
app4.set_collection_name("four_collection")
def test_chroma_db_collection_reset():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app1.set_collection_name("one_collection")
app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app2.set_collection_name("two_collection")
app3 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app3.set_collection_name("three_collection")
app4 = App(
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
app4.set_collection_name("four_collection")
# Each one of them get data
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
# Resetting the first one should reset them all.
app1.reset()
app1.db.reset()
# Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))
assert app1.db.count() == 0
assert app2.db.count() == 1
assert app3.db.count() == 1
assert app4.db.count() == 1
# All should be empty
self.assertEqual(app1.count(), 0)
self.assertEqual(app2.count(), 0)
self.assertEqual(app3.count(), 0)
self.assertEqual(app4.count(), 0)
# cleanup
app2.db.reset()
app3.db.reset()
app4.db.reset()