Improve and add more tests (#807)

This commit is contained in:
Sidharth Mohanty
2023-10-18 02:36:47 +05:30
committed by GitHub
parent d065cbf934
commit e8a2846449
11 changed files with 403 additions and 260 deletions

View File

@@ -13,13 +13,9 @@ class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config) super().__init__(config=config)
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> str:
response = OpenAILlm._get_answer(prompt, self.config) response = OpenAILlm._get_answer(prompt, self.config)
if self.config.stream:
return response return response
else:
return response.content
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
messages = [] messages = []
@@ -35,10 +31,9 @@ class OpenAILlm(BaseLlm):
if config.top_p: if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream: if config.stream:
from langchain.callbacks.streaming_stdout import \ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else: else:
chat = ChatOpenAI(**kwargs) chat = ChatOpenAI(**kwargs)
return chat(messages) return chat(messages).content

View File

@@ -133,6 +133,7 @@ isort = "^5.12.0"
pytest-cov = "^4.1.0" pytest-cov = "^4.1.0"
responses = "^0.23.3" responses = "^0.23.3"
mock = "^5.1.0" mock = "^5.1.0"
pytest-asyncio = "^0.21.1"
[tool.poetry.extras] [tool.poetry.extras]
streamlit = ["streamlit"] streamlit = ["streamlit"]

View File

@@ -104,6 +104,19 @@ class TestConfigForAppComponents:
assert isinstance(embedder_config, BaseEmbedderConfig) assert isinstance(embedder_config, BaseEmbedderConfig)
def test_components_raises_type_error_if_not_proper_instances(self):
wrong_llm = "wrong_llm"
with pytest.raises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with pytest.raises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with pytest.raises(TypeError):
App(embedder=wrong_embedder)
class TestAppFromConfig: class TestAppFromConfig:
def load_config_data(self, yaml_path): def load_config_data(self, yaml_path):

View File

@@ -1,5 +1,5 @@
import pytest import pytest
from string import Template
from embedchain.llm.base import BaseLlm, BaseLlmConfig from embedchain.llm.base import BaseLlm, BaseLlmConfig
@@ -14,6 +14,18 @@ def test_is_get_llm_model_answer_not_implemented(base_llm):
base_llm.get_llm_model_answer() base_llm.get_llm_model_answer()
def test_is_stream_bool():
with pytest.raises(ValueError):
config = BaseLlmConfig(stream="test value")
BaseLlm(config=config)
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.template, Template)
def test_is_get_llm_model_answer_implemented(): def test_is_get_llm_model_answer_implemented():
class TestLlm(BaseLlm): class TestLlm(BaseLlm):
def get_llm_model_answer(self): def get_llm_model_answer(self):

View File

