refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -3,8 +3,7 @@ import unittest
from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig, CustomAppConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.config import AppConfig, ChromaDbConfig
class TestChromaDbHostsLoglevel(unittest.TestCase):
@@ -13,8 +12,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
@patch("chromadb.api.models.Collection.Collection.add")
@patch("chromadb.api.models.Collection.Collection.get")
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
@patch("embedchain.llm.base_llm.BaseLlm.get_answer_from_llm")
@patch("embedchain.llm.base_llm.BaseLlm.get_llm_model_answer")
def test_whole_app(
self,
_mock_get,
@@ -43,17 +42,14 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
"""
Test if the `App` instance is correctly reconstructed after a reset.
"""
app = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
config = AppConfig(log_level="DEBUG", collect_metrics=False)
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
app.reset()
# Make sure the client is still healthy
app.db.client.heartbeat()
# Make sure the collection exists, and can be added to
app.collection.add(
app.db.collection.add(
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
metadatas=[
{"chapter": "3", "verse": "16"},

View File

@@ -59,7 +59,7 @@ class TestJsonSerializable(unittest.TestCase):
def test_recursive(self):
"""Test recursiveness with the real app"""
random_id = str(random.random())
config = AppConfig(id=random_id)
config = AppConfig(id=random_id, collect_metrics=False)
# config class is set under app.config.
app = App(config=config)
# w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
@@ -67,4 +67,5 @@ class TestJsonSerializable(unittest.TestCase):
new_app: App = App.deserialize(s)
# The id of the new app is the same as the first one.
self.assertEqual(random_id, new_app.config.id)
# We have proven that a nested class (app.config) can be serialized and deserialized just the same.
# TODO: test deeper recursion

View File

@@ -1,9 +1,11 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import patch, MagicMock
from embedchain import App
from embedchain.config import AppConfig, ChatConfig
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
class TestApp(unittest.TestCase):
@@ -12,7 +14,7 @@ class TestApp(unittest.TestCase):
self.app = App(config=AppConfig(collect_metrics=False))
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
@@ -28,13 +30,36 @@ class TestApp(unittest.TestCase):
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
app = App()
config = AppConfig(collect_metrics=False)
app = App(config=config)
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.memory.chat_memory.messages), 2)
self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
self.assertEqual(len(app.llm.history.splitlines()), 2)
second_answer = app.chat("Test query 2")
self.assertEqual(second_answer, "Test answer")
self.assertEqual(len(app.memory.chat_memory.messages), 4)
self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
self.assertEqual(len(app.llm.history.splitlines()), 4)
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_template_replacement(self, mock_get_answer, mock_retrieve):
"""
Tests that if a default template is used and it doesn't contain history,
the default template is swapped in.
Also tests that a dry run does not change the history
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.llm.history.splitlines()), 2)
history = app.llm.history
dry_run = app.chat("Test query 2", dry_run=True)
self.assertIn("History:", dry_run)
self.assertEqual(history, app.llm.history)
self.assertEqual(len(app.llm.history.splitlines()), 2)
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_params(self):
@@ -57,13 +82,14 @@ class TestApp(unittest.TestCase):
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.chat("Test chat", where={"attribute": "value"})
answer = self.app.chat("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
@@ -85,15 +111,15 @@ class TestApp(unittest.TestCase):
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
chatConfig = ChatConfig(where={"attribute": "value"})
answer = self.app.chat("Test chat", chatConfig)
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
queryConfig = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.chat("Test query", queryConfig)
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()

View File

@@ -2,7 +2,7 @@ import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig, QueryConfig
from embedchain.config import AppConfig, BaseLlmConfig
class TestGeneratePrompt(unittest.TestCase):
@@ -23,10 +23,11 @@ class TestGeneratePrompt(unittest.TestCase):
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = QueryConfig(template=Template(template))
config = BaseLlmConfig(template=Template(template))
self.app.llm.config = config
# Execute
result = self.app.generate_prompt(input_query, contexts, config)
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = (
@@ -45,10 +46,11 @@ class TestGeneratePrompt(unittest.TestCase):
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = QueryConfig()
config = BaseLlmConfig()
# Execute
result = self.app.generate_prompt(input_query, contexts, config)
self.app.llm.config = config
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
@@ -58,9 +60,11 @@ class TestGeneratePrompt(unittest.TestCase):
"""
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
"""
config = QueryConfig(history=["Past context 1", "Past context 2"])
config = BaseLlmConfig()
config.template = Template("Context: $context | Query: $query | History: $history")
prompt = self.app.generate_prompt("Test query", ["Test context"], config)
self.app.llm.config = config
self.app.llm.set_history(["Past context 1", "Past context 2"])
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
self.assertEqual(prompt, expected_prompt)

View File

@@ -3,7 +3,7 @@ import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, QueryConfig
from embedchain.config import AppConfig, BaseLlmConfig
class TestApp(unittest.TestCase):
@@ -33,29 +33,35 @@ class TestApp(unittest.TestCase):
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query")
_answer = self.app.query(input_query="Test query")
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
# Ensure retrieve_from_database was called
mock_retrieve.assert_called_once()
# Check the call arguments
args, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
self.assertEqual(input_query_arg, "Test query")
mock_answer.assert_called_once()
@patch("openai.ChatCompletion.create")
def test_query_config_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig()
chat_config = QueryConfig(system_prompt="Test system prompt")
app = App(config=config)
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
app.get_llm_model_answer("Test query", chat_config)
app.llm.get_llm_model_answer("Test query")
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
self.assertTrue(messages_arg[0].get("role"), "system")
self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
self.assertTrue(messages_arg[1].get("role"), "user")
self.assertEqual(messages_arg[1].get("content"), "Test query")
# TODO: Add tests for other config variables
@@ -63,16 +69,18 @@ class TestApp(unittest.TestCase):
def test_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig()
chat_config = QueryConfig()
app = App(config=config, system_prompt="Test system prompt")
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
app.get_llm_model_answer("Test query", chat_config)
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
app.llm.get_llm_model_answer("Test query")
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
self.assertTrue(messages_arg[0].get("role"), "system")
self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self):
@@ -95,13 +103,14 @@ class TestApp(unittest.TestCase):
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
@@ -123,15 +132,16 @@ class TestApp(unittest.TestCase):
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
queryConfig = QueryConfig(where={"attribute": "value"})
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
queryConfig = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.query("Test query", queryConfig)
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()

View File

@@ -3,8 +3,10 @@
import unittest
from unittest.mock import patch
from chromadb.config import Settings
from embedchain import App
from embedchain.config import AppConfig, CustomAppConfig
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.vectordb.chroma_db import ChromaDB
@@ -16,8 +18,9 @@ class TestChromaDbHosts(unittest.TestCase):
"""
host = "test-host"
port = "1234"
config = ChromaDbConfig(host=host, port=port)
db = ChromaDB(host=host, port=port, embedding_fn=len)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
@@ -31,7 +34,8 @@ class TestChromaDbHosts(unittest.TestCase):
"chroma_client_auth_credentials": "admin:admin",
}
db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
@@ -44,37 +48,41 @@ class TestChromaDbHosts(unittest.TestCase):
# Review this test
class TestChromaDbHostsInit(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
def test_app_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = AppConfig(host=host, port=port, collect_metrics=False)
config = AppConfig(collect_metrics=False)
chromadb_config = ChromaDbConfig(host=host, port=port)
_app = App(config)
_app = App(config, chromadb_config=chromadb_config)
# self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
# self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, host)
self.assertEqual(called_settings.chroma_server_http_port, port)
class TestChromaDbHostsNone(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
def test_init_with_host_and_port_none(self, mock_client):
"""
Test if the `App` instance is initialized without default hosts and ports.
"""
_app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
called_settings: Settings = mock_client.call_args[0][0]
self.assertEqual(called_settings.chroma_server_host, None)
self.assertEqual(called_settings.chroma_server_http_port, None)
class TestChromaDbHostsLoglevel(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
def test_init_with_host_and_port_log_level(self, mock_client):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
@@ -87,11 +95,10 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
class TestChromaDbDuplicateHandling:
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
chroma_settings = {"allow_reset": True}
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
def test_duplicates_throw_warning(self, caplog):
"""
@@ -101,8 +108,8 @@ class TestChromaDbDuplicateHandling:
self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False))
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
@@ -117,19 +124,18 @@ class TestChromaDbDuplicateHandling:
app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection("test_collection_2")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not
class TestChromaDbCollection(unittest.TestCase):
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
chroma_settings = {"allow_reset": True}
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
def test_init_with_default_collection(self):
"""
@@ -137,16 +143,17 @@ class TestChromaDbCollection(unittest.TestCase):
"""
app = App(config=AppConfig(collect_metrics=False))
self.assertEqual(app.collection.name, "embedchain_store")
self.assertEqual(app.db.collection.name, "embedchain_store")
def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = AppConfig(collection_name="test_collection", collect_metrics=False)
app = App(config)
config = AppConfig(collect_metrics=False)
app = App(config=config)
app.set_collection(collection_name="test_collection")
self.assertEqual(app.collection.name, "test_collection")
self.assertEqual(app.db.collection.name, "test_collection")
def test_set_collection(self):
"""
@@ -155,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection")
self.assertEqual(app.collection.name, "test_collection")
self.assertEqual(app.db.collection.name, "test_collection")
def test_changes_encapsulated(self):
"""
@@ -169,7 +176,7 @@ class TestChromaDbCollection(unittest.TestCase):
# Collection should be empty when created
self.assertEqual(app.count(), 0)
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# After adding, should contain one item
self.assertEqual(app.count(), 1)
@@ -178,7 +185,7 @@ class TestChromaDbCollection(unittest.TestCase):
self.assertEqual(app.count(), 0)
# Adding to new collection should not effect existing collection
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection("test_collection_1")
# Should still be 1, not 2.
self.assertEqual(app.count(), 1)
@@ -192,7 +199,7 @@ class TestChromaDbCollection(unittest.TestCase):
app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
app = App(config=AppConfig(collect_metrics=False))
@@ -213,13 +220,13 @@ class TestChromaDbCollection(unittest.TestCase):
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
# app2 has been created last, but adding to app1 will still write to collection 1.
app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.count(), 1)
self.assertEqual(app2.count(), 0)
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.db.count(), 1)
self.assertEqual(app2.db.count(), 0)
# Add data
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# Swap names and test
app1.set_collection("test_collection_2")
@@ -235,12 +242,14 @@ class TestChromaDbCollection(unittest.TestCase):
self.app_with_settings.reset()
# Create two apps
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app1.set_collection("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection("one_collection")
# Add data
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
# Both should have the same collection
self.assertEqual(app1.count(), 3)
@@ -255,25 +264,20 @@ class TestChromaDbCollection(unittest.TestCase):
# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(
CustomAppConfig(
collection_name="one_collection",
id="new_app_id_1",
collect_metrics=False,
provider=Providers.OPENAI,
embedding_fn=EmbeddingFunctions.OPENAI,
chroma_settings={"allow_reset": True},
)
)
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
app1.set_collection("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection("one_collection")
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app3.set_collection("three_collection")
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
app4.set_collection("four_collection")
# Each one of them get data
app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
# Resetting the first one should reset them all.
app1.reset()

View File

@@ -1,7 +1,7 @@
import unittest
from unittest.mock import Mock
from embedchain.config import ElasticsearchDBConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
@@ -10,24 +10,20 @@ class TestEsDB(unittest.TestCase):
self.es_config = ElasticsearchDBConfig()
self.vector_dim = 384
def test_init_with_invalid_embedding_fn(self):
# Test if an exception is raised when an invalid embedding_fn is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=None)
def test_init_with_invalid_es_config(self):
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock(), es_config=None)
ElasticsearchDB(es_config=None)
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
embedder.set_vector_dimension(None)
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None)
ElasticsearchDB(es_config=self.es_config)
def test_init_with_invalid_collection_name(self):
# Test if an exception is raised when an invalid collection_name is provided
self.es_config.collection_name = None
with self.assertRaises(ValueError):
ElasticsearchDB(
embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None
)
ElasticsearchDB(es_config=self.es_config)