refactor: classes and configs (#528)
This commit is contained in:
@@ -3,8 +3,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, CustomAppConfig
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
|
||||
|
||||
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
@@ -13,8 +12,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
@patch("chromadb.api.models.Collection.Collection.add")
|
||||
@patch("chromadb.api.models.Collection.Collection.get")
|
||||
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
|
||||
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
|
||||
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
|
||||
@patch("embedchain.llm.base_llm.BaseLlm.get_answer_from_llm")
|
||||
@patch("embedchain.llm.base_llm.BaseLlm.get_llm_model_answer")
|
||||
def test_whole_app(
|
||||
self,
|
||||
_mock_get,
|
||||
@@ -43,17 +42,14 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
"""
|
||||
Test if the `App` instance is correctly reconstructed after a reset.
|
||||
"""
|
||||
app = App(
|
||||
CustomAppConfig(
|
||||
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||
)
|
||||
)
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
|
||||
app.reset()
|
||||
|
||||
# Make sure the client is still healthy
|
||||
app.db.client.heartbeat()
|
||||
# Make sure the collection exists, and can be added to
|
||||
app.collection.add(
|
||||
app.db.collection.add(
|
||||
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
|
||||
metadatas=[
|
||||
{"chapter": "3", "verse": "16"},
|
||||
|
||||
@@ -59,7 +59,7 @@ class TestJsonSerializable(unittest.TestCase):
|
||||
def test_recursive(self):
|
||||
"""Test recursiveness with the real app"""
|
||||
random_id = str(random.random())
|
||||
config = AppConfig(id=random_id)
|
||||
config = AppConfig(id=random_id, collect_metrics=False)
|
||||
# config class is set under app.config.
|
||||
app = App(config=config)
|
||||
# w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
|
||||
@@ -67,4 +67,5 @@ class TestJsonSerializable(unittest.TestCase):
|
||||
new_app: App = App.deserialize(s)
|
||||
# The id of the new app is the same as the first one.
|
||||
self.assertEqual(random_id, new_app.config.id)
|
||||
# We have proven that a nested class (app.config) can be serialized and deserialized just the same.
|
||||
# TODO: test deeper recursion
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -12,7 +14,7 @@ class TestApp(unittest.TestCase):
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
|
||||
@@ -28,13 +30,36 @@ class TestApp(unittest.TestCase):
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
|
||||
'memory' methods.
|
||||
"""
|
||||
app = App()
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.memory.chat_memory.messages), 2)
|
||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
self.assertEqual(len(app.memory.chat_memory.messages), 4)
|
||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 4)
|
||||
|
||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_template_replacement(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
Tests that if a default template is used and it doesn't contain history,
|
||||
the default template is swapped in.
|
||||
|
||||
Also tests that a dry run does not change the history
|
||||
"""
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
history = app.llm.history
|
||||
dry_run = app.chat("Test query 2", dry_run=True)
|
||||
self.assertIn("History:", dry_run)
|
||||
self.assertEqual(history, app.llm.history)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
@@ -57,13 +82,14 @@ class TestApp(unittest.TestCase):
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.chat("Test chat", where={"attribute": "value"})
|
||||
answer = self.app.chat("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
_args, kwargs = mock_retrieve.call_args
|
||||
self.assertEqual(kwargs.get('input_query'), "Test query")
|
||||
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
@@ -85,15 +111,15 @@ class TestApp(unittest.TestCase):
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
chatConfig = ChatConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test chat", chatConfig)
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(self.app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
queryConfig = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test query", queryConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get('input_query'), "Test query")
|
||||
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, QueryConfig
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
|
||||
|
||||
class TestGeneratePrompt(unittest.TestCase):
|
||||
@@ -23,10 +23,11 @@ class TestGeneratePrompt(unittest.TestCase):
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
|
||||
config = QueryConfig(template=Template(template))
|
||||
config = BaseLlmConfig(template=Template(template))
|
||||
self.app.llm.config = config
|
||||
|
||||
# Execute
|
||||
result = self.app.generate_prompt(input_query, contexts, config)
|
||||
result = self.app.llm.generate_prompt(input_query, contexts)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
@@ -45,10 +46,11 @@ class TestGeneratePrompt(unittest.TestCase):
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
config = QueryConfig()
|
||||
config = BaseLlmConfig()
|
||||
|
||||
# Execute
|
||||
result = self.app.generate_prompt(input_query, contexts, config)
|
||||
self.app.llm.config = config
|
||||
result = self.app.llm.generate_prompt(input_query, contexts)
|
||||
|
||||
# Assert
|
||||
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
|
||||
@@ -58,9 +60,11 @@ class TestGeneratePrompt(unittest.TestCase):
|
||||
"""
|
||||
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
|
||||
"""
|
||||
config = QueryConfig(history=["Past context 1", "Past context 2"])
|
||||
config = BaseLlmConfig()
|
||||
config.template = Template("Context: $context | Query: $query | History: $history")
|
||||
prompt = self.app.generate_prompt("Test query", ["Test context"], config)
|
||||
self.app.llm.config = config
|
||||
self.app.llm.set_history(["Past context 1", "Past context 2"])
|
||||
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
|
||||
|
||||
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
|
||||
self.assertEqual(prompt, expected_prompt)
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, QueryConfig
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -33,29 +33,35 @@ class TestApp(unittest.TestCase):
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query")
|
||||
_answer = self.app.query(input_query="Test query")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||
# Ensure retrieve_from_database was called
|
||||
mock_retrieve.assert_called_once()
|
||||
|
||||
# Check the call arguments
|
||||
args, kwargs = mock_retrieve.call_args
|
||||
input_query_arg = kwargs.get("input_query")
|
||||
self.assertEqual(input_query_arg, "Test query")
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("openai.ChatCompletion.create")
|
||||
def test_query_config_app_passing(self, mock_create):
|
||||
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
|
||||
|
||||
config = AppConfig()
|
||||
chat_config = QueryConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config)
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config, llm_config=chat_config)
|
||||
|
||||
app.get_llm_model_answer("Test query", chat_config)
|
||||
app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
|
||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||
self.assertEqual(messages_arg[0]["role"], "system")
|
||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||
self.assertTrue(messages_arg[0].get("role"), "system")
|
||||
self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
|
||||
self.assertTrue(messages_arg[1].get("role"), "user")
|
||||
self.assertEqual(messages_arg[1].get("content"), "Test query")
|
||||
|
||||
# TODO: Add tests for other config variables
|
||||
|
||||
@@ -63,16 +69,18 @@ class TestApp(unittest.TestCase):
|
||||
def test_app_passing(self, mock_create):
|
||||
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
|
||||
|
||||
config = AppConfig()
|
||||
chat_config = QueryConfig()
|
||||
app = App(config=config, system_prompt="Test system prompt")
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig()
|
||||
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
|
||||
|
||||
app.get_llm_model_answer("Test query", chat_config)
|
||||
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
|
||||
|
||||
app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
|
||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||
self.assertEqual(messages_arg[0]["role"], "system")
|
||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||
self.assertTrue(messages_arg[0].get("role"), "system")
|
||||
self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(self):
|
||||
@@ -95,13 +103,14 @@ class TestApp(unittest.TestCase):
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
_args, kwargs = mock_retrieve.call_args
|
||||
self.assertEqual(kwargs.get('input_query'), "Test query")
|
||||
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
@@ -123,15 +132,16 @@ class TestApp(unittest.TestCase):
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
queryConfig = QueryConfig(where={"attribute": "value"})
|
||||
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(self.app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
queryConfig = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = self.app.query("Test query", queryConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get('input_query'), "Test query")
|
||||
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
@@ -3,8 +3,10 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, CustomAppConfig
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
@@ -16,8 +18,9 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
"""
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
config = ChromaDbConfig(host=host, port=port)
|
||||
|
||||
db = ChromaDB(host=host, port=port, embedding_fn=len)
|
||||
db = ChromaDB(config=config)
|
||||
settings = db.client.get_settings()
|
||||
self.assertEqual(settings.chroma_server_host, host)
|
||||
self.assertEqual(settings.chroma_server_http_port, port)
|
||||
@@ -31,7 +34,8 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
}
|
||||
|
||||
db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
|
||||
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
|
||||
db = ChromaDB(config=config)
|
||||
settings = db.client.get_settings()
|
||||
self.assertEqual(settings.chroma_server_host, host)
|
||||
self.assertEqual(settings.chroma_server_http_port, port)
|
||||
@@ -44,37 +48,41 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
# Review this test
|
||||
class TestChromaDbHostsInit(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||
def test_init_with_host_and_port(self, mock_client):
|
||||
def test_app_init_with_host_and_port(self, mock_client):
|
||||
"""
|
||||
Test if the `App` instance is initialized with the correct host and port values.
|
||||
"""
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
|
||||
config = AppConfig(host=host, port=port, collect_metrics=False)
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chromadb_config = ChromaDbConfig(host=host, port=port)
|
||||
|
||||
_app = App(config)
|
||||
_app = App(config, chromadb_config=chromadb_config)
|
||||
|
||||
# self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
|
||||
# self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
|
||||
self.assertEqual(called_settings.chroma_server_host, host)
|
||||
self.assertEqual(called_settings.chroma_server_http_port, port)
|
||||
|
||||
|
||||
class TestChromaDbHostsNone(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||
def test_init_with_host_and_port(self, mock_client):
|
||||
def test_init_with_host_and_port_none(self, mock_client):
|
||||
"""
|
||||
Test if the `App` instance is initialized without default hosts and ports.
|
||||
"""
|
||||
|
||||
_app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
|
||||
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
self.assertEqual(called_settings.chroma_server_host, None)
|
||||
self.assertEqual(called_settings.chroma_server_http_port, None)
|
||||
|
||||
|
||||
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||
def test_init_with_host_and_port(self, mock_client):
|
||||
def test_init_with_host_and_port_log_level(self, mock_client):
|
||||
"""
|
||||
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
||||
"""
|
||||
@@ -87,11 +95,10 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
|
||||
|
||||
class TestChromaDbDuplicateHandling:
|
||||
app_with_settings = App(
|
||||
CustomAppConfig(
|
||||
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||
)
|
||||
)
|
||||
chroma_settings = {"allow_reset": True}
|
||||
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
|
||||
def test_duplicates_throw_warning(self, caplog):
|
||||
"""
|
||||
@@ -101,8 +108,8 @@ class TestChromaDbDuplicateHandling:
|
||||
self.app_with_settings.reset()
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" in caplog.text
|
||||
assert "Add of existing embedding ID: 0" in caplog.text
|
||||
|
||||
@@ -117,19 +124,18 @@ class TestChromaDbDuplicateHandling:
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.set_collection("test_collection_1")
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.set_collection("test_collection_2")
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" not in caplog.text # not
|
||||
assert "Add of existing embedding ID: 0" not in caplog.text # not
|
||||
|
||||
|
||||
class TestChromaDbCollection(unittest.TestCase):
|
||||
app_with_settings = App(
|
||||
CustomAppConfig(
|
||||
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||
)
|
||||
)
|
||||
chroma_settings = {"allow_reset": True}
|
||||
chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
|
||||
def test_init_with_default_collection(self):
|
||||
"""
|
||||
@@ -137,16 +143,17 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
"""
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
self.assertEqual(app.collection.name, "embedchain_store")
|
||||
self.assertEqual(app.db.collection.name, "embedchain_store")
|
||||
|
||||
def test_init_with_custom_collection(self):
|
||||
"""
|
||||
Test if the `App` instance is initialized with the correct custom collection name.
|
||||
"""
|
||||
config = AppConfig(collection_name="test_collection", collect_metrics=False)
|
||||
app = App(config)
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
app.set_collection(collection_name="test_collection")
|
||||
|
||||
self.assertEqual(app.collection.name, "test_collection")
|
||||
self.assertEqual(app.db.collection.name, "test_collection")
|
||||
|
||||
def test_set_collection(self):
|
||||
"""
|
||||
@@ -155,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.set_collection("test_collection")
|
||||
|
||||
self.assertEqual(app.collection.name, "test_collection")
|
||||
self.assertEqual(app.db.collection.name, "test_collection")
|
||||
|
||||
def test_changes_encapsulated(self):
|
||||
"""
|
||||
@@ -169,7 +176,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
# Collection should be empty when created
|
||||
self.assertEqual(app.count(), 0)
|
||||
|
||||
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
# After adding, should contain one item
|
||||
self.assertEqual(app.count(), 1)
|
||||
|
||||
@@ -178,7 +185,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
self.assertEqual(app.count(), 0)
|
||||
|
||||
# Adding to new collection should not effect existing collection
|
||||
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.set_collection("test_collection_1")
|
||||
# Should still be 1, not 2.
|
||||
self.assertEqual(app.count(), 1)
|
||||
@@ -192,7 +199,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
app.set_collection("test_collection_1")
|
||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
del app
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
@@ -213,13 +220,13 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False))
|
||||
|
||||
# app2 has been created last, but adding to app1 will still write to collection 1.
|
||||
app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
self.assertEqual(app1.count(), 1)
|
||||
self.assertEqual(app2.count(), 0)
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
self.assertEqual(app1.db.count(), 1)
|
||||
self.assertEqual(app2.db.count(), 0)
|
||||
|
||||
# Add data
|
||||
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
|
||||
app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
|
||||
# Swap names and test
|
||||
app1.set_collection("test_collection_2")
|
||||
@@ -235,12 +242,14 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
self.app_with_settings.reset()
|
||||
|
||||
# Create two apps
|
||||
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
|
||||
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
|
||||
app1.set_collection("one_collection")
|
||||
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
|
||||
app2.set_collection("one_collection")
|
||||
|
||||
# Add data
|
||||
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
|
||||
# Both should have the same collection
|
||||
self.assertEqual(app1.count(), 3)
|
||||
@@ -255,25 +264,20 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
|
||||
# Create four apps.
|
||||
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
||||
app1 = App(
|
||||
CustomAppConfig(
|
||||
collection_name="one_collection",
|
||||
id="new_app_id_1",
|
||||
collect_metrics=False,
|
||||
provider=Providers.OPENAI,
|
||||
embedding_fn=EmbeddingFunctions.OPENAI,
|
||||
chroma_settings={"allow_reset": True},
|
||||
)
|
||||
)
|
||||
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
|
||||
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
|
||||
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
|
||||
app1.set_collection("one_collection")
|
||||
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
|
||||
app2.set_collection("one_collection")
|
||||
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
|
||||
app3.set_collection("three_collection")
|
||||
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
|
||||
app4.set_collection("four_collection")
|
||||
|
||||
# Each one of them get data
|
||||
app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
|
||||
app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
|
||||
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
|
||||
|
||||
# Resetting the first one should reset them all.
|
||||
app1.reset()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
|
||||
|
||||
|
||||
@@ -10,24 +10,20 @@ class TestEsDB(unittest.TestCase):
|
||||
self.es_config = ElasticsearchDBConfig()
|
||||
self.vector_dim = 384
|
||||
|
||||
def test_init_with_invalid_embedding_fn(self):
|
||||
# Test if an exception is raised when an invalid embedding_fn is provided
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(embedding_fn=None)
|
||||
|
||||
def test_init_with_invalid_es_config(self):
|
||||
# Test if an exception is raised when an invalid es_config is provided
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(embedding_fn=Mock(), es_config=None)
|
||||
ElasticsearchDB(es_config=None)
|
||||
|
||||
def test_init_with_invalid_vector_dim(self):
|
||||
# Test if an exception is raised when an invalid vector_dim is provided
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(None)
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None)
|
||||
ElasticsearchDB(es_config=self.es_config)
|
||||
|
||||
def test_init_with_invalid_collection_name(self):
|
||||
# Test if an exception is raised when an invalid collection_name is provided
|
||||
self.es_config.collection_name = None
|
||||
with self.assertRaises(ValueError):
|
||||
ElasticsearchDB(
|
||||
embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None
|
||||
)
|
||||
ElasticsearchDB(es_config=self.es_config)
|
||||
|
||||
Reference in New Issue
Block a user