Improve and add more tests (#807)

This commit is contained in:
Sidharth Mohanty
2023-10-18 02:36:47 +05:30
committed by GitHub
parent d065cbf934
commit e8a2846449
11 changed files with 403 additions and 260 deletions

View File

@@ -13,13 +13,9 @@ class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config) super().__init__(config=config)
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt) -> str:
response = OpenAILlm._get_answer(prompt, self.config) response = OpenAILlm._get_answer(prompt, self.config)
return response
if self.config.stream:
return response
else:
return response.content
def _get_answer(prompt: str, config: BaseLlmConfig) -> str: def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
messages = [] messages = []
@@ -35,10 +31,9 @@ class OpenAILlm(BaseLlm):
if config.top_p: if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream: if config.stream:
from langchain.callbacks.streaming_stdout import \ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else: else:
chat = ChatOpenAI(**kwargs) chat = ChatOpenAI(**kwargs)
return chat(messages) return chat(messages).content

View File

@@ -133,6 +133,7 @@ isort = "^5.12.0"
pytest-cov = "^4.1.0" pytest-cov = "^4.1.0"
responses = "^0.23.3" responses = "^0.23.3"
mock = "^5.1.0" mock = "^5.1.0"
pytest-asyncio = "^0.21.1"
[tool.poetry.extras] [tool.poetry.extras]
streamlit = ["streamlit"] streamlit = ["streamlit"]

View File

@@ -104,6 +104,19 @@ class TestConfigForAppComponents:
assert isinstance(embedder_config, BaseEmbedderConfig) assert isinstance(embedder_config, BaseEmbedderConfig)
def test_components_raises_type_error_if_not_proper_instances(self):
wrong_llm = "wrong_llm"
with pytest.raises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with pytest.raises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with pytest.raises(TypeError):
App(embedder=wrong_embedder)
class TestAppFromConfig: class TestAppFromConfig:
def load_config_data(self, yaml_path): def load_config_data(self, yaml_path):

View File

@@ -1,5 +1,5 @@
import pytest import pytest
from string import Template
from embedchain.llm.base import BaseLlm, BaseLlmConfig from embedchain.llm.base import BaseLlm, BaseLlmConfig
@@ -14,6 +14,18 @@ def test_is_get_llm_model_answer_not_implemented(base_llm):
base_llm.get_llm_model_answer() base_llm.get_llm_model_answer()
def test_is_stream_bool():
with pytest.raises(ValueError):
config = BaseLlmConfig(stream="test value")
BaseLlm(config=config)
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.template, Template)
def test_is_get_llm_model_answer_implemented(): def test_is_get_llm_model_answer_implemented():
class TestLlm(BaseLlm): class TestLlm(BaseLlm):
def get_llm_model_answer(self): def get_llm_model_answer(self):

View File

