[Feature]: Add support for creating app using yaml config (#787)

This commit is contained in:
Deshraj Yadav
2023-10-12 15:35:49 -07:00
committed by GitHub
parent 4820ea15d6
commit a86d7f52e9
36 changed files with 479 additions and 95 deletions

View File

@@ -1,8 +1,11 @@
import os
import unittest
import yaml
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
@@ -100,3 +103,74 @@ class TestConfigForAppComponents(unittest.TestCase):
App(embedder_config=wrong_embedder_config)
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
class TestAppFromConfig:
def load_config_data(self, yaml_path):
with open(yaml_path, "r") as file:
return yaml.safe_load(file)
def test_from_chroma_config(self):
yaml_path = "embedchain/yaml/chroma.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
# Even though not present in the config, the default value is used
assert app.config.collect_metrics is True
# Validate the LLM config values
llm_config = config_data["llm"]["config"]
assert app.llm.config.temperature == llm_config["temperature"]
assert app.llm.config.max_tokens == llm_config["max_tokens"]
assert app.llm.config.top_p == llm_config["top_p"]
assert app.llm.config.stream == llm_config["stream"]
# Validate the VectorDB config values
db_config = config_data["vectordb"]["config"]
assert app.db.config.collection_name == db_config["collection_name"]
assert app.db.config.dir == db_config["dir"]
assert app.db.config.allow_reset == db_config["allow_reset"]
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedder.config.model == embedder_config["model"]
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
def test_from_opensource_config(self):
yaml_path = "embedchain/yaml/opensource.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
# Validate the LLM config values
llm_config = config_data["llm"]["config"]
assert app.llm.config.temperature == llm_config["temperature"]
assert app.llm.config.max_tokens == llm_config["max_tokens"]
assert app.llm.config.top_p == llm_config["top_p"]
assert app.llm.config.stream == llm_config["stream"]
# Validate the VectorDB config values
db_config = config_data["vectordb"]["config"]
assert app.db.config.collection_name == db_config["collection_name"]
assert app.db.config.dir == db_config["dir"]
assert app.db.config.allow_reset == db_config["allow_reset"]
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedder.config.model == embedder_config["model"]
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]

View File

@@ -1,9 +1,11 @@
import os
import pytest
from embedchain.config import AddConfig, BaseLlmConfig
from embedchain.bots.base import BaseBot
from unittest.mock import patch
import pytest
from embedchain.bots.base import BaseBot
from embedchain.config import AddConfig, BaseLlmConfig
@pytest.fixture
def base_bot():

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import MagicMock
import pytest
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.models.data_type import DataType

View File

