[Feature]: Add support for creating app using yaml config (#787)

This commit is contained in:
Deshraj Yadav
2023-10-12 15:35:49 -07:00
committed by GitHub
parent 4820ea15d6
commit a86d7f52e9
36 changed files with 479 additions and 95 deletions

View File

@@ -23,9 +23,9 @@ jobs:
- name: Install dependencies
run: poetry install --all-extras
- name: Lint with ruff
run: make ci_lint
run: make lint
- name: Test with pytest
run: make ci_test
run: make test
- name: Generate coverage report
run: make coverage
- name: Upload coverage reports to Codecov

View File

@@ -31,19 +31,13 @@ format:
$(PYTHON) -m black .
$(PYTHON) -m isort .
lint:
$(PYTHON) -m ruff .
clean:
rm -rf dist build *.egg-info
test:
$(PYTHON) -m pytest
ci_lint:
lint:
poetry run ruff .
ci_test:
test:
poetry run pytest
coverage:

View File

@@ -1,12 +1,13 @@
import logging
from typing import Optional
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChromaDbConfig)
import yaml
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
@@ -35,7 +36,6 @@ class App(EmbedChain):
db_config: Optional[BaseVectorDbConfig] = None,
embedder: BaseEmbedder = None,
embedder_config: Optional[BaseEmbedderConfig] = None,
chromadb_config: Optional[ChromaDbConfig] = None,
system_prompt: Optional[str] = None,
):
"""
@@ -60,20 +60,10 @@ class App(EmbedChain):
:param embedder_config: Allows you to configure the Embedder.
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
:type embedder_config: Optional[BaseEmbedderConfig], optional
:param chromadb_config: Deprecated alias of `db_config`, defaults to None
:type chromadb_config: Optional[ChromaDbConfig], optional
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
:type system_prompt: Optional[str], optional
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
"""
# Overwrite deprecated arguments
if chromadb_config:
logging.warning(
"DEPRECATION WARNING: Please use `db_config` argument instead of `chromadb_config`."
"`chromadb_config` will be removed in a future release."
)
db_config = chromadb_config
# Type check configs
if config and not isinstance(config, AppConfig):
raise TypeError(
@@ -123,3 +113,33 @@ class App(EmbedChain):
"Please make sure the type is right and that you are passing an instance."
)
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
@classmethod
def from_config(cls, yaml_path: str):
"""
Instantiate an App object from a YAML configuration file.
:param yaml_path: Path to the YAML configuration file.
:type yaml_path: str
:return: An instance of the App class.
:rtype: App
"""
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
app_config_data = config_data.get("app", {})
llm_config_data = config_data.get("llm", {})
db_config_data = config_data.get("vectordb", {})
embedder_config_data = config_data.get("embedder", {})
app_config = AppConfig(**app_config_data.get("config", {}))
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
db_provider = db_config_data.get("provider", "chroma")
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
embedder_provider = embedder_config_data.get("provider", "openai")
embedder = EmbedderFactory.create(embedder_provider, embedder_config_data.get("config", {}))
return cls(config=app_config, llm=llm, db=db, embedder=embedder)

View File

@@ -3,7 +3,8 @@ from typing import Any
from embedchain import App
from embedchain.config import AddConfig, AppConfig, LlmConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB

View File

@@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig):
dir: str = "db",
host: Optional[str] = None,
port: Optional[str] = None,
**kwargs,
):
"""
Initializes a configuration class instance for the vector database.
@@ -22,8 +23,14 @@ class BaseVectorDbConfig(BaseConfig):
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
"""
self.collection_name = collection_name or "embedchain_store"
self.dir = dir
self.host = host
self.port = port
# Assign additional keyword arguments
if kwargs:
for key, value in kwargs.items():
setattr(self, key, value)

88
embedchain/factory.py Normal file
View File

