Improve and add more tests (#807)

This commit is contained in:
Sidharth Mohanty
2023-10-18 02:36:47 +05:30
committed by GitHub
parent d065cbf934
commit e8a2846449
11 changed files with 403 additions and 260 deletions

View File

@@ -13,13 +13,9 @@ class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
def get_llm_model_answer(self, prompt) -> str:
response = OpenAILlm._get_answer(prompt, self.config)
if self.config.stream:
return response
else:
return response.content
return response
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
messages = []
@@ -35,10 +31,9 @@ class OpenAILlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:
chat = ChatOpenAI(**kwargs)
return chat(messages)
return chat(messages).content

View File

@@ -133,6 +133,7 @@ isort = "^5.12.0"
pytest-cov = "^4.1.0"
responses = "^0.23.3"
mock = "^5.1.0"
pytest-asyncio = "^0.21.1"
[tool.poetry.extras]
streamlit = ["streamlit"]

View File

@@ -104,6 +104,19 @@ class TestConfigForAppComponents:
assert isinstance(embedder_config, BaseEmbedderConfig)
def test_components_raises_type_error_if_not_proper_instances(self):
wrong_llm = "wrong_llm"
with pytest.raises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with pytest.raises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with pytest.raises(TypeError):
App(embedder=wrong_embedder)
class TestAppFromConfig:
def load_config_data(self, yaml_path):

View File

@@ -1,5 +1,5 @@
import pytest
from string import Template
from embedchain.llm.base import BaseLlm, BaseLlmConfig
@@ -14,6 +14,18 @@ def test_is_get_llm_model_answer_not_implemented(base_llm):
base_llm.get_llm_model_answer()
def test_is_stream_bool():
with pytest.raises(ValueError):
config = BaseLlmConfig(stream="test value")
BaseLlm(config=config)
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.template, Template)
def test_is_get_llm_model_answer_implemented():
class TestLlm(BaseLlm):
def get_llm_model_answer(self):

View File