@@ -41,7 +41,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
Test if the `App` instance is correctly reconstructed after a reset.
"""
config = AppConfig(log_level="DEBUG", collect_metrics=False)
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
chroma_config = {"allow_reset": True}
app = App(config=config, db_config=ChromaDbConfig(**chroma_config))
app.reset()
# Make sure the client is still healthy

View File

@@ -1,9 +1,11 @@
import pytest
from unittest.mock import MagicMock
from embedchain.embedder.base import BaseEmbedder
from embedchain.config.embedder.base import BaseEmbedderConfig
import pytest
from chromadb.api.types import Documents, Embeddings
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
@pytest.fixture
def base_embedder():

View File

@@ -1,62 +1,63 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.antrophic import AntrophicLlm
from embedchain.config import BaseLlmConfig
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.anthropic import AnthropicLlm
@pytest.fixture
def antrophic_llm():
def anthropic_llm():
config = BaseLlmConfig(temperature=0.5, model="gpt2")
return AntrophicLlm(config)
return AnthropicLlm(config)
def test_get_llm_model_answer(antrophic_llm):
with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
def test_get_llm_model_answer(anthropic_llm):
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = antrophic_llm.get_llm_model_answer(prompt)
response = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config)
def test_get_answer(antrophic_llm):
def test_get_answer(anthropic_llm):
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
response = anthropic_llm._get_answer(prompt, anthropic_llm.config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model
)
mock_chat_instance.assert_called_once_with(
antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
anthropic_llm._get_messages(prompt, system_prompt=anthropic_llm.config.system_prompt)
)
def test_get_messages(antrophic_llm):
def test_get_messages(anthropic_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = antrophic_llm._get_messages(prompt, system_prompt)
messages = anthropic_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
config = antrophic_llm.config
config = anthropic_llm.config
config.max_tokens = 500
response = antrophic_llm._get_answer(prompt, config)
response = anthropic_llm._get_answer(prompt, config)
assert response == "Test Response"
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)

View File

@@ -1,9 +1,11 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.azure_openai import AzureOpenAILlm
from embedchain.config import BaseLlmConfig
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.azure_openai import AzureOpenAILlm
@pytest.fixture
def azure_openai_llm():

View File

@@ -1,7 +1,7 @@
import importlib
import os
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
from embedchain.config import BaseLlmConfig
from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm

View File

@@ -1,9 +1,11 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.vertex_ai import VertexAiLlm
from embedchain.config import BaseLlmConfig
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.vertex_ai import VertexAiLlm
@pytest.fixture
def vertexai_llm():

View File

@@ -1,7 +1,9 @@
import hashlib
import pytest
from unittest.mock import Mock, patch
import pytest
from requests import Response
from embedchain.loaders.docs_site_loader import DocsSiteLoader

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import MagicMock, patch
import pytest
from embedchain.loaders.docx_file import DocxFileLoader

View File

@@ -1,5 +1,7 @@
import hashlib
import pytest
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader

View File

@@ -1,5 +1,7 @@
import hashlib
import pytest
from embedchain.loaders.local_text import LocalTextLoader

View File

@@ -1,6 +1,8 @@
import hashlib
from unittest.mock import mock_open, patch
import pytest
from unittest.mock import patch, mock_open
from embedchain.loaders.mdx import MdxLoader

View File

@@ -1,7 +1,9 @@
import hashlib
import os
import pytest
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.notion import NotionLoader

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.web_page import WebPageLoader

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import MagicMock, Mock, patch
import pytest
from embedchain.loaders.youtube_video import YoutubeVideoLoader

61
tests/test_factory.py Normal file
View File

@@ -0,0 +1,61 @@
import pytest
import embedchain
import embedchain.embedder.gpt4all
import embedchain.embedder.huggingface
import embedchain.embedder.openai
import embedchain.embedder.vertexai
import embedchain.llm.anthropic
import embedchain.llm.openai
import embedchain.vectordb.chroma
import embedchain.vectordb.elasticsearch
import embedchain.vectordb.opensearch
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
class TestFactories:
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("openai", {}, embedchain.llm.openai.OpenAILlm),
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
],
)
def test_llm_factory_create(self, provider_name, config_data, expected_class):
llm_instance = LlmFactory.create(provider_name, config_data)
assert isinstance(llm_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
(
"huggingface",
{"model": "sentence-transformers/all-mpnet-base-v2"},
embedchain.embedder.huggingface.HuggingFaceEmbedder,
),
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAiEmbedder),
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
],
)
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.embedder.vertexai.VertexAiEmbedder", autospec=True)
embedder_instance = EmbedderFactory.create(provider_name, config_data)
assert isinstance(embedder_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
(
"opensearch",
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
embedchain.vectordb.opensearch.OpenSearchDB,
),
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
],
)
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
assert isinstance(vectordb_instance, expected_class)

View File

@@ -28,19 +28,25 @@ class TestChromaDbHosts(unittest.TestCase):
host = "test-host"
port = "1234"
chroma_auth_settings = {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
chroma_config = {
"host": host,
"port": port,
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
config = ChromaDbConfig(**chroma_config)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
)
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
@@ -55,9 +61,9 @@ class TestChromaDbHostsInit(unittest.TestCase):
port = "1234"
config = AppConfig(collect_metrics=False)
chromadb_config = ChromaDbConfig(host=host, port=port)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, chromadb_config=chromadb_config)
_app = App(config, db_config=db_config)
called_settings: Settings = mock_client.call_args[0][0]
@@ -95,7 +101,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
class TestChromaDbDuplicateHandling:
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_duplicates_throw_warning(self, caplog):
"""
@@ -131,7 +137,7 @@ class TestChromaDbDuplicateHandling:
class TestChromaDbCollection(unittest.TestCase):
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_init_with_default_collection(self):
"""
@@ -296,7 +302,7 @@ class TestChromaDbCollection(unittest.TestCase):
# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config)
app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection_name("one_collection")

View File

@@ -1,14 +1,15 @@
# ruff: noqa: E501
import os
from unittest import mock
from unittest.mock import Mock, patch
import pytest
from unittest import mock
from unittest.mock import patch, Mock
from embedchain.config import ZillizDBConfig
from embedchain.vectordb.zilliz import ZillizVectorDB
# to run tests, provide the URI and TOKEN in .env file
class TestZillizVectorDBConfig:
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
@@ -51,6 +52,7 @@ class TestZillizVectorDBConfig:
with pytest.raises(AttributeError):
ZillizDBConfig()
class TestZillizVectorDB:
@pytest.fixture
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
@@ -147,7 +149,7 @@ class TestZillizDBCollection:
zilliz_db = ZillizVectorDB(config=mock_config)
# Add a 'embedder' attribute to the ZillizVectorDB instance for testing
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object