@@ -0,0 +1,88 @@
import importlib
def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
class LlmFactory:
provider_to_class = {
"anthropic": "embedchain.llm.anthropic.AnthropicLlm",
"azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
"cohere": "embedchain.llm.cohere.CohereLlm",
"gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
"hugging_face_llm": "embedchain.llm.hugging_face_llm.HuggingFaceLlm",
"jina": "embedchain.llm.jina.JinaLlm",
"llama2": "embedchain.llm.llama2.Llama2Llm",
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
"openai": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
"anthropic": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
}
@classmethod
def create(cls, provider_name, config_data):
class_type = cls.provider_to_class.get(provider_name)
# Default to embedchain base config if the provider is not in the config map
config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
config_class_type = cls.provider_to_config_class.get(config_name)
if class_type:
llm_class = load_class(class_type)
llm_config_class = load_class(config_class_type)
return llm_class(config=llm_config_class(**config_data))
else:
raise ValueError(f"Unsupported Llm provider: {provider_name}")
class EmbedderFactory:
provider_to_class = {
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAiEmbedder",
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
}
provider_to_config_class = {
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
}
@classmethod
def create(cls, provider_name, config_data):
class_type = cls.provider_to_class.get(provider_name)
# Default to openai config if the provider is not in the config map
config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
config_class_type = cls.provider_to_config_class.get(config_name)
if class_type:
embedder_class = load_class(class_type)
embedder_config_class = load_class(config_class_type)
return embedder_class(config=embedder_config_class(**config_data))
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
class VectorDBFactory:
provider_to_class = {
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
}
@classmethod
def create(cls, provider_name, config_data):
class_type = cls.provider_to_class.get(provider_name)
config_class_type = cls.provider_to_config_class.get(provider_name)
if class_type:
embedder_class = load_class(class_type)
embedder_config_class = load_class(config_class_type)
return embedder_class(config=embedder_config_class(**config_data))
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")

View File

@@ -7,12 +7,12 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class AntrophicLlm(BaseLlm):
class AnthropicLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return AntrophicLlm._get_answer(prompt=prompt, config=self.config)
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:

View File

@@ -34,7 +34,8 @@ class JinaLlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:

View File

@@ -2,9 +2,7 @@ try:
from PIL import Image, UnidentifiedImageError
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
"Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
) from None
raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None
MODEL_NAME = "clip-ViT-B-32"

View File

@@ -37,7 +37,7 @@ class ChromaDB(BaseVectorDB):
self.config = ChromaDbConfig()
self.settings = Settings()
self.settings.allow_reset = self.config.allow_reset
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
if self.config.chroma_settings:
for key, value in self.config.chroma_settings.items():
if hasattr(self.settings, key):
@@ -72,6 +72,17 @@ class ChromaDB(BaseVectorDB):
"""Called during initialization"""
return self.client
def _generate_where_clause(self, where: Dict[str, any]) -> str:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if len(where.keys()) == 1:
return where
where_filters = []
for k, v in where.items():
if isinstance(v, str):
where_filters.append({k: v})
return {"$and": where_filters}
def _get_or_create_collection(self, name: str) -> Collection:
"""
Get or create a named collection.
@@ -107,13 +118,14 @@ class ChromaDB(BaseVectorDB):
if ids:
args["ids"] = ids
if where:
args["where"] = where
args["where"] = self._generate_where_clause(where)
if limit:
args["limit"] = limit
return self.collection.get(**args)
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
where_clause = self._generate_where_clause(where)
return self.collection.get(where=where_clause, limit=1)
def add(
self,

View File

@@ -110,8 +110,13 @@ class OpenSearchDB(BaseVectorDB):
return result
def add(
self, embeddings: List[List[str]], documents: List[str], metadatas: List[object], ids: List[str],
skip_embedding: bool):
self,
embeddings: List[List[str]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""add data in vector database
:param embeddings: list of embeddings to add
@@ -162,7 +167,8 @@ class OpenSearchDB(BaseVectorDB):
embedding_function=embeddings,
opensearch_url=f"{self.config.opensearch_url}",
http_auth=self.config.http_auth,
use_ssl=True,
use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl,
verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs,
)
pre_filter = {"match_all": {}} # default

View File

@@ -5,8 +5,8 @@ from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
try:
from pymilvus import MilvusClient
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
MilvusClient, connections, utility)
except ImportError:
raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"

View File

@@ -0,0 +1,26 @@
app:
config:
id: 'my-app'
collection_name: 'my-app'
llm:
provider: openai
model: 'gpt-3.5-turbo'
config:
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
vectordb:
provider: chroma
config:
collection_name: 'my-app'
dir: db
allow_reset: true
embedder:
provider: openai
config:
model: 'text-embedding-ada-002'
deployment_name: null

View File

@@ -0,0 +1,33 @@
app:
config:
id: 'my-app'
log_level: 'WARN'
collect_metrics: true
collection_name: 'my-app'
llm:
provider: openai
model: 'gpt-3.5-turbo'
config:
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
vectordb:
provider: opensearch
config:
opensearch_url: 'https://localhost:9200'
http_auth:
- admin
- admin
vector_dimension: 1536
collection_name: 'my-app'
use_ssl: false
verify_certs: false
embedder:
provider: openai
config:
model: 'text-embedding-ada-002'
deployment_name: null

View File

@@ -0,0 +1,27 @@
app:
config:
id: 'open-source-app'
collection_name: 'open-source-app'
collect_metrics: false
llm:
provider: gpt4all
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
config:
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
vectordb:
provider: chroma
config:
collection_name: 'open-source-app'
dir: db
allow_reset: true
embedder:
provider: gpt4all
config:
model: 'all-MiniLM-L6-v2'
deployment_name: null

View File

@@ -1,8 +1,11 @@
import os
import unittest
import yaml
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
@@ -100,3 +103,74 @@ class TestConfigForAppComponents(unittest.TestCase):
App(embedder_config=wrong_embedder_config)
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
class TestAppFromConfig:
def load_config_data(self, yaml_path):
with open(yaml_path, "r") as file:
return yaml.safe_load(file)
def test_from_chroma_config(self):
yaml_path = "embedchain/yaml/chroma.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
# Even though not present in the config, the default value is used
assert app.config.collect_metrics is True
# Validate the LLM config values
llm_config = config_data["llm"]["config"]
assert app.llm.config.temperature == llm_config["temperature"]
assert app.llm.config.max_tokens == llm_config["max_tokens"]
assert app.llm.config.top_p == llm_config["top_p"]
assert app.llm.config.stream == llm_config["stream"]
# Validate the VectorDB config values
db_config = config_data["vectordb"]["config"]
assert app.db.config.collection_name == db_config["collection_name"]
assert app.db.config.dir == db_config["dir"]
assert app.db.config.allow_reset == db_config["allow_reset"]
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedder.config.model == embedder_config["model"]
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
def test_from_opensource_config(self):
yaml_path = "embedchain/yaml/opensource.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
# Validate the LLM config values
llm_config = config_data["llm"]["config"]
assert app.llm.config.temperature == llm_config["temperature"]
assert app.llm.config.max_tokens == llm_config["max_tokens"]
assert app.llm.config.top_p == llm_config["top_p"]
assert app.llm.config.stream == llm_config["stream"]
# Validate the VectorDB config values
db_config = config_data["vectordb"]["config"]
assert app.db.config.collection_name == db_config["collection_name"]
assert app.db.config.dir == db_config["dir"]
assert app.db.config.allow_reset == db_config["allow_reset"]
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedder.config.model == embedder_config["model"]
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]

View File

@@ -1,9 +1,11 @@
import os
import pytest
from embedchain.config import AddConfig, BaseLlmConfig
from embedchain.bots.base import BaseBot
from unittest.mock import patch
import pytest
from embedchain.bots.base import BaseBot
from embedchain.config import AddConfig, BaseLlmConfig
@pytest.fixture
def base_bot():

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import MagicMock
import pytest
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.models.data_type import DataType

View File

@@ -41,7 +41,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
Test if the `App` instance is correctly reconstructed after a reset.
"""
config = AppConfig(log_level="DEBUG", collect_metrics=False)
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
chroma_config = {"allow_reset": True}
app = App(config=config, db_config=ChromaDbConfig(**chroma_config))
app.reset()
# Make sure the client is still healthy

