refactor: classes and configs (#528)
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, CustomAppConfig
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
@@ -16,8 +18,9 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
"""
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
config = ChromaDbConfig(host=host, port=port)
|
||||
|
||||
db = ChromaDB(host=host, port=port, embedding_fn=len)
|
||||
db = ChromaDB(config=config)
|
||||
settings = db.client.get_settings()
|
||||
self.assertEqual(settings.chroma_server_host, host)
|
||||
self.assertEqual(settings.chroma_server_http_port, port)
|
||||
@@ -31,7 +34,8 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
}
|
||||
|
||||
db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
|
||||
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
|
||||
db = ChromaDB(config=config)
|
||||
settings = db.client.get_settings()
|
||||
self.assertEqual(settings.chroma_server_host, host)
|
||||
self.assertEqual(settings.chroma_server_http_port, port)
|
||||
@@ -44,37 +48,41 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
# Review this test
|
||||
class TestChromaDbHostsInit(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||
def test_init_with_host_and_port(self, mock_client):
|
||||
def test_app_init_with_host_and_port(self, mock_client):
|
||||
"""
|
||||
Test if the `App` instance is initialized with the correct host and port values.
|
||||
"""
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
|
||||
config = AppConfig(host=host, port=port, collect_metrics=False)
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chromadb_config = ChromaDbConfig(host=host, port=port)
|
||||
|
||||
_app = App(config)
|
||||
_app = App(config, chromadb_config=chromadb_config)
|
||||
|
||||
# self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
|
||||
# self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
|
||||
self.assertEqual(called_settings.chroma_server_host, host)
|
||||
self.assertEqual(called_settings.chroma_server_http_port, port)
|
||||
|
||||
|
||||
class TestChromaDbHostsNone(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||
def test_init_with_host_and_port(self, mock_client):
|
||||
def test_init_with_host_and_port_none(self, mock_client):
|
||||
"""
|
||||
Test if the `App` instance is initialized without default hosts and ports.
|
||||
"""
|
||||
|
||||
_app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
|
||||
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
self.assertEqual(called_settings.chroma_server_host, None)
|
||||
self.assertEqual(called_settings.chroma_server_http_port, None)
|
||||
|
||||
|
||||
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||
def test_init_with_host_and_port(self, mock_client):
|
||||
def test_init_with_host_and_port_log_level(self, mock_client):
|
||||
"""
|
||||
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
||||
"""
|
||||
@@ -87,11 +95,10 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
|
||||
|
||||
class TestChromaDbDuplicateHandling:
|
||||
app_with_settings = App(
|
||||
CustomAppConfig(
|
||||
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||
)
|
||||
)
|
||||
chroma_settings = {"allow_reset": True}
|
||||
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
|
||||
def test_duplicates_throw_warning(self, caplog):
|
||||
"""
|
||||
@@ -101,8 +108,8 @@ class TestChromaDbDuplicateHandling:
|
||||
self.app_with_settings.reset()
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" in caplog.text
|
||||
assert "Add of existing embedding ID: 0" in caplog.text
|
||||
|
||||
@@ -117,19 +124,18 @@ class TestChromaDbDuplicateHandling:
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.set_collection("test_collection_1")
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.set_collection("test_collection_2")
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" not in caplog.text # not
|
||||
assert "Add of existing embedding ID: 0" not in caplog.text # not
|
||||
|
||||
|
||||
class TestChromaDbCollection(unittest.TestCase):
|
||||
app_with_settings = App(
|
||||
CustomAppConfig(
|
||||
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||
)
|
||||
)
|
||||
chroma_settings = {"allow_reset": True}
|
||||
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
|
||||
def test_init_with_default_collection(self):
|
||||
"""
|
||||
@@ -137,16 +143,17 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
"""
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
self.assertEqual(app.collection.name, "embedchain_store")
|
||||
self.assertEqual(app.db.collection.name, "embedchain_store")
|
||||
|
||||
def test_init_with_custom_collection(self):
|
||||
"""
|
||||
Test if the `App` instance is initialized with the correct custom collection name.
|
||||
"""
|
||||
config = AppConfig(collection_name="test_collection", collect_metrics=False)
|
||||
app = App(config)
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
app.set_collection(collection_name="test_collection")
|
||||
|
||||
self.assertEqual(app.collection.name, "test_collection")
|
||||
self.assertEqual(app.db.collection.name, "test_collection")
|
||||
|
||||
def test_set_collection(self):
|
||||
"""
|
||||
@@ -155,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.set_collection("test_collection")
|
||||
|
||||
self.assertEqual(app.collection.name, "test_collection")
|
||||
self.assertEqual(app.db.collection.name, "test_collection")
|
||||
|
||||
def test_changes_encapsulated(self):
|
||||
"""
|
||||
@@ -169,7 +176,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
# Collection should be empty when created
|
||||
self.assertEqual(app.count(), 0)
|
||||
|
||||
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
# After adding, should contain one item
|
||||
self.assertEqual(app.count(), 1)
|
||||
|
||||
@@ -178,7 +185,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
self.assertEqual(app.count(), 0)
|
||||
|
||||
# Adding to new collection should not effect existing collection
|
||||
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.set_collection("test_collection_1")
|
||||
# Should still be 1, not 2.
|
||||
self.assertEqual(app.count(), 1)
|
||||
@@ -192,7 +199,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.set_collection("test_collection_1")
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
del app
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
@@ -213,13 +220,13 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
|
||||
|
||||
# app2 has been created last, but adding to app1 will still write to collection 1.
|
||||
app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
self.assertEqual(app1.count(), 1)
|
||||
self.assertEqual(app2.count(), 0)
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
self.assertEqual(app1.db.count(), 1)
|
||||
self.assertEqual(app2.db.count(), 0)
|
||||
|
||||
# Add data
|
||||
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
|
||||
app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
|
||||
# Swap names and test
|
||||
app1.set_collection("test_collection_2")
|
||||
@@ -235,12 +242,14 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
self.app_with_settings.reset()
|
||||
|
||||
# Create two apps
|
||||
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
|
||||
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
|
||||
app1.set_collection("one_collection")
|
||||
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
|
||||
app2.set_collection("one_collection")
|
||||
|
||||
# Add data
|
||||
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
|
||||
# Both should have the same collection
|
||||
self.assertEqual(app1.count(), 3)
|
||||
@@ -255,25 +264,20 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
|
||||
# Create four apps.
|
||||
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
||||
app1 = App(
|
||||
CustomAppConfig(
|
||||
collection_name="one_collection",
|
||||
id="new_app_id_1",
|
||||
collect_metrics=False,
|
||||
provider=Providers.OPENAI,
|
||||
embedding_fn=EmbeddingFunctions.OPENAI,
|
||||
chroma_settings={"allow_reset": True},
|
||||
)
|
||||
)
|
||||
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
|
||||
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
|
||||
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
|
||||
app1.set_collection("one_collection")
|
||||
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
|
||||
app2.set_collection("one_collection")
|
||||
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
|
||||
app3.set_collection("three_collection")
|
||||
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
|
||||
app4.set_collection("four_collection")
|
||||
|
||||
# Each one of them get data
|
||||
app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
|
||||
app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
|
||||
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
|
||||
|
||||
# Resetting the first one should reset them all.
|
||||
app1.reset()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
|
||||
|
||||
|
||||
@@ -10,24 +10,20 @@ class TestEsDB(unittest.TestCase):
|
||||
self.es_config = ElasticsearchDBConfig()
|
||||
self.vector_dim = 384
|
||||
|
||||
def test_init_with_invalid_embedding_fn(self):
|
||||
# Test if an exception is raised when an invalid embedding_fn is provided
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(embedding_fn=None)
|
||||
|
||||
def test_init_with_invalid_es_config(self):
|
||||
# Test if an exception is raised when an invalid es_config is provided
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(embedding_fn=Mock(), es_config=None)
|
||||
ElasticsearchDB(es_config=None)
|
||||
|
||||
def test_init_with_invalid_vector_dim(self):
|
||||
# Test if an exception is raised when an invalid vector_dim is provided
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(None)
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None)
|
||||
ElasticsearchDB(es_config=self.es_config)
|
||||
|
||||
def test_init_with_invalid_collection_name(self):
|
||||
# Test if an exception is raised when an invalid collection_name is provided
|
||||
self.es_config.collection_name = None
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(
|
||||
embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None
|
||||
)
|
||||
ElasticsearchDB(es_config=self.es_config)
|
||||
|
||||
Reference in New Issue
Block a user