@@ -1,33 +1,55 @@
import os import os
import unittest import pytest
from unittest.mock import patch
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.cohere import CohereLlm from embedchain.llm.cohere import CohereLlm
class TestCohereLlm(unittest.TestCase): @pytest.fixture
def setUp(self): def cohere_llm_config():
os.environ["COHERE_API_KEY"] = "test_api_key" os.environ["COHERE_API_KEY"] = "test_api_key"
self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8) config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("COHERE_API_KEY")
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("COHERE_API_KEY")
with self.assertRaises(ValueError):
CohereLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self): def test_init_raises_value_error_without_api_key(mocker):
llm = CohereLlm(self.config) mocker.patch.dict(os.environ, clear=True)
llm.config.system_prompt = "system_prompt" with pytest.raises(ValueError):
with self.assertRaises(ValueError): CohereLlm()
llm.get_llm_model_answer("prompt")
@patch("embedchain.llm.cohere.CohereLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = CohereLlm(self.config) def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
answer = llm.get_llm_model_answer("Test query") llm = CohereLlm(cohere_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once() def test_get_llm_model_answer(cohere_llm_config, mocker):
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
llm = CohereLlm(cohere_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
mock_instance = mocked_cohere.return_value
mock_instance.return_value = "Mocked answer"
llm = CohereLlm(cohere_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"
mocked_cohere.assert_called_once_with(
cohere_api_key="test_api_key",
model="gptd-instruct-tft",
max_tokens=50,
temperature=0.7,
p=0.8,
)
mock_instance.assert_called_once_with(prompt)

View File

@@ -1,64 +1,61 @@
import importlib import importlib
import os import os
import unittest import pytest
from unittest.mock import MagicMock, patch
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm from embedchain.llm.huggingface import HuggingFaceLlm
class TestHuggingFaceLlm(unittest.TestCase): @pytest.fixture
def setUp(self): def huggingface_llm_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token" os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8) config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
with self.assertRaises(ValueError):
HuggingFaceLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self): def test_init_raises_value_error_without_api_key(mocker):
llm = HuggingFaceLlm(self.config) mocker.patch.dict(os.environ, clear=True)
llm.config.system_prompt = "system_prompt" with pytest.raises(ValueError):
with self.assertRaises(ValueError): HuggingFaceLlm()
llm.get_llm_model_answer("prompt")
def test_top_p_value_within_range(self):
config = BaseLlmConfig(top_p=1.0)
with self.assertRaises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config)
def test_dependency_is_imported(self): def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
importlib_installed = True llm = HuggingFaceLlm(huggingface_llm_config)
try: llm.config.system_prompt = "system_prompt"
importlib.import_module("huggingface_hub") with pytest.raises(ValueError):
except ImportError: llm.get_llm_model_answer("prompt")
importlib_installed = False
self.assertTrue(importlib_installed)
@patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = HuggingFaceLlm(self.config) def test_top_p_value_within_range():
answer = llm.get_llm_model_answer("Test query") config = BaseLlmConfig(top_p=1.0)
with pytest.raises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config)
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once()
@patch("embedchain.llm.huggingface.HuggingFaceHub") def test_dependency_is_imported():
def test_hugging_face_mock(self, mock_huggingface): importlib_installed = True
mock_llm_instance = MagicMock() try:
mock_llm_instance.return_value = "Test answer" importlib.import_module("huggingface_hub")
mock_huggingface.return_value = mock_llm_instance except ImportError:
importlib_installed = False
assert importlib_installed
llm = HuggingFaceLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") def test_get_llm_model_answer(huggingface_llm_config, mocker):
mock_huggingface.assert_called_once_with( mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
huggingfacehub_api_token="test_access_token",
repo_id="google/flan-t5-xxl", llm = HuggingFaceLlm(huggingface_llm_config)
model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8}, answer = llm.get_llm_model_answer("Test query")
)
mock_llm_instance.assert_called_once_with("Test query") assert answer == "Test answer"
def test_hugging_face_mock(huggingface_llm_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mocker.patch("embedchain.llm.huggingface.HuggingFaceHub", return_value=mock_llm_instance)
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mock_llm_instance.assert_called_once_with("Test query")

View File

@@ -1,40 +1,76 @@
import os import os
import unittest import pytest
from unittest.mock import patch
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm from embedchain.llm.jina import JinaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class TestJinaLlm(unittest.TestCase): @pytest.fixture
def setUp(self): def config():
os.environ["JINACHAT_API_KEY"] = "test_api_key" os.environ["JINACHAT_API_KEY"] = "test_api_key"
self.config = BaseLlmConfig( config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt" yield config
) os.environ.pop("JINACHAT_API_KEY")
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("JINACHAT_API_KEY")
with self.assertRaises(ValueError):
JinaLlm()
@patch("embedchain.llm.jina.JinaLlm._get_answer") def test_init_raises_value_error_without_api_key(mocker):
def test_get_llm_model_answer(self, mock_get_answer): mocker.patch.dict(os.environ, clear=True)
mock_get_answer.return_value = "Test answer" with pytest.raises(ValueError):
JinaLlm()
llm = JinaLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer") def test_get_llm_model_answer(config, mocker):
mock_get_answer.assert_called_once() mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
@patch("embedchain.llm.jina.JinaLlm._get_answer") llm = JinaLlm(config)
def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer): answer = llm.get_llm_model_answer("Test query")
self.config.system_prompt = "Custom system prompt"
mock_get_answer.return_value = "Test answer"
llm = JinaLlm(self.config) assert answer == "Test answer"
answer = llm.get_llm_model_answer("Test query") mocked_get_answer.assert_called_once_with("Test query", config)
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once() def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
)

47
tests/llm/test_llama2.py Normal file
View File

@@ -0,0 +1,47 @@
import os
import pytest
from embedchain.llm.llama2 import Llama2Llm
@pytest.fixture
def llama2_llm():
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
llm = Llama2Llm()
return llm
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
Llama2Llm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
llama2_llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llama2_llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(llama2_llm, mocker):
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
mocked_replicate_instance = mocker.MagicMock()
mocked_replicate.return_value = mocked_replicate_instance
mocked_replicate_instance.return_value = "Test answer"
llama2_llm.config.model = "test_model"
llama2_llm.config.max_tokens = 50
llama2_llm.config.temperature = 0.7
llama2_llm.config.top_p = 0.8
answer = llama2_llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_replicate.assert_called_once_with(
model="test_model",
input={
"temperature": 0.7,
"max_length": 50,
"top_p": 0.8,
},
)
mocked_replicate_instance.assert_called_once_with("Test query")

