[Feature]: Add support for creating app using yaml config (#787)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
|
||||
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
|
||||
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
|
||||
BaseLlmConfig, ChromaDbConfig)
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
@@ -100,3 +103,74 @@ class TestConfigForAppComponents(unittest.TestCase):
|
||||
App(embedder_config=wrong_embedder_config)
|
||||
|
||||
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
with open(yaml_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
def test_from_chroma_config(self):
|
||||
yaml_path = "embedchain/yaml/chroma.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
# Even though not present in the config, the default value is used
|
||||
assert app.config.collect_metrics is True
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
|
||||
def test_from_opensource_config(self):
|
||||
yaml_path = "embedchain/yaml/opensource.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
from embedchain.bots.base import BaseBot
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.bots.base import BaseBot
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_bot():
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
@@ -41,7 +41,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
Test if the `App` instance is correctly reconstructed after a reset.
|
||||
"""
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
|
||||
chroma_config = {"allow_reset": True}
|
||||
app = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||
app.reset()
|
||||
|
||||
# Make sure the client is still healthy
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
|
||||
import pytest
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_embedder():
|
||||
|
||||
@@ -1,62 +1,63 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain.llm.antrophic import AntrophicLlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.anthropic import AnthropicLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def antrophic_llm():
|
||||
def anthropic_llm():
|
||||
config = BaseLlmConfig(temperature=0.5, model="gpt2")
|
||||
return AntrophicLlm(config)
|
||||
return AnthropicLlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(antrophic_llm):
|
||||
with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
def test_get_llm_model_answer(anthropic_llm):
|
||||
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = antrophic_llm.get_llm_model_answer(prompt)
|
||||
response = anthropic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config)
|
||||
|
||||
|
||||
def test_get_answer(antrophic_llm):
|
||||
def test_get_answer(anthropic_llm):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
|
||||
response = anthropic_llm._get_answer(prompt, anthropic_llm.config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
|
||||
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model
|
||||
)
|
||||
mock_chat_instance.assert_called_once_with(
|
||||
antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
|
||||
anthropic_llm._get_messages(prompt, system_prompt=anthropic_llm.config.system_prompt)
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages(antrophic_llm):
|
||||
def test_get_messages(anthropic_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = antrophic_llm._get_messages(prompt, system_prompt)
|
||||
messages = anthropic_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
|
||||
def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
config = antrophic_llm.config
|
||||
config = anthropic_llm.config
|
||||
config.max_tokens = 500
|
||||
|
||||
response = antrophic_llm._get_answer(prompt, config)
|
||||
response = anthropic_llm._get_answer(prompt, config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_llm():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from embedchain.llm.vertex_ai import VertexAiLlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.vertex_ai import VertexAiLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_llm():
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.docx_file import DocxFileLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_text import LocalTextLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import hashlib
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.notion import NotionLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
|
||||
|
||||
61
tests/test_factory.py
Normal file
61
tests/test_factory.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
|
||||
import embedchain
|
||||
import embedchain.embedder.gpt4all
|
||||
import embedchain.embedder.huggingface
|
||||
import embedchain.embedder.openai
|
||||
import embedchain.embedder.vertexai
|
||||
import embedchain.llm.anthropic
|
||||
import embedchain.llm.openai
|
||||
import embedchain.vectordb.chroma
|
||||
import embedchain.vectordb.elasticsearch
|
||||
import embedchain.vectordb.opensearch
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
|
||||
|
||||
class TestFactories:
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("openai", {}, embedchain.llm.openai.OpenAILlm),
|
||||
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
|
||||
],
|
||||
)
|
||||
def test_llm_factory_create(self, provider_name, config_data, expected_class):
|
||||
llm_instance = LlmFactory.create(provider_name, config_data)
|
||||
assert isinstance(llm_instance, expected_class)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
|
||||
(
|
||||
"huggingface",
|
||||
{"model": "sentence-transformers/all-mpnet-base-v2"},
|
||||
embedchain.embedder.huggingface.HuggingFaceEmbedder,
|
||||
),
|
||||
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAiEmbedder),
|
||||
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
|
||||
],
|
||||
)
|
||||
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
|
||||
mocker.patch("embedchain.embedder.vertexai.VertexAiEmbedder", autospec=True)
|
||||
embedder_instance = EmbedderFactory.create(provider_name, config_data)
|
||||
assert isinstance(embedder_instance, expected_class)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
|
||||
(
|
||||
"opensearch",
|
||||
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
|
||||
embedchain.vectordb.opensearch.OpenSearchDB,
|
||||
),
|
||||
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
|
||||
],
|
||||
)
|
||||
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
|
||||
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
|
||||
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
|
||||
assert isinstance(vectordb_instance, expected_class)
|
||||
@@ -28,19 +28,25 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
|
||||
chroma_auth_settings = {
|
||||
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
chroma_config = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"chroma_settings": {
|
||||
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
},
|
||||
}
|
||||
|
||||
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
|
||||
config = ChromaDbConfig(**chroma_config)
|
||||
db = ChromaDB(config=config)
|
||||
settings = db.client.get_settings()
|
||||
self.assertEqual(settings.chroma_server_host, host)
|
||||
self.assertEqual(settings.chroma_server_http_port, port)
|
||||
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
|
||||
self.assertEqual(
|
||||
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
|
||||
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
|
||||
)
|
||||
|
||||
|
||||
@@ -55,9 +61,9 @@ class TestChromaDbHostsInit(unittest.TestCase):
|
||||
port = "1234"
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chromadb_config = ChromaDbConfig(host=host, port=port)
|
||||
db_config = ChromaDbConfig(host=host, port=port)
|
||||
|
||||
_app = App(config, chromadb_config=chromadb_config)
|
||||
_app = App(config, db_config=db_config)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
|
||||
@@ -95,7 +101,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
class TestChromaDbDuplicateHandling:
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
app_with_settings = App(config=app_config, db_config=chroma_config)
|
||||
|
||||
def test_duplicates_throw_warning(self, caplog):
|
||||
"""
|
||||
@@ -131,7 +137,7 @@ class TestChromaDbDuplicateHandling:
|
||||
class TestChromaDbCollection(unittest.TestCase):
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
app_with_settings = App(config=app_config, db_config=chroma_config)
|
||||
|
||||
def test_init_with_default_collection(self):
|
||||
"""
|
||||
@@ -296,7 +302,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
|
||||
# Create four apps.
|
||||
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config)
|
||||
app1.set_collection_name("one_collection")
|
||||
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from embedchain.config import ZillizDBConfig
|
||||
from embedchain.vectordb.zilliz import ZillizVectorDB
|
||||
|
||||
|
||||
# to run tests, provide the URI and TOKEN in .env file
|
||||
class TestZillizVectorDBConfig:
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
@@ -51,6 +52,7 @@ class TestZillizVectorDBConfig:
|
||||
with pytest.raises(AttributeError):
|
||||
ZillizDBConfig()
|
||||
|
||||
|
||||
class TestZillizVectorDB:
|
||||
@pytest.fixture
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
@@ -147,7 +149,7 @@ class TestZillizDBCollection:
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
# Add a 'embedder' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
|
||||
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
|
||||
|
||||
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
|
||||
|
||||
Reference in New Issue
Block a user