Improve tests (#795)
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import CustomApp
|
||||
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, LlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
@@ -12,7 +11,7 @@ from embedchain.vectordb.chroma import ChromaDB
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
|
||||
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
|
||||
|
||||
def add(self, data: Any, config: AddConfig = None):
|
||||
"""
|
||||
|
||||
@@ -2,18 +2,15 @@ import os
|
||||
import unittest
|
||||
|
||||
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
class TestApps(unittest.TestCase):
|
||||
try:
|
||||
del os.environ["OPENAI_KEY"]
|
||||
except KeyError:
|
||||
pass
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
|
||||
def test_app(self):
|
||||
app = App()
|
||||
@@ -21,6 +18,18 @@ class TestApps(unittest.TestCase):
|
||||
self.assertIsInstance(app.db, BaseVectorDB)
|
||||
self.assertIsInstance(app.embedder, BaseEmbedder)
|
||||
|
||||
wrong_llm = "wrong_llm"
|
||||
with self.assertRaises(TypeError):
|
||||
App(llm=wrong_llm)
|
||||
|
||||
wrong_db = "wrong_db"
|
||||
with self.assertRaises(TypeError):
|
||||
App(db=wrong_db)
|
||||
|
||||
wrong_embedder = "wrong_embedder"
|
||||
with self.assertRaises(TypeError):
|
||||
App(embedder=wrong_embedder)
|
||||
|
||||
def test_custom_app(self):
|
||||
app = CustomApp()
|
||||
self.assertIsInstance(app.llm, BaseLlm)
|
||||
@@ -58,3 +67,36 @@ class TestConfigForAppComponents(unittest.TestCase):
|
||||
database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
|
||||
app = App(db=database)
|
||||
self.assertEqual(app.db.config.collection_name, self.collection_name)
|
||||
|
||||
def test_different_configs_are_proper_instances(self):
|
||||
config = AppConfig()
|
||||
wrong_app_config = AddConfig()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
App(config=wrong_app_config)
|
||||
|
||||
self.assertIsInstance(config, AppConfig)
|
||||
|
||||
llm_config = BaseLlmConfig()
|
||||
wrong_llm_config = "wrong_llm_config"
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
App(llm_config=wrong_llm_config)
|
||||
|
||||
self.assertIsInstance(llm_config, BaseLlmConfig)
|
||||
|
||||
db_config = BaseVectorDbConfig()
|
||||
wrong_db_config = "wrong_db_config"
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
App(db_config=wrong_db_config)
|
||||
|
||||
self.assertIsInstance(db_config, BaseVectorDbConfig)
|
||||
|
||||
embedder_config = BaseEmbedderConfig()
|
||||
wrong_embedder_config = "wrong_embedder_config"
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
App(embedder_config=wrong_embedder_config)
|
||||
|
||||
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
49
tests/bots/test_base.py
Normal file
49
tests/bots/test_base.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import pytest
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
from embedchain.bots.base import BaseBot
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_bot():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key" # needed by App
|
||||
return BaseBot()
|
||||
|
||||
|
||||
def test_add(base_bot):
|
||||
data = "Test data"
|
||||
config = AddConfig()
|
||||
|
||||
with patch.object(base_bot.app, "add") as mock_add:
|
||||
base_bot.add(data, config)
|
||||
mock_add.assert_called_with(data, config=config)
|
||||
|
||||
|
||||
def test_query(base_bot):
|
||||
query = "Test query"
|
||||
config = BaseLlmConfig()
|
||||
|
||||
with patch.object(base_bot.app, "query") as mock_query:
|
||||
mock_query.return_value = "Query result"
|
||||
|
||||
result = base_bot.query(query, config)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "Query result"
|
||||
|
||||
|
||||
def test_start():
|
||||
class TestBot(BaseBot):
|
||||
def start(self):
|
||||
return "Bot started"
|
||||
|
||||
bot = TestBot()
|
||||
result = bot.start()
|
||||
assert result == "Bot started"
|
||||
|
||||
|
||||
def test_start_not_implemented():
|
||||
bot = BaseBot()
|
||||
with pytest.raises(NotImplementedError):
|
||||
bot.start()
|
||||
@@ -1,11 +1,57 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
|
||||
class TestEmbedder(unittest.TestCase):
|
||||
def test_init_with_invalid_vector_dim(self):
|
||||
# Test if an exception is raised when an invalid vector_dim is provided
|
||||
embedder = BaseEmbedder()
|
||||
with self.assertRaises(TypeError):
|
||||
embedder.set_vector_dimension(None)
|
||||
@pytest.fixture
|
||||
def base_embedder():
|
||||
return BaseEmbedder()
|
||||
|
||||
|
||||
def test_initialization(base_embedder):
|
||||
assert isinstance(base_embedder.config, BaseEmbedderConfig)
|
||||
# not initialized
|
||||
assert not hasattr(base_embedder, "embedding_fn")
|
||||
assert not hasattr(base_embedder, "vector_dimension")
|
||||
|
||||
|
||||
def test_set_embedding_fn(base_embedder):
|
||||
def embedding_function(texts: Documents) -> Embeddings:
|
||||
return [f"Embedding for {text}" for text in texts]
|
||||
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
assert hasattr(base_embedder, "embedding_fn")
|
||||
assert callable(base_embedder.embedding_fn)
|
||||
embeddings = base_embedder.embedding_fn(["text1", "text2"])
|
||||
assert embeddings == ["Embedding for text1", "Embedding for text2"]
|
||||
|
||||
|
||||
def test_set_embedding_fn_when_not_a_function(base_embedder):
|
||||
with pytest.raises(ValueError):
|
||||
base_embedder.set_embedding_fn(None)
|
||||
|
||||
|
||||
def test_set_vector_dimension(base_embedder):
|
||||
base_embedder.set_vector_dimension(256)
|
||||
assert hasattr(base_embedder, "vector_dimension")
|
||||
assert base_embedder.vector_dimension == 256
|
||||
|
||||
|
||||
def test_set_vector_dimension_type_error(base_embedder):
|
||||
with pytest.raises(TypeError):
|
||||
base_embedder.set_vector_dimension(None)
|
||||
|
||||
|
||||
def test_langchain_default_concept():
|
||||
embeddings = MagicMock()
|
||||
embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
|
||||
embed_function = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
result = embed_function(["text1", "text2"])
|
||||
assert result == ["Embedding1", "Embedding2"]
|
||||
|
||||
|
||||
def test_embedder_with_config():
|
||||
embedder = BaseEmbedder(BaseEmbedderConfig())
|
||||
assert isinstance(embedder.config, BaseEmbedderConfig)
|
||||
|
||||
64
tests/llm/test_antrophic.py
Normal file
64
tests/llm/test_antrophic.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain.llm.antrophic import AntrophicLlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def antrophic_llm():
|
||||
config = BaseLlmConfig(temperature=0.5, model="gpt2")
|
||||
return AntrophicLlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(antrophic_llm):
|
||||
with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = antrophic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
|
||||
|
||||
|
||||
def test_get_answer(antrophic_llm):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
|
||||
)
|
||||
mock_chat_instance.assert_called_once_with(
|
||||
antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages(antrophic_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = antrophic_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
config = antrophic_llm.config
|
||||
config.max_tokens = 500
|
||||
|
||||
response = antrophic_llm._get_answer(prompt, config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
|
||||
|
||||
assert "Config option `max_tokens` is not supported by this model." in caplog.text
|
||||
91
tests/llm/test_azure_openai.py
Normal file
91
tests/llm/test_azure_openai.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_llm():
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
model="gpt-3.5-turbo",
|
||||
max_tokens=50,
|
||||
system_prompt="System Prompt",
|
||||
)
|
||||
return AzureOpenAILlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(azure_openai_llm):
|
||||
with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = azure_openai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
|
||||
|
||||
|
||||
def test_get_answer(azure_openai_llm):
|
||||
with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name=azure_openai_llm.config.deployment_name,
|
||||
openai_api_version="2023-05-15",
|
||||
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
|
||||
temperature=azure_openai_llm.config.temperature,
|
||||
max_tokens=azure_openai_llm.config.max_tokens,
|
||||
streaming=azure_openai_llm.config.stream,
|
||||
)
|
||||
mock_chat_instance.assert_called_once_with(
|
||||
azure_openai_llm._get_messages(prompt, system_prompt=azure_openai_llm.config.system_prompt)
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages(azure_openai_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = azure_openai_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_get_answer_top_p_is_provided(azure_openai_llm, caplog):
|
||||
with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
config = azure_openai_llm.config
|
||||
config.top_p = 0.5
|
||||
|
||||
response = azure_openai_llm._get_answer(prompt, config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name=config.deployment_name,
|
||||
openai_api_version="2023-05-15",
|
||||
model_name=config.model or "gpt-3.5-turbo",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
mock_chat_instance.assert_called_once_with(
|
||||
azure_openai_llm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
)
|
||||
|
||||
assert "Config option `top_p` is not supported by this model." in caplog.text
|
||||
|
||||
|
||||
def test_when_no_deployment_name_provided():
|
||||
config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
|
||||
with pytest.raises(ValueError):
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test Prompt")
|
||||
63
tests/llm/test_vertex_ai.py
Normal file
63
tests/llm/test_vertex_ai.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from embedchain.llm.vertex_ai import VertexAiLlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_llm():
|
||||
config = BaseLlmConfig(temperature=0.6, model="vertexai_model", system_prompt="System Prompt")
|
||||
return VertexAiLlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(vertexai_llm):
|
||||
with patch.object(VertexAiLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = vertexai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
|
||||
|
||||
|
||||
def test_get_answer_with_warning(vertexai_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
config = vertexai_llm.config
|
||||
config.top_p = 0.5
|
||||
|
||||
response = vertexai_llm._get_answer(prompt, config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
|
||||
|
||||
assert "Config option `top_p` is not supported by this model." in caplog.text
|
||||
|
||||
|
||||
def test_get_answer_no_warning(vertexai_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
config = vertexai_llm.config
|
||||
config.top_p = 1.0
|
||||
|
||||
response = vertexai_llm._get_answer(prompt, config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
|
||||
|
||||
assert "Config option `top_p` is not supported by this model." not in caplog.text
|
||||
|
||||
|
||||
def test_get_messages(vertexai_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = vertexai_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
Reference in New Issue
Block a user