73
tests/llm/test_openai.py Normal file
View File

@@ -0,0 +1,73 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
def config():
os.environ["OPENAI_API_KEY"] = "test_api_key"
config = BaseLlmConfig(
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
)
yield config
os.environ.pop("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
)

View File

@@ -1,135 +1,85 @@
import os import os
import unittest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig
class TestApp(unittest.TestCase): @pytest.fixture
os.environ["OPENAI_API_KEY"] = "test_key" def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
app = App(config=AppConfig(collect_metrics=False))
return app
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self): def test_query(app):
""" with patch.object(app, "retrieve_from_database") as mock_retrieve:
This test checks the functionality of the 'query' method in the App class. mock_retrieve.return_value = ["Test context"]
It simulates a scenario where the 'retrieve_from_database' method returns a context list and with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
_answer = self.app.query(input_query="Test query")
# Ensure retrieve_from_database was called
mock_retrieve.assert_called_once()
# Check the call arguments
args, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
self.assertEqual(input_query_arg, "Test query")
mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(self, mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value.content = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
answer = app.llm.get_llm_model_answer("Test query")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
self.assertEqual(answer, "Test answer")
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_app_passing(self, mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value.content = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
answer = app.llm.get_llm_model_answer("Test query")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
self.assertEqual(answer, "Test answer")
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer" mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query: answer = app.query(input_query="Test query")
mock_database_query.return_value = ["Test context"] assert answer == "Test answer"
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.query("Test query", llm_config)
self.assertEqual(answer, "Test answer") mock_retrieve.assert_called_once()
_args, kwargs = mock_database_query.call_args _, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get("input_query"), "Test query") input_query_arg = kwargs.get("input_query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"}) assert input_query_arg == "Test query"
mock_answer.assert_called_once() mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(app):
with patch.object(app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = app.query("Test query", where={"attribute": "value"})
assert answer == "Test answer"
_, kwargs = mock_retrieve.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(app):
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = app.query("Test query", llm_config)
assert answer == "Test answer"
_, kwargs = mock_database_query.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once()

View File

@@ -1,32 +1,29 @@
import unittest from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
class TestDataTypeEnums(unittest.TestCase): def test_subclass_types_in_data_type():
def test_subclass_types_in_data_type(self): """Test that all data type category subclasses are contained in the composite data type"""
"""Test that all data type category subclasses are contained in the composite data type""" # Check if DirectDataType values are in DataType
# Check if DirectDataType values are in DataType for data_type in DirectDataType:
for data_type in DirectDataType: assert data_type.value in DataType._value2member_map_
self.assertIn(data_type.value, DataType._value2member_map_)
# Check if IndirectDataType values are in DataType # Check if IndirectDataType values are in DataType
for data_type in IndirectDataType: for data_type in IndirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_) assert data_type.value in DataType._value2member_map_
# Check if SpecialDataType values are in DataType # Check if SpecialDataType values are in DataType
for data_type in SpecialDataType: for data_type in SpecialDataType:
self.assertIn(data_type.value, DataType._value2member_map_) assert data_type.value in DataType._value2member_map_
def test_data_type_in_subclasses(self):
"""Test that all data types in the composite data type are categorized in a subclass""" def test_data_type_in_subclasses():
for data_type in DataType: """Test that all data types in the composite data type are categorized in a subclass"""
if data_type.value in DirectDataType._value2member_map_: for data_type in DataType:
self.assertIn(data_type.value, DirectDataType._value2member_map_) if data_type.value in DirectDataType._value2member_map_:
elif data_type.value in IndirectDataType._value2member_map_: assert data_type.value in DirectDataType._value2member_map_
self.assertIn(data_type.value, IndirectDataType._value2member_map_) elif data_type.value in IndirectDataType._value2member_map_:
elif data_type.value in SpecialDataType._value2member_map_: assert data_type.value in IndirectDataType._value2member_map_
self.assertIn(data_type.value, SpecialDataType._value2member_map_) elif data_type.value in SpecialDataType._value2member_map_:
else: assert data_type.value in SpecialDataType._value2member_map_
self.fail(f"{data_type.value} not found in any subclass enums") else:
assert False, f"{data_type.value} not found in any subclass enums"