refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -3,8 +3,10 @@
import unittest
from unittest.mock import patch
from chromadb.config import Settings
from embedchain import App
from embedchain.config import AppConfig, CustomAppConfig
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.vectordb.chroma_db import ChromaDB
@@ -16,8 +18,9 @@ class TestChromaDbHosts(unittest.TestCase):
"""
host = "test-host"
port = "1234"
config = ChromaDbConfig(host=host, port=port)
db = ChromaDB(host=host, port=port, embedding_fn=len)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
@@ -31,7 +34,8 @@ class TestChromaDbHosts(unittest.TestCase):
"chroma_client_auth_credentials": "admin:admin",
}
db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
@@ -44,37 +48,41 @@ class TestChromaDbHosts(unittest.TestCase):
# Review this test
class TestChromaDbHostsInit(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
def test_app_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = AppConfig(host=host, port=port, collect_metrics=False)
config = AppConfig(collect_metrics=False)
chromadb_config = ChromaDbConfig(host=host, port=port)
_app = App(config)
_app = App(config, chromadb_config=chromadb_config)
# self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
# self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, host)
self.assertEqual(called_settings.chroma_server_http_port, port)
class TestChromaDbHostsNone(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
def test_init_with_host_and_port_none(self, mock_client):
"""
Test if the `App` instance is initialized without default hosts and ports.
"""
_app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, None)
self.assertEqual(called_settings.chroma_server_http_port, None)
class TestChromaDbHostsLoglevel(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
def test_init_with_host_and_port_log_level(self, mock_client):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
@@ -87,11 +95,10 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
class TestChromaDbDuplicateHandling:
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
chroma_settings = {"allow_reset": True}
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
def test_duplicates_throw_warning(self, caplog):
"""
@@ -101,8 +108,8 @@ class TestChromaDbDuplicateHandling:
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
@@ -117,19 +124,18 @@ class TestChromaDbDuplicateHandling:
app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection("test_collection_2")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not
class TestChromaDbCollection(unittest.TestCase):
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
chroma_settings = {"allow_reset": True}
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
def test_init_with_default_collection(self):
"""
@@ -137,16 +143,17 @@ class TestChromaDbCollection(unittest.TestCase):
"""
app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(app.collection.name, "embedchain_store")
self.assertEqual(app.db.collection.name, "embedchain_store")
def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = AppConfig(collection_name="test_collection", collect_metrics=False)
app = App(config)
config = AppConfig(collect_metrics=False)
app = App(config=config)
app.set_collection(collection_name="test_collection")
self.assertEqual(app.collection.name, "test_collection")
self.assertEqual(app.db.collection.name, "test_collection")
def test_set_collection(self):
"""
@@ -155,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection")
self.assertEqual(app.collection.name, "test_collection")
self.assertEqual(app.db.collection.name, "test_collection")
def test_changes_encapsulated(self):
"""
@@ -169,7 +176,7 @@ class TestChromaDbCollection(unittest.TestCase):
# Collection should be empty when created
self.assertEqual(app.count(), 0)
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# After adding, should contain one item
self.assertEqual(app.count(), 1)
@@ -178,7 +185,7 @@ class TestChromaDbCollection(unittest.TestCase):
self.assertEqual(app.count(), 0)
# Adding to new collection should not effect existing collection
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection("test_collection_1")
# Should still be 1, not 2.
self.assertEqual(app.count(), 1)
@@ -192,7 +199,7 @@ class TestChromaDbCollection(unittest.TestCase):
app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
app = App(config=AppConfig(collect_metrics=False))
@@ -213,13 +220,13 @@ class TestChromaDbCollection(unittest.TestCase):
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
# app2 has been created last, but adding to app1 will still write to collection 1.
app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.count(), 1)
self.assertEqual(app2.count(), 0)
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.db.count(), 1)
self.assertEqual(app2.db.count(), 0)
# Add data
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# Swap names and test
app1.set_collection("test_collection_2")
@@ -235,12 +242,14 @@ class TestChromaDbCollection(unittest.TestCase):
self.app_with_settings.reset()
# Create two apps
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app1.set_collection("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection("one_collection")
# Add data
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
# Both should have the same collection
self.assertEqual(app1.count(), 3)
@@ -255,25 +264,20 @@ class TestChromaDbCollection(unittest.TestCase):
# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(
CustomAppConfig(
collection_name="one_collection",
id="new_app_id_1",
collect_metrics=False,
provider=Providers.OPENAI,
embedding_fn=EmbeddingFunctions.OPENAI,
chroma_settings={"allow_reset": True},
)
)
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
app1.set_collection("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection("one_collection")
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app3.set_collection("three_collection")
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
app4.set_collection("four_collection")
# Each one of them get data
app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
# Resetting the first one should reset them all.
app1.reset()

View File

@@ -1,7 +1,7 @@
import unittest
from unittest.mock import Mock
from embedchain.config import ElasticsearchDBConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
@@ -10,24 +10,20 @@ class TestEsDB(unittest.TestCase):
self.es_config = ElasticsearchDBConfig()
self.vector_dim = 384
def test_init_with_invalid_embedding_fn(self):
# Test if an exception is raised when an invalid embedding_fn is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=None)
def test_init_with_invalid_es_config(self):
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock(), es_config=None)
ElasticsearchDB(es_config=None)
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
embedder.set_vector_dimension(None)
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None)
ElasticsearchDB(es_config=self.es_config)
def test_init_with_invalid_collection_name(self):
# Test if an exception is raised when an invalid collection_name is provided
self.es_config.collection_name = None
with self.assertRaises(ValueError):
ElasticsearchDB(
embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None
)
ElasticsearchDB(es_config=self.es_config)