View File

@@ -1,9 +1,11 @@
import pytest
from unittest.mock import MagicMock
from embedchain.embedder.base import BaseEmbedder
from embedchain.config.embedder.base import BaseEmbedderConfig
import pytest
from chromadb.api.types import Documents, Embeddings
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
@pytest.fixture
def base_embedder():

View File

@@ -1,62 +1,63 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.antrophic import AntrophicLlm
from embedchain.config import BaseLlmConfig
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.anthropic import AnthropicLlm
@pytest.fixture
def antrophic_llm():
def anthropic_llm():
config = BaseLlmConfig(temperature=0.5, model="gpt2")
return AntrophicLlm(config)
return AnthropicLlm(config)
def test_get_llm_model_answer(antrophic_llm):
with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
def test_get_llm_model_answer(anthropic_llm):
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = antrophic_llm.get_llm_model_answer(prompt)
response = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config)
def test_get_answer(antrophic_llm):
def test_get_answer(anthropic_llm):
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
response = anthropic_llm._get_answer(prompt, anthropic_llm.config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model
)
mock_chat_instance.assert_called_once_with(
antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
anthropic_llm._get_messages(prompt, system_prompt=anthropic_llm.config.system_prompt)
)
def test_get_messages(antrophic_llm):
def test_get_messages(anthropic_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = antrophic_llm._get_messages(prompt, system_prompt)
messages = anthropic_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
config = antrophic_llm.config
config = anthropic_llm.config
config.max_tokens = 500
response = antrophic_llm._get_answer(prompt, config)
response = anthropic_llm._get_answer(prompt, config)
assert response == "Test Response"
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)

View File

@@ -1,9 +1,11 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.azure_openai import AzureOpenAILlm
from embedchain.config import BaseLlmConfig
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.azure_openai import AzureOpenAILlm
@pytest.fixture
def azure_openai_llm():

View File