@@ -1,33 +1,55 @@
import os
import unittest
from unittest.mock import patch
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.cohere import CohereLlm
class TestCohereLlm(unittest.TestCase):
def setUp(self):
os.environ["COHERE_API_KEY"] = "test_api_key"
self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
@pytest.fixture
def cohere_llm_config():
os.environ["COHERE_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("COHERE_API_KEY")
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("COHERE_API_KEY")
with self.assertRaises(ValueError):
CohereLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
llm = CohereLlm(self.config)
llm.config.system_prompt = "system_prompt"
with self.assertRaises(ValueError):
llm.get_llm_model_answer("prompt")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
CohereLlm()
@patch("embedchain.llm.cohere.CohereLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = CohereLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
llm = CohereLlm(cohere_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once()
def test_get_llm_model_answer(cohere_llm_config, mocker):
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
llm = CohereLlm(cohere_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
mock_instance = mocked_cohere.return_value
mock_instance.return_value = "Mocked answer"
llm = CohereLlm(cohere_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"
mocked_cohere.assert_called_once_with(
cohere_api_key="test_api_key",
model="gptd-instruct-tft",
max_tokens=50,
temperature=0.7,
p=0.8,
)
mock_instance.assert_called_once_with(prompt)

View File

@@ -1,64 +1,61 @@
import importlib
import os
import unittest
from unittest.mock import MagicMock, patch
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm
class TestHuggingFaceLlm(unittest.TestCase):
def setUp(self):
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
@pytest.fixture
def huggingface_llm_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
with self.assertRaises(ValueError):
HuggingFaceLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
llm = HuggingFaceLlm(self.config)
llm.config.system_prompt = "system_prompt"
with self.assertRaises(ValueError):
llm.get_llm_model_answer("prompt")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
HuggingFaceLlm()
def test_top_p_value_within_range(self):
config = BaseLlmConfig(top_p=1.0)
with self.assertRaises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config)
def test_dependency_is_imported(self):
importlib_installed = True
try:
importlib.import_module("huggingface_hub")
except ImportError:
importlib_installed = False
self.assertTrue(importlib_installed)
def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
llm = HuggingFaceLlm(huggingface_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
@patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
llm = HuggingFaceLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
def test_top_p_value_within_range():
config = BaseLlmConfig(top_p=1.0)
with pytest.raises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config)
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once()
@patch("embedchain.llm.huggingface.HuggingFaceHub")
def test_hugging_face_mock(self, mock_huggingface):
mock_llm_instance = MagicMock()
mock_llm_instance.return_value = "Test answer"
mock_huggingface.return_value = mock_llm_instance
def test_dependency_is_imported():
importlib_installed = True
try:
importlib.import_module("huggingface_hub")
except ImportError:
importlib_installed = False
assert importlib_installed
llm = HuggingFaceLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer")
mock_huggingface.assert_called_once_with(
huggingfacehub_api_token="test_access_token",
repo_id="google/flan-t5-xxl",
model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8},
)
mock_llm_instance.assert_called_once_with("Test query")
def test_get_llm_model_answer(huggingface_llm_config, mocker):
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_hugging_face_mock(huggingface_llm_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mocker.patch("embedchain.llm.huggingface.HuggingFaceHub", return_value=mock_llm_instance)
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mock_llm_instance.assert_called_once_with("Test query")

View File

@@ -1,40 +1,76 @@
import os
import unittest
from unittest.mock import patch
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
class TestJinaLlm(unittest.TestCase):
def setUp(self):
os.environ["JINACHAT_API_KEY"] = "test_api_key"
self.config = BaseLlmConfig(
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt"
)
@pytest.fixture
def config():
os.environ["JINACHAT_API_KEY"] = "test_api_key"
config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
yield config
os.environ.pop("JINACHAT_API_KEY")
def test_init_raises_value_error_without_api_key(self):
os.environ.pop("JINACHAT_API_KEY")
with self.assertRaises(ValueError):
JinaLlm()
@patch("embedchain.llm.jina.JinaLlm._get_answer")
def test_get_llm_model_answer(self, mock_get_answer):
mock_get_answer.return_value = "Test answer"
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
JinaLlm()
llm = JinaLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once()
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
@patch("embedchain.llm.jina.JinaLlm._get_answer")
def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer):
self.config.system_prompt = "Custom system prompt"
mock_get_answer.return_value = "Test answer"
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
llm = JinaLlm(self.config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
self.assertEqual(answer, "Test answer")
mock_get_answer.assert_called_once()
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
)

47
tests/llm/test_llama2.py Normal file
View File

@@ -0,0 +1,47 @@
import os
import pytest
from embedchain.llm.llama2 import Llama2Llm
@pytest.fixture
def llama2_llm():
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
llm = Llama2Llm()
return llm
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
Llama2Llm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
llama2_llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llama2_llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(llama2_llm, mocker):
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
mocked_replicate_instance = mocker.MagicMock()
mocked_replicate.return_value = mocked_replicate_instance
mocked_replicate_instance.return_value = "Test answer"
llama2_llm.config.model = "test_model"
llama2_llm.config.max_tokens = 50
llama2_llm.config.temperature = 0.7
llama2_llm.config.top_p = 0.8
answer = llama2_llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_replicate.assert_called_once_with(
model="test_model",
input={
"temperature": 0.7,
"max_length": 50,
"top_p": 0.8,
},
)
mocked_replicate_instance.assert_called_once_with("Test query")

73
tests/llm/test_openai.py Normal file
View File

@@ -0,0 +1,73 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
def config():
os.environ["OPENAI_API_KEY"] = "test_api_key"
config = BaseLlmConfig(
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
)
yield config
os.environ.pop("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
)

View File

@@ -1,135 +1,85 @@
import os
import unittest
import pytest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
app = App(config=AppConfig(collect_metrics=False))
return app
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
_answer = self.app.query(input_query="Test query")
# Ensure retrieve_from_database was called
mock_retrieve.assert_called_once()
# Check the call arguments
args, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
self.assertEqual(input_query_arg, "Test query")
mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(self, mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value.content = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
answer = app.llm.get_llm_model_answer("Test query")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
self.assertEqual(answer, "Test answer")
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_app_passing(self, mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value.content = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
answer = app.llm.get_llm_model_answer("Test query")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
self.assertEqual(answer, "Test answer")
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(app):
with patch.object(app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.query("Test query", llm_config)
answer = app.query(input_query="Test query")
assert answer == "Test answer"
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
mock_retrieve.assert_called_once()
_, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
assert input_query_arg == "Test query"
mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(app):
with patch.object(app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = app.query("Test query", where={"attribute": "value"})
assert answer == "Test answer"
_, kwargs = mock_retrieve.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(app):
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = app.query("Test query", llm_config)
assert answer == "Test answer"
_, kwargs = mock_database_query.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once()

View File

@@ -1,32 +1,29 @@
import unittest
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
class TestDataTypeEnums(unittest.TestCase):
def test_subclass_types_in_data_type(self):
"""Test that all data type category subclasses are contained in the composite data type"""
# Check if DirectDataType values are in DataType
for data_type in DirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_)
def test_subclass_types_in_data_type():
"""Test that all data type category subclasses are contained in the composite data type"""
# Check if DirectDataType values are in DataType
for data_type in DirectDataType:
assert data_type.value in DataType._value2member_map_
# Check if IndirectDataType values are in DataType
for data_type in IndirectDataType:
self.assertIn(data_type.value, DataType._value2member_map_)
# Check if IndirectDataType values are in DataType
for data_type in IndirectDataType:
assert data_type.value in DataType._value2member_map_
# Check if SpecialDataType values are in DataType
for data_type in SpecialDataType:
self.assertIn(data_type.value, DataType._value2member_map_)
# Check if SpecialDataType values are in DataType
for data_type in SpecialDataType:
assert data_type.value in DataType._value2member_map_
def test_data_type_in_subclasses(self):
"""Test that all data types in the composite data type are categorized in a subclass"""
for data_type in DataType:
if data_type.value in DirectDataType._value2member_map_:
self.assertIn(data_type.value, DirectDataType._value2member_map_)
elif data_type.value in IndirectDataType._value2member_map_:
self.assertIn(data_type.value, IndirectDataType._value2member_map_)
elif data_type.value in SpecialDataType._value2member_map_:
self.assertIn(data_type.value, SpecialDataType._value2member_map_)
else:
self.fail(f"{data_type.value} not found in any subclass enums")
def test_data_type_in_subclasses():
"""Test that all data types in the composite data type are categorized in a subclass"""
for data_type in DataType:
if data_type.value in DirectDataType._value2member_map_:
assert data_type.value in DirectDataType._value2member_map_
elif data_type.value in IndirectDataType._value2member_map_:
assert data_type.value in IndirectDataType._value2member_map_
elif data_type.value in SpecialDataType._value2member_map_:
assert data_type.value in SpecialDataType._value2member_map_
else:
assert False, f"{data_type.value} not found in any subclass enums"