@@ -1,33 +1,55 @@
import os import os
import unittest import pytest
from unittest.mock import patch
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.cohere import CohereLlm from embedchain.llm.cohere import CohereLlm
class TestCohereLlm(unittest.TestCase): @pytest.fixture
def setUp(self): def cohere_llm_config():
os.environ["COHERE_API_KEY"] = "test_api_key" os.environ["COHERE_API_KEY"] = "test_api_key"
self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8) config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("COHERE_API_KEY") os.environ.pop("COHERE_API_KEY")
with self.assertRaises(ValueError):
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
CohereLlm() CohereLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
llm = CohereLlm(self.config) def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
llm = CohereLlm(cohere_llm_config)
llm.config.system_prompt = "system_prompt" llm.config.system_prompt = "system_prompt"
with self.assertRaises(ValueError): with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt") llm.get_llm_model_answer("prompt")
@patch("embedchain.llm.cohere.CohereLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = CohereLlm(self.config) def test_get_llm_model_answer(cohere_llm_config, mocker):
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
llm = CohereLlm(cohere_llm_config)
answer = llm.get_llm_model_answer("Test query") answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
mock_get_answer.assert_called_once()
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
mock_instance = mocked_cohere.return_value
mock_instance.return_value = "Mocked answer"
llm = CohereLlm(cohere_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"
mocked_cohere.assert_called_once_with(
cohere_api_key="test_api_key",
model="gptd-instruct-tft",
max_tokens=50,
temperature=0.7,
p=0.8,
)
mock_instance.assert_called_once_with(prompt)

View File

@@ -1,64 +1,61 @@
import importlib import importlib
import os import os
import unittest import pytest
from unittest.mock import MagicMock, patch
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm from embedchain.llm.huggingface import HuggingFaceLlm
class TestHuggingFaceLlm(unittest.TestCase): @pytest.fixture
def setUp(self): def huggingface_llm_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token" os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8) config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN") os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
with self.assertRaises(ValueError):
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
HuggingFaceLlm() HuggingFaceLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
llm = HuggingFaceLlm(self.config) def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
llm = HuggingFaceLlm(huggingface_llm_config)
llm.config.system_prompt = "system_prompt" llm.config.system_prompt = "system_prompt"
with self.assertRaises(ValueError): with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt") llm.get_llm_model_answer("prompt")
def test_top_p_value_within_range(self):
def test_top_p_value_within_range():
config = BaseLlmConfig(top_p=1.0) config = BaseLlmConfig(top_p=1.0)
with self.assertRaises(ValueError): with pytest.raises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config) HuggingFaceLlm._get_answer("test_prompt", config)
def test_dependency_is_imported(self):
def test_dependency_is_imported():
importlib_installed = True importlib_installed = True
try: try:
importlib.import_module("huggingface_hub") importlib.import_module("huggingface_hub")
except ImportError: except ImportError:
importlib_installed = False importlib_installed = False
self.assertTrue(importlib_installed) assert importlib_installed
@patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = HuggingFaceLlm(self.config) def test_get_llm_model_answer(huggingface_llm_config, mocker):
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query") answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
mock_get_answer.assert_called_once()
@patch("embedchain.llm.huggingface.HuggingFaceHub")
def test_hugging_face_mock(self, mock_huggingface):
mock_llm_instance = MagicMock()
mock_llm_instance.return_value = "Test answer"
mock_huggingface.return_value = mock_llm_instance
llm = HuggingFaceLlm(self.config) def test_hugging_face_mock(huggingface_llm_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mocker.patch("embedchain.llm.huggingface.HuggingFaceHub", return_value=mock_llm_instance)
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query") answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
mock_huggingface.assert_called_once_with(
huggingfacehub_api_token="test_access_token",
repo_id="google/flan-t5-xxl",
model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8},
)
mock_llm_instance.assert_called_once_with("Test query") mock_llm_instance.assert_called_once_with("Test query")

View File

@@ -1,40 +1,76 @@
import os import os
import unittest import pytest
from unittest.mock import patch
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm from embedchain.llm.jina import JinaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class TestJinaLlm(unittest.TestCase): @pytest.fixture
def setUp(self): def config():
os.environ["JINACHAT_API_KEY"] = "test_api_key" os.environ["JINACHAT_API_KEY"] = "test_api_key"
self.config = BaseLlmConfig( config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt" yield config
)
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("JINACHAT_API_KEY") os.environ.pop("JINACHAT_API_KEY")
with self.assertRaises(ValueError):
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
JinaLlm() JinaLlm()
@patch("embedchain.llm.jina.JinaLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = JinaLlm(self.config) def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query") answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
mock_get_answer.assert_called_once() mocked_get_answer.assert_called_once_with("Test query", config)
@patch("embedchain.llm.jina.JinaLlm._get_answer")
def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer):
self.config.system_prompt = "Custom system prompt"
mock_get_answer.return_value = "Test answer"
llm = JinaLlm(self.config) def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query") answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
mock_get_answer.assert_called_once() mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
)

47
tests/llm/test_llama2.py Normal file
View File

@@ -0,0 +1,47 @@
import os
import pytest
from embedchain.llm.llama2 import Llama2Llm
@pytest.fixture
def llama2_llm():
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
llm = Llama2Llm()
return llm
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
Llama2Llm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
llama2_llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llama2_llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(llama2_llm, mocker):
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
mocked_replicate_instance = mocker.MagicMock()
mocked_replicate.return_value = mocked_replicate_instance
mocked_replicate_instance.return_value = "Test answer"
llama2_llm.config.model = "test_model"
llama2_llm.config.max_tokens = 50
llama2_llm.config.temperature = 0.7
llama2_llm.config.top_p = 0.8
answer = llama2_llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_replicate.assert_called_once_with(
model="test_model",
input={
"temperature": 0.7,
"max_length": 50,
"top_p": 0.8,
},
)
mocked_replicate_instance.assert_called_once_with("Test query")

73
tests/llm/test_openai.py Normal file
View File

@@ -0,0 +1,73 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
def config():
os.environ["OPENAI_API_KEY"] = "test_api_key"
config = BaseLlmConfig(
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
)
yield config
os.environ.pop("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
)

View File