@@ -1,7 +1,7 @@
import importlib
import os
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
from embedchain.config import BaseLlmConfig
from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm

View File

@@ -1,9 +1,11 @@
import pytest
from unittest.mock import MagicMock, patch
from embedchain.llm.vertex_ai import VertexAiLlm
from embedchain.config import BaseLlmConfig
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.vertex_ai import VertexAiLlm
@pytest.fixture
def vertexai_llm():

View File

@@ -1,7 +1,9 @@
import hashlib
import pytest
from unittest.mock import Mock, patch
import pytest
from requests import Response
from embedchain.loaders.docs_site_loader import DocsSiteLoader

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import MagicMock, patch
import pytest
from embedchain.loaders.docx_file import DocxFileLoader

View File

@@ -1,5 +1,7 @@
import hashlib
import pytest
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader

View File

@@ -1,5 +1,7 @@
import hashlib
import pytest
from embedchain.loaders.local_text import LocalTextLoader

View File

@@ -1,6 +1,8 @@
import hashlib
from unittest.mock import mock_open, patch
import pytest
from unittest.mock import patch, mock_open
from embedchain.loaders.mdx import MdxLoader

View File

@@ -1,7 +1,9 @@
import hashlib
import os
import pytest
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.notion import NotionLoader

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.web_page import WebPageLoader

View File

@@ -1,6 +1,8 @@
import hashlib
import pytest
from unittest.mock import MagicMock, Mock, patch
import pytest
from embedchain.loaders.youtube_video import YoutubeVideoLoader

61
tests/test_factory.py Normal file
View File

@@ -0,0 +1,61 @@
import pytest
import embedchain
import embedchain.embedder.gpt4all
import embedchain.embedder.huggingface
import embedchain.embedder.openai
import embedchain.embedder.vertexai
import embedchain.llm.anthropic
import embedchain.llm.openai
import embedchain.vectordb.chroma
import embedchain.vectordb.elasticsearch
import embedchain.vectordb.opensearch
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
class TestFactories:
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("openai", {}, embedchain.llm.openai.OpenAILlm),
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
],
)
def test_llm_factory_create(self, provider_name, config_data, expected_class):
llm_instance = LlmFactory.create(provider_name, config_data)
assert isinstance(llm_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
(
"huggingface",
{"model": "sentence-transformers/all-mpnet-base-v2"},
embedchain.embedder.huggingface.HuggingFaceEmbedder,
),
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAiEmbedder),
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
],
)
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.embedder.vertexai.VertexAiEmbedder", autospec=True)
embedder_instance = EmbedderFactory.create(provider_name, config_data)
assert isinstance(embedder_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
(
"opensearch",
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
embedchain.vectordb.opensearch.OpenSearchDB,
),
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
],
)
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
assert isinstance(vectordb_instance, expected_class)

View File

@@ -28,19 +28,25 @@ class TestChromaDbHosts(unittest.TestCase):
host = "test-host"
port = "1234"
chroma_auth_settings = {
chroma_config = {
"host": host,
"port": port,
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
config = ChromaDbConfig(**chroma_config)
db = ChromaDB(config=config)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
)
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
@@ -55,9 +61,9 @@ class TestChromaDbHostsInit(unittest.TestCase):
port = "1234"
config = AppConfig(collect_metrics=False)
chromadb_config = ChromaDbConfig(host=host, port=port)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, chromadb_config=chromadb_config)
_app = App(config, db_config=db_config)
called_settings: Settings = mock_client.call_args[0][0]
@@ -95,7 +101,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
class TestChromaDbDuplicateHandling:
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_duplicates_throw_warning(self, caplog):
"""
@@ -131,7 +137,7 @@ class TestChromaDbDuplicateHandling:
class TestChromaDbCollection(unittest.TestCase):
chroma_config = ChromaDbConfig(allow_reset=True)
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
app_with_settings = App(config=app_config, db_config=chroma_config)
def test_init_with_default_collection(self):
"""
@@ -296,7 +302,7 @@ class TestChromaDbCollection(unittest.TestCase):
# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config)
app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection_name("one_collection")

View File

@@ -1,14 +1,15 @@
# ruff: noqa: E501
import os
from unittest import mock
from unittest.mock import Mock, patch
import pytest
from unittest import mock
from unittest.mock import patch, Mock
from embedchain.config import ZillizDBConfig
from embedchain.vectordb.zilliz import ZillizVectorDB
# to run tests, provide the URI and TOKEN in .env file
class TestZillizVectorDBConfig:
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
@@ -51,6 +52,7 @@ class TestZillizVectorDBConfig:
with pytest.raises(AttributeError):
ZillizDBConfig()
class TestZillizVectorDB:
@pytest.fixture
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})