Improve tests (#795)

This commit is contained in:
Sidharth Mohanty
2023-10-13 01:45:22 +05:30
committed by GitHub
parent b5de605e2b
commit 4820ea15d6
7 changed files with 373 additions and 19 deletions

View File

@@ -1,10 +1,9 @@
from typing import Any
from embedchain import CustomApp
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
from embedchain import App
from embedchain.config import AddConfig, AppConfig, LlmConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB
@@ -12,7 +11,7 @@ from embedchain.vectordb.chroma import ChromaDB
@register_deserializable
class BaseBot(JSONSerializable):
def __init__(self):
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
def add(self, data: Any, config: AddConfig = None):
"""

View File

@@ -2,18 +2,15 @@ import os
import unittest
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import ChromaDbConfig
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
from embedchain.vectordb.chroma import ChromaDB
class TestApps(unittest.TestCase):
try:
del os.environ["OPENAI_KEY"]
except KeyError:
pass
os.environ["OPENAI_API_KEY"] = "test_api_key"
def test_app(self):
app = App()
@@ -21,6 +18,18 @@ class TestApps(unittest.TestCase):
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
wrong_llm = "wrong_llm"
with self.assertRaises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with self.assertRaises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with self.assertRaises(TypeError):
App(embedder=wrong_embedder)
def test_custom_app(self):
app = CustomApp()
self.assertIsInstance(app.llm, BaseLlm)
@@ -58,3 +67,36 @@ class TestConfigForAppComponents(unittest.TestCase):
database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
app = App(db=database)
self.assertEqual(app.db.config.collection_name, self.collection_name)
def test_different_configs_are_proper_instances(self):
config = AppConfig()
wrong_app_config = AddConfig()
with self.assertRaises(TypeError):
App(config=wrong_app_config)
self.assertIsInstance(config, AppConfig)
llm_config = BaseLlmConfig()
wrong_llm_config = "wrong_llm_config"
with self.assertRaises(TypeError):
App(llm_config=wrong_llm_config)
self.assertIsInstance(llm_config, BaseLlmConfig)
db_config = BaseVectorDbConfig()
wrong_db_config = "wrong_db_config"
with self.assertRaises(TypeError):
App(db_config=wrong_db_config)
self.assertIsInstance(db_config, BaseVectorDbConfig)
embedder_config = BaseEmbedderConfig()
wrong_embedder_config = "wrong_embedder_config"
with self.assertRaises(TypeError):
App(embedder_config=wrong_embedder_config)
self.assertIsInstance(embedder_config, BaseEmbedderConfig)

49
tests/bots/test_base.py Normal file
View File

@@ -0,0 +1,49 @@
import os
import pytest
from embedchain.config import AddConfig, BaseLlmConfig
from embedchain.bots.base import BaseBot
from unittest.mock import patch
@pytest.fixture
def base_bot():
os.environ["OPENAI_API_KEY"] = "test_api_key" # needed by App
return BaseBot()
def test_add(base_bot):
data = "Test data"
config = AddConfig()
with patch.object(base_bot.app, "add") as mock_add:
base_bot.add(data, config)
mock_add.assert_called_with(data, config=config)
def test_query(base_bot):
query = "Test query"
config = BaseLlmConfig()
with patch.object(base_bot.app, "query") as mock_query:
mock_query.return_value = "Query result"
result = base_bot.query(query, config)
assert isinstance(result, str)
assert result == "Query result"
def test_start():
class TestBot(BaseBot):
def start(self):
return "Bot started"
bot = TestBot()
result = bot.start()
assert result == "Bot started"
def test_start_not_implemented():
bot = BaseBot()
with pytest.raises(NotImplementedError):
bot.start()

View File

@@ -1,11 +1,57 @@
import unittest
import pytest
from unittest.mock import MagicMock
from embedchain.embedder.base import BaseEmbedder
from embedchain.config.embedder.base import BaseEmbedderConfig
from chromadb.api.types import Documents, Embeddings
class TestEmbedder(unittest.TestCase):
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
with self.assertRaises(TypeError):
embedder.set_vector_dimension(None)
@pytest.fixture
def base_embedder():
return BaseEmbedder()
def test_initialization(base_embedder):
assert isinstance(base_embedder.config, BaseEmbedderConfig)
# not initialized
assert not hasattr(base_embedder, "embedding_fn")
assert not hasattr(base_embedder, "vector_dimension")
def test_set_embedding_fn(base_embedder):
def embedding_function(texts: Documents) -> Embeddings:
return [f"Embedding for {text}" for text in texts]
base_embedder.set_embedding_fn(embedding_function)
assert hasattr(base_embedder, "embedding_fn")
assert callable(base_embedder.embedding_fn)
embeddings = base_embedder.embedding_fn(["text1", "text2"])
assert embeddings == ["Embedding for text1", "Embedding for text2"]
def test_set_embedding_fn_when_not_a_function(base_embedder):
with pytest.raises(ValueError):
base_embedder.set_embedding_fn(None)
def test_set_vector_dimension(base_embedder):
base_embedder.set_vector_dimension(256)
assert hasattr(base_embedder, "vector_dimension")
assert base_embedder.vector_dimension == 256
def test_set_vector_dimension_type_error(base_embedder):
with pytest.raises(TypeError):
base_embedder.set_vector_dimension(None)
def test_langchain_default_concept():
embeddings = MagicMock()
embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
embed_function = BaseEmbedder._langchain_default_concept(embeddings)
result = embed_function(["text1", "text2"])
assert result == ["Embedding1", "Embedding2"]
def test_embedder_with_config():
embedder = BaseEmbedder(BaseEmbedderConfig())
assert isinstance(embedder.config, BaseEmbedderConfig)

View File

@@ -0,0 +1,64 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.antrophic import AntrophicLlm
from embedchain.config import BaseLlmConfig
from langchain.schema import HumanMessage, SystemMessage
@pytest.fixture
def antrophic_llm():
config = BaseLlmConfig(temperature=0.5, model="gpt2")
return AntrophicLlm(config)
def test_get_llm_model_answer(antrophic_llm):
with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = antrophic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
def test_get_answer(antrophic_llm):
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
)
mock_chat_instance.assert_called_once_with(
antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
)
def test_get_messages(antrophic_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = antrophic_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
config = antrophic_llm.config
config.max_tokens = 500
response = antrophic_llm._get_answer(prompt, config)
assert response == "Test Response"
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
assert "Config option `max_tokens` is not supported by this model." in caplog.text

View File

@@ -0,0 +1,91 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.azure_openai import AzureOpenAILlm
from embedchain.config import BaseLlmConfig
from langchain.schema import HumanMessage, SystemMessage
@pytest.fixture
def azure_openai_llm():
config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
model="gpt-3.5-turbo",
max_tokens=50,
system_prompt="System Prompt",
)
return AzureOpenAILlm(config)
def test_get_llm_model_answer(azure_openai_llm):
with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = azure_openai_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
def test_get_answer(azure_openai_llm):
with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
deployment_name=azure_openai_llm.config.deployment_name,
openai_api_version="2023-05-15",
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
temperature=azure_openai_llm.config.temperature,
max_tokens=azure_openai_llm.config.max_tokens,
streaming=azure_openai_llm.config.stream,
)
mock_chat_instance.assert_called_once_with(
azure_openai_llm._get_messages(prompt, system_prompt=azure_openai_llm.config.system_prompt)
)
def test_get_messages(azure_openai_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = azure_openai_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_get_answer_top_p_is_provided(azure_openai_llm, caplog):
with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
config = azure_openai_llm.config
config.top_p = 0.5
response = azure_openai_llm._get_answer(prompt, config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
deployment_name=config.deployment_name,
openai_api_version="2023-05-15",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
mock_chat_instance.assert_called_once_with(
azure_openai_llm._get_messages(prompt, system_prompt=config.system_prompt)
)
assert "Config option `top_p` is not supported by this model." in caplog.text
def test_when_no_deployment_name_provided():
config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
with pytest.raises(ValueError):
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")

View File

@@ -0,0 +1,63 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.vertex_ai import VertexAiLlm
from embedchain.config import BaseLlmConfig
from langchain.schema import HumanMessage, SystemMessage
@pytest.fixture
def vertexai_llm():
config = BaseLlmConfig(temperature=0.6, model="vertexai_model", system_prompt="System Prompt")
return VertexAiLlm(config)
def test_get_llm_model_answer(vertexai_llm):
with patch.object(VertexAiLlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = vertexai_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
def test_get_answer_with_warning(vertexai_llm, caplog):
with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
config = vertexai_llm.config
config.top_p = 0.5
response = vertexai_llm._get_answer(prompt, config)
assert response == "Test Response"
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
assert "Config option `top_p` is not supported by this model." in caplog.text
def test_get_answer_no_warning(vertexai_llm, caplog):
with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
config = vertexai_llm.config
config.top_p = 1.0
response = vertexai_llm._get_answer(prompt, config)
assert response == "Test Response"
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
assert "Config option `top_p` is not supported by this model." not in caplog.text
def test_get_messages(vertexai_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = vertexai_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]