@@ -1,135 +1,85 @@
import os import os
import unittest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig
class TestApp(unittest.TestCase): @pytest.fixture
os.environ["OPENAI_API_KEY"] = "test_key" def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
app = App(config=AppConfig(collect_metrics=False))
return app
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self): def test_query(app):
""" with patch.object(app, "retrieve_from_database") as mock_retrieve:
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"] mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer" mock_answer.return_value = "Test answer"
_answer = self.app.query(input_query="Test query") answer = app.query(input_query="Test query")
assert answer == "Test answer"
# Ensure retrieve_from_database was called
mock_retrieve.assert_called_once() mock_retrieve.assert_called_once()
_, kwargs = mock_retrieve.call_args
# Check the call arguments
args, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query") input_query_arg = kwargs.get("input_query")
self.assertEqual(input_query_arg, "Test query") assert input_query_arg == "Test query"
mock_answer.assert_called_once() mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer") @patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(self, mock_get_answer): def test_query_config_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock() mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value.content = "Test answer" mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False) config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt") chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config) app = App(config=config, llm_config=chat_config)
answer = app.llm.get_llm_model_answer("Test query") answer = app.llm.get_llm_model_answer("Test query")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt") assert app.llm.config.system_prompt == "Test system prompt"
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
@patch("embedchain.llm.openai.OpenAILlm._get_answer") @patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_app_passing(self, mock_get_answer): def test_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock() mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value.content = "Test answer" mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False) config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig() chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt") app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
answer = app.llm.get_llm_model_answer("Test query") answer = app.llm.get_llm_model_answer("Test query")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt") assert app.llm.config.system_prompt == "Test system prompt"
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self): def test_query_with_where_in_params(app):
""" with patch.object(app, "retrieve_from_database") as mock_retrieve:
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"] mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer" mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"}) answer = app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
_args, kwargs = mock_retrieve.call_args _, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get("input_query"), "Test query") assert kwargs.get("input_query") == "Test query"
self.assertEqual(kwargs.get("where"), {"attribute": "value"}) assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once() mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(self): def test_query_with_where_in_query_config(app):
""" with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer" mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query: with patch.object(app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"] mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"}) llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.query("Test query", llm_config) answer = app.query("Test query", llm_config)
self.assertEqual(answer, "Test answer") assert answer == "Test answer"
_args, kwargs = mock_database_query.call_args _, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get("input_query"), "Test query") assert kwargs.get("input_query") == "Test query"
self.assertEqual(kwargs.get("where"), {"attribute": "value"}) assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once() mock_answer.assert_called_once()

View File

@@ -1,32 +1,29 @@
import unittest from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
class TestDataTypeEnums(unittest.TestCase): def test_subclass_types_in_data_type():
def test_subclass_types_in_data_type(self):
"""Test that all data type category subclasses are contained in the composite data type""" """Test that all data type category subclasses are contained in the composite data type"""
# Check if DirectDataType values are in DataType # Check if DirectDataType values are in DataType
for data_type in DirectDataType: for data_type in DirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_) assert data_type.value in DataType._value2member_map_
# Check if IndirectDataType values are in DataType # Check if IndirectDataType values are in DataType
for data_type in IndirectDataType: for data_type in IndirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_) assert data_type.value in DataType._value2member_map_
# Check if SpecialDataType values are in DataType # Check if SpecialDataType values are in DataType
for data_type in SpecialDataType: for data_type in SpecialDataType:
self.assertIn(data_type.value, DataType._value2member_map_) assert data_type.value in DataType._value2member_map_
def test_data_type_in_subclasses(self):
def test_data_type_in_subclasses():
"""Test that all data types in the composite data type are categorized in a subclass""" """Test that all data types in the composite data type are categorized in a subclass"""
for data_type in DataType: for data_type in DataType:
if data_type.value in DirectDataType._value2member_map_: if data_type.value in DirectDataType._value2member_map_:
self.assertIn(data_type.value, DirectDataType._value2member_map_) assert data_type.value in DirectDataType._value2member_map_
elif data_type.value in IndirectDataType._value2member_map_: elif data_type.value in IndirectDataType._value2member_map_:
self.assertIn(data_type.value, IndirectDataType._value2member_map_) assert data_type.value in IndirectDataType._value2member_map_
elif data_type.value in SpecialDataType._value2member_map_: elif data_type.value in SpecialDataType._value2member_map_:
self.assertIn(data_type.value, SpecialDataType._value2member_map_) assert data_type.value in SpecialDataType._value2member_map_
else: else:
self.fail(f"{data_type.value} not found in any subclass enums") assert False, f"{data_type.value} not found in any subclass enums"