diff --git a/embedchain/bots/base.py b/embedchain/bots/base.py index d25df78e..5be27121 100644 --- a/embedchain/bots/base.py +++ b/embedchain/bots/base.py @@ -1,10 +1,9 @@ from typing import Any -from embedchain import CustomApp -from embedchain.config import AddConfig, CustomAppConfig, LlmConfig +from embedchain import App +from embedchain.config import AddConfig, AppConfig, LlmConfig from embedchain.embedder.openai import OpenAIEmbedder -from embedchain.helper.json_serializable import (JSONSerializable, - register_deserializable) +from embedchain.helper.json_serializable import JSONSerializable, register_deserializable from embedchain.llm.openai import OpenAILlm from embedchain.vectordb.chroma import ChromaDB @@ -12,7 +11,7 @@ from embedchain.vectordb.chroma import ChromaDB @register_deserializable class BaseBot(JSONSerializable): def __init__(self): - self.app = CustomApp(config=CustomAppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder()) + self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder()) def add(self, data: Any, config: AddConfig = None): """ diff --git a/tests/apps/test_apps.py b/tests/apps/test_apps.py index 79a4c85b..0b2ad263 100644 --- a/tests/apps/test_apps.py +++ b/tests/apps/test_apps.py @@ -2,18 +2,15 @@ import os import unittest from embedchain import App, CustomApp, Llama2App, OpenSourceApp -from embedchain.config import ChromaDbConfig +from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig from embedchain.embedder.base import BaseEmbedder from embedchain.llm.base import BaseLlm -from embedchain.vectordb.base import BaseVectorDB +from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig from embedchain.vectordb.chroma import ChromaDB class TestApps(unittest.TestCase): - try: - del os.environ["OPENAI_KEY"] - except KeyError: - pass + os.environ["OPENAI_API_KEY"] = "test_api_key" def test_app(self): app = App() @@ -21,6 +18,18 @@ class TestApps(unittest.TestCase): self.assertIsInstance(app.db, BaseVectorDB) self.assertIsInstance(app.embedder, BaseEmbedder) + wrong_llm = "wrong_llm" + with self.assertRaises(TypeError): + App(llm=wrong_llm) + + wrong_db = "wrong_db" + with self.assertRaises(TypeError): + App(db=wrong_db) + + wrong_embedder = "wrong_embedder" + with self.assertRaises(TypeError): + App(embedder=wrong_embedder) + def test_custom_app(self): app = CustomApp() self.assertIsInstance(app.llm, BaseLlm) @@ -58,3 +67,36 @@ class TestConfigForAppComponents(unittest.TestCase): database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name)) app = App(db=database) self.assertEqual(app.db.config.collection_name, self.collection_name) + + def test_different_configs_are_proper_instances(self): + config = AppConfig() + wrong_app_config = AddConfig() + + with self.assertRaises(TypeError): + App(config=wrong_app_config) + + self.assertIsInstance(config, AppConfig) + + llm_config = BaseLlmConfig() + wrong_llm_config = "wrong_llm_config" + + with self.assertRaises(TypeError): + App(llm_config=wrong_llm_config) + + self.assertIsInstance(llm_config, BaseLlmConfig) + + db_config = BaseVectorDbConfig() + wrong_db_config = "wrong_db_config" + + with self.assertRaises(TypeError): + App(db_config=wrong_db_config) + + self.assertIsInstance(db_config, BaseVectorDbConfig) + + embedder_config = BaseEmbedderConfig() + wrong_embedder_config = "wrong_embedder_config" + + with self.assertRaises(TypeError): + App(embedder_config=wrong_embedder_config) + + self.assertIsInstance(embedder_config, BaseEmbedderConfig) diff --git a/tests/bots/test_base.py b/tests/bots/test_base.py new file mode 100644 index 00000000..050de6e6 --- /dev/null +++ b/tests/bots/test_base.py @@ -0,0 +1,49 @@ +import os +import pytest +from embedchain.config import AddConfig, BaseLlmConfig +from embedchain.bots.base import BaseBot +from unittest.mock import patch + + +@pytest.fixture +def base_bot(): + os.environ["OPENAI_API_KEY"] = "test_api_key" # needed by App + return BaseBot() + + +def test_add(base_bot): + data = "Test data" + config = AddConfig() + + with patch.object(base_bot.app, "add") as mock_add: + base_bot.add(data, config) + mock_add.assert_called_with(data, config=config) + + +def test_query(base_bot): + query = "Test query" + config = BaseLlmConfig() + + with patch.object(base_bot.app, "query") as mock_query: + mock_query.return_value = "Query result" + + result = base_bot.query(query, config) + + assert isinstance(result, str) + assert result == "Query result" + + +def test_start(): + class TestBot(BaseBot): + def start(self): + return "Bot started" + + bot = TestBot() + result = bot.start() + assert result == "Bot started" + + +def test_start_not_implemented(): + bot = BaseBot() + with pytest.raises(NotImplementedError): + bot.start() diff --git a/tests/embedder/test_embedder.py b/tests/embedder/test_embedder.py index bfae3821..e83f9770 100644 --- a/tests/embedder/test_embedder.py +++ b/tests/embedder/test_embedder.py @@ -1,11 +1,57 @@ -import unittest - +import pytest +from unittest.mock import MagicMock from embedchain.embedder.base import BaseEmbedder +from embedchain.config.embedder.base import BaseEmbedderConfig +from chromadb.api.types import Documents, Embeddings -class TestEmbedder(unittest.TestCase): - def test_init_with_invalid_vector_dim(self): - # Test if an exception is raised when an invalid vector_dim is provided - embedder = BaseEmbedder() - with self.assertRaises(TypeError): - embedder.set_vector_dimension(None) +@pytest.fixture +def base_embedder(): + return BaseEmbedder() + + +def test_initialization(base_embedder): + assert isinstance(base_embedder.config, BaseEmbedderConfig) + # not initialized + assert not hasattr(base_embedder, "embedding_fn") + assert not hasattr(base_embedder, "vector_dimension") + + +def test_set_embedding_fn(base_embedder): + def embedding_function(texts: Documents) -> Embeddings: + return [f"Embedding for {text}" for text in texts] + + base_embedder.set_embedding_fn(embedding_function) + assert hasattr(base_embedder, "embedding_fn") + assert callable(base_embedder.embedding_fn) + embeddings = base_embedder.embedding_fn(["text1", "text2"]) + assert embeddings == ["Embedding for text1", "Embedding for text2"] + + +def test_set_embedding_fn_when_not_a_function(base_embedder): + with pytest.raises(ValueError): + base_embedder.set_embedding_fn(None) + + +def test_set_vector_dimension(base_embedder): + base_embedder.set_vector_dimension(256) + assert hasattr(base_embedder, "vector_dimension") + assert base_embedder.vector_dimension == 256 + + +def test_set_vector_dimension_type_error(base_embedder): + with pytest.raises(TypeError): + base_embedder.set_vector_dimension(None) + + +def test_langchain_default_concept(): + embeddings = MagicMock() + embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"] + embed_function = BaseEmbedder._langchain_default_concept(embeddings) + result = embed_function(["text1", "text2"]) + assert result == ["Embedding1", "Embedding2"] + + +def test_embedder_with_config(): + embedder = BaseEmbedder(BaseEmbedderConfig()) + assert isinstance(embedder.config, BaseEmbedderConfig) diff --git a/tests/llm/test_antrophic.py b/tests/llm/test_antrophic.py new file mode 100644 index 00000000..d2489a05 --- /dev/null +++ b/tests/llm/test_antrophic.py @@ -0,0 +1,64 @@ +import pytest +from unittest.mock import MagicMock, patch + +from embedchain.llm.antrophic import AntrophicLlm +from embedchain.config import BaseLlmConfig +from langchain.schema import HumanMessage, SystemMessage + + +@pytest.fixture +def antrophic_llm(): + config = BaseLlmConfig(temperature=0.5, model="gpt2") + return AntrophicLlm(config) + + +def test_get_llm_model_answer(antrophic_llm): + with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method: + prompt = "Test Prompt" + response = antrophic_llm.get_llm_model_answer(prompt) + assert response == "Test Response" + mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config) + + +def test_get_answer(antrophic_llm): + with patch("langchain.chat_models.ChatAnthropic") as mock_chat: + mock_chat_instance = mock_chat.return_value + mock_chat_instance.return_value = MagicMock(content="Test Response") + + prompt = "Test Prompt" + response = antrophic_llm._get_answer(prompt, antrophic_llm.config) + + assert response == "Test Response" + mock_chat.assert_called_once_with( + temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model + ) + mock_chat_instance.assert_called_once_with( + antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt) + ) + + +def test_get_messages(antrophic_llm): + prompt = "Test Prompt" + system_prompt = "Test System Prompt" + messages = antrophic_llm._get_messages(prompt, system_prompt) + assert messages == [ + SystemMessage(content="Test System Prompt", additional_kwargs={}), + HumanMessage(content="Test Prompt", additional_kwargs={}, example=False), + ] + + +def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog): + with patch("langchain.chat_models.ChatAnthropic") as mock_chat: + mock_chat_instance = mock_chat.return_value + mock_chat_instance.return_value = MagicMock(content="Test Response") + + prompt = "Test Prompt" + config = antrophic_llm.config + config.max_tokens = 500 + + response = antrophic_llm._get_answer(prompt, config) + + assert response == "Test Response" + mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model) + + assert "Config option `max_tokens` is not supported by this model." in caplog.text diff --git a/tests/llm/test_azure_openai.py b/tests/llm/test_azure_openai.py new file mode 100644 index 00000000..7f2c7614 --- /dev/null +++ b/tests/llm/test_azure_openai.py @@ -0,0 +1,91 @@ +import pytest +from unittest.mock import MagicMock, patch +from embedchain.llm.azure_openai import AzureOpenAILlm +from embedchain.config import BaseLlmConfig +from langchain.schema import HumanMessage, SystemMessage + + +@pytest.fixture +def azure_openai_llm(): + config = BaseLlmConfig( + deployment_name="azure_deployment", + temperature=0.7, + model="gpt-3.5-turbo", + max_tokens=50, + system_prompt="System Prompt", + ) + return AzureOpenAILlm(config) + + +def test_get_llm_model_answer(azure_openai_llm): + with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method: + prompt = "Test Prompt" + response = azure_openai_llm.get_llm_model_answer(prompt) + assert response == "Test Response" + mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config) + + +def test_get_answer(azure_openai_llm): + with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat: + mock_chat_instance = mock_chat.return_value + mock_chat_instance.return_value = MagicMock(content="Test Response") + + prompt = "Test Prompt" + response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config) + + assert response == "Test Response" + mock_chat.assert_called_once_with( + deployment_name=azure_openai_llm.config.deployment_name, + openai_api_version="2023-05-15", + model_name=azure_openai_llm.config.model or "gpt-3.5-turbo", + temperature=azure_openai_llm.config.temperature, + max_tokens=azure_openai_llm.config.max_tokens, + streaming=azure_openai_llm.config.stream, + ) + mock_chat_instance.assert_called_once_with( + azure_openai_llm._get_messages(prompt, system_prompt=azure_openai_llm.config.system_prompt) + ) + + +def test_get_messages(azure_openai_llm): + prompt = "Test Prompt" + system_prompt = "Test System Prompt" + messages = azure_openai_llm._get_messages(prompt, system_prompt) + assert messages == [ + SystemMessage(content="Test System Prompt", additional_kwargs={}), + HumanMessage(content="Test Prompt", additional_kwargs={}, example=False), + ] + + +def test_get_answer_top_p_is_provided(azure_openai_llm, caplog): + with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat: + mock_chat_instance = mock_chat.return_value + mock_chat_instance.return_value = MagicMock(content="Test Response") + + prompt = "Test Prompt" + config = azure_openai_llm.config + config.top_p = 0.5 + + response = azure_openai_llm._get_answer(prompt, config) + + assert response == "Test Response" + mock_chat.assert_called_once_with( + deployment_name=config.deployment_name, + openai_api_version="2023-05-15", + model_name=config.model or "gpt-3.5-turbo", + temperature=config.temperature, + max_tokens=config.max_tokens, + streaming=config.stream, + ) + mock_chat_instance.assert_called_once_with( + azure_openai_llm._get_messages(prompt, system_prompt=config.system_prompt) + ) + + assert "Config option `top_p` is not supported by this model." in caplog.text + + +def test_when_no_deployment_name_provided(): + config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt") + with pytest.raises(ValueError): + llm = AzureOpenAILlm(config) + llm.get_llm_model_answer("Test Prompt") diff --git a/tests/llm/test_vertex_ai.py b/tests/llm/test_vertex_ai.py new file mode 100644 index 00000000..952d5c39 --- /dev/null +++ b/tests/llm/test_vertex_ai.py @@ -0,0 +1,63 @@ +import pytest +from unittest.mock import MagicMock, patch +from embedchain.llm.vertex_ai import VertexAiLlm +from embedchain.config import BaseLlmConfig +from langchain.schema import HumanMessage, SystemMessage + + +@pytest.fixture +def vertexai_llm(): + config = BaseLlmConfig(temperature=0.6, model="vertexai_model", system_prompt="System Prompt") + return VertexAiLlm(config) + + +def test_get_llm_model_answer(vertexai_llm): + with patch.object(VertexAiLlm, "_get_answer", return_value="Test Response") as mock_method: + prompt = "Test Prompt" + response = vertexai_llm.get_llm_model_answer(prompt) + assert response == "Test Response" + mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config) + + +def test_get_answer_with_warning(vertexai_llm, caplog): + with patch("langchain.chat_models.ChatVertexAI") as mock_chat: + mock_chat_instance = mock_chat.return_value + mock_chat_instance.return_value = MagicMock(content="Test Response") + + prompt = "Test Prompt" + config = vertexai_llm.config + config.top_p = 0.5 + + response = vertexai_llm._get_answer(prompt, config) + + assert response == "Test Response" + mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model) + + assert "Config option `top_p` is not supported by this model." in caplog.text + + +def test_get_answer_no_warning(vertexai_llm, caplog): + with patch("langchain.chat_models.ChatVertexAI") as mock_chat: + mock_chat_instance = mock_chat.return_value + mock_chat_instance.return_value = MagicMock(content="Test Response") + + prompt = "Test Prompt" + config = vertexai_llm.config + config.top_p = 1.0 + + response = vertexai_llm._get_answer(prompt, config) + + assert response == "Test Response" + mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model) + + assert "Config option `top_p` is not supported by this model." not in caplog.text + + +def test_get_messages(vertexai_llm): + prompt = "Test Prompt" + system_prompt = "Test System Prompt" + messages = vertexai_llm._get_messages(prompt, system_prompt) + assert messages == [ + SystemMessage(content="Test System Prompt", additional_kwargs={}), + HumanMessage(content="Test Prompt", additional_kwargs={}, example=False), + ]