[Feature]: Add support for creating app using yaml config (#787)
This commit is contained in:
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -23,9 +23,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: poetry install --all-extras
|
||||
- name: Lint with ruff
|
||||
run: make ci_lint
|
||||
run: make lint
|
||||
- name: Test with pytest
|
||||
run: make ci_test
|
||||
run: make test
|
||||
- name: Generate coverage report
|
||||
run: make coverage
|
||||
- name: Upload coverage reports to Codecov
|
||||
|
||||
12
Makefile
12
Makefile
@@ -20,7 +20,7 @@ install_opensearch:
|
||||
|
||||
install_milvus:
|
||||
poetry install --extras milvus
|
||||
|
||||
|
||||
shell:
|
||||
poetry shell
|
||||
|
||||
@@ -31,19 +31,13 @@ format:
|
||||
$(PYTHON) -m black .
|
||||
$(PYTHON) -m isort .
|
||||
|
||||
lint:
|
||||
$(PYTHON) -m ruff .
|
||||
|
||||
clean:
|
||||
rm -rf dist build *.egg-info
|
||||
|
||||
test:
|
||||
$(PYTHON) -m pytest
|
||||
|
||||
ci_lint:
|
||||
lint:
|
||||
poetry run ruff .
|
||||
|
||||
ci_test:
|
||||
test:
|
||||
poetry run pytest
|
||||
|
||||
coverage:
|
||||
|
||||
@@ -7,5 +7,5 @@ from embedchain.apps.custom_app import CustomApp # noqa: F401
|
||||
from embedchain.apps.Llama2App import Llama2App # noqa: F401
|
||||
from embedchain.apps.open_source_app import OpenSourceApp # noqa: F401
|
||||
from embedchain.apps.person_app import (PersonApp, # noqa: F401
|
||||
PersonOpenSourceApp)
|
||||
PersonOpenSourceApp)
|
||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChromaDbConfig)
|
||||
import yaml
|
||||
|
||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
@@ -35,7 +36,6 @@ class App(EmbedChain):
|
||||
db_config: Optional[BaseVectorDbConfig] = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
embedder_config: Optional[BaseEmbedderConfig] = None,
|
||||
chromadb_config: Optional[ChromaDbConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@@ -60,20 +60,10 @@ class App(EmbedChain):
|
||||
:param embedder_config: Allows you to configure the Embedder.
|
||||
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
|
||||
:type embedder_config: Optional[BaseEmbedderConfig], optional
|
||||
:param chromadb_config: Deprecated alias of `db_config`, defaults to None
|
||||
:type chromadb_config: Optional[ChromaDbConfig], optional
|
||||
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
|
||||
"""
|
||||
# Overwrite deprecated arguments
|
||||
if chromadb_config:
|
||||
logging.warning(
|
||||
"DEPRECATION WARNING: Please use `db_config` argument instead of `chromadb_config`."
|
||||
"`chromadb_config` will be removed in a future release."
|
||||
)
|
||||
db_config = chromadb_config
|
||||
|
||||
# Type check configs
|
||||
if config and not isinstance(config, AppConfig):
|
||||
raise TypeError(
|
||||
@@ -123,3 +113,33 @@ class App(EmbedChain):
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, yaml_path: str):
|
||||
"""
|
||||
Instantiate an App object from a YAML configuration file.
|
||||
|
||||
:param yaml_path: Path to the YAML configuration file.
|
||||
:type yaml_path: str
|
||||
:return: An instance of the App class.
|
||||
:rtype: App
|
||||
"""
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
app_config_data = config_data.get("app", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedder_config_data = config_data.get("embedder", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data.get("config", {}))
|
||||
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
embedder_provider = embedder_config_data.get("provider", "openai")
|
||||
embedder = EmbedderFactory.create(embedder_provider, embedder_config_data.get("config", {}))
|
||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, LlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
dir: str = "db",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the vector database.
|
||||
@@ -22,8 +23,14 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
:type host: Optional[str], optional
|
||||
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param kwargs: Additional keyword arguments
|
||||
:type kwargs: dict
|
||||
"""
|
||||
self.collection_name = collection_name or "embedchain_store"
|
||||
self.dir = dir
|
||||
self.host = host
|
||||
self.port = port
|
||||
# Assign additional keyword arguments
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
88
embedchain/factory.py
Normal file
88
embedchain/factory.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import importlib
|
||||
|
||||
|
||||
def load_class(class_type):
|
||||
module_path, class_name = class_type.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
class LlmFactory:
|
||||
provider_to_class = {
|
||||
"anthropic": "embedchain.llm.anthropic.AnthropicLlm",
|
||||
"azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
|
||||
"cohere": "embedchain.llm.cohere.CohereLlm",
|
||||
"gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
|
||||
"hugging_face_llm": "embedchain.llm.hugging_face_llm.HuggingFaceLlm",
|
||||
"jina": "embedchain.llm.jina.JinaLlm",
|
||||
"llama2": "embedchain.llm.llama2.Llama2Llm",
|
||||
"openai": "embedchain.llm.openai.OpenAILlm",
|
||||
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"embedchain": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
|
||||
"openai": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
|
||||
"anthropic": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, provider_name, config_data):
|
||||
class_type = cls.provider_to_class.get(provider_name)
|
||||
# Default to embedchain base config if the provider is not in the config map
|
||||
config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
|
||||
config_class_type = cls.provider_to_config_class.get(config_name)
|
||||
if class_type:
|
||||
llm_class = load_class(class_type)
|
||||
llm_config_class = load_class(config_class_type)
|
||||
return llm_class(config=llm_config_class(**config_data))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
||||
|
||||
|
||||
class EmbedderFactory:
|
||||
provider_to_class = {
|
||||
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
|
||||
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
|
||||
"vertexai": "embedchain.embedder.vertexai.VertexAiEmbedder",
|
||||
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, provider_name, config_data):
|
||||
class_type = cls.provider_to_class.get(provider_name)
|
||||
# Default to openai config if the provider is not in the config map
|
||||
config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
|
||||
config_class_type = cls.provider_to_config_class.get(config_name)
|
||||
if class_type:
|
||||
embedder_class = load_class(class_type)
|
||||
embedder_config_class = load_class(config_class_type)
|
||||
return embedder_class(config=embedder_config_class(**config_data))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
||||
|
||||
|
||||
class VectorDBFactory:
|
||||
provider_to_class = {
|
||||
"chroma": "embedchain.vectordb.chroma.ChromaDB",
|
||||
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
||||
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, provider_name, config_data):
|
||||
class_type = cls.provider_to_class.get(provider_name)
|
||||
config_class_type = cls.provider_to_config_class.get(provider_name)
|
||||
if class_type:
|
||||
embedder_class = load_class(class_type)
|
||||
embedder_config_class = load_class(config_class_type)
|
||||
return embedder_class(config=embedder_config_class(**config_data))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
||||
@@ -7,12 +7,12 @@ from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AntrophicLlm(BaseLlm):
|
||||
class AnthropicLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AntrophicLlm._get_answer(prompt=prompt, config=self.config)
|
||||
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
@@ -34,7 +34,8 @@ class JinaLlm(BaseLlm):
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
|
||||
@@ -2,9 +2,7 @@ try:
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
|
||||
) from None
|
||||
raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None
|
||||
|
||||
MODEL_NAME = "clip-ViT-B-32"
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class ChromaDB(BaseVectorDB):
|
||||
self.config = ChromaDbConfig()
|
||||
|
||||
self.settings = Settings()
|
||||
self.settings.allow_reset = self.config.allow_reset
|
||||
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
|
||||
if self.config.chroma_settings:
|
||||
for key, value in self.config.chroma_settings.items():
|
||||
if hasattr(self.settings, key):
|
||||
@@ -72,6 +72,17 @@ class ChromaDB(BaseVectorDB):
|
||||
"""Called during initialization"""
|
||||
return self.client
|
||||
|
||||
def _generate_where_clause(self, where: Dict[str, any]) -> str:
|
||||
# If only one filter is supplied, return it as is
|
||||
# (no need to wrap in $and based on chroma docs)
|
||||
if len(where.keys()) == 1:
|
||||
return where
|
||||
where_filters = []
|
||||
for k, v in where.items():
|
||||
if isinstance(v, str):
|
||||
where_filters.append({k: v})
|
||||
return {"$and": where_filters}
|
||||
|
||||
def _get_or_create_collection(self, name: str) -> Collection:
|
||||
"""
|
||||
Get or create a named collection.
|
||||
@@ -107,13 +118,14 @@ class ChromaDB(BaseVectorDB):
|
||||
if ids:
|
||||
args["ids"] = ids
|
||||
if where:
|
||||
args["where"] = where
|
||||
args["where"] = self._generate_where_clause(where)
|
||||
if limit:
|
||||
args["limit"] = limit
|
||||
return self.collection.get(**args)
|
||||
|
||||
def get_advanced(self, where):
|
||||
return self.collection.get(where=where, limit=1)
|
||||
where_clause = self._generate_where_clause(where)
|
||||
return self.collection.get(where=where_clause, limit=1)
|
||||
|
||||
def add(
|
||||
self,
|
||||
|
||||
@@ -110,8 +110,13 @@ class OpenSearchDB(BaseVectorDB):
|
||||
return result
|
||||
|
||||
def add(
|
||||
self, embeddings: List[List[str]], documents: List[str], metadatas: List[object], ids: List[str],
|
||||
skip_embedding: bool):
|
||||
self,
|
||||
embeddings: List[List[str]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
):
|
||||
"""add data in vector database
|
||||
|
||||
:param embeddings: list of embeddings to add
|
||||
@@ -162,7 +167,8 @@ class OpenSearchDB(BaseVectorDB):
|
||||
embedding_function=embeddings,
|
||||
opensearch_url=f"{self.config.opensearch_url}",
|
||||
http_auth=self.config.http_auth,
|
||||
use_ssl=True,
|
||||
use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl,
|
||||
verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs,
|
||||
)
|
||||
|
||||
pre_filter = {"match_all": {}} # default
|
||||
|
||||
@@ -5,8 +5,8 @@ from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
|
||||
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
|
||||
MilvusClient, connections, utility)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||
|
||||
26
embedchain/yaml/chroma.yaml
Normal file
26
embedchain/yaml/chroma.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
app:
|
||||
config:
|
||||
id: 'my-app'
|
||||
collection_name: 'my-app'
|
||||
|
||||
llm:
|
||||
provider: openai
|
||||
model: 'gpt-3.5-turbo'
|
||||
config:
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
stream: false
|
||||
|
||||
vectordb:
|
||||
provider: chroma
|
||||
config:
|
||||
collection_name: 'my-app'
|
||||
dir: db
|
||||
allow_reset: true
|
||||
|
||||
embedder:
|
||||
provider: openai
|
||||
config:
|
||||
model: 'text-embedding-ada-002'
|
||||
deployment_name: null
|
||||
33
embedchain/yaml/opensearch.yaml
Normal file
33
embedchain/yaml/opensearch.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
app:
|
||||
config:
|
||||
id: 'my-app'
|
||||
log_level: 'WARN'
|
||||
collect_metrics: true
|
||||
collection_name: 'my-app'
|
||||
|
||||
llm:
|
||||
provider: openai
|
||||
model: 'gpt-3.5-turbo'
|
||||
config:
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
stream: false
|
||||
|
||||
vectordb:
|
||||
provider: opensearch
|
||||
config:
|
||||
opensearch_url: 'https://localhost:9200'
|
||||
http_auth:
|
||||
- admin
|
||||
- admin
|
||||
vector_dimension: 1536
|
||||
collection_name: 'my-app'
|
||||
use_ssl: false
|
||||
verify_certs: false
|
||||
|
||||
embedder:
|
||||
provider: openai
|
||||
config:
|
||||
model: 'text-embedding-ada-002'
|
||||
deployment_name: null
|
||||
27
embedchain/yaml/opensource.yaml
Normal file
27
embedchain/yaml/opensource.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
app:
|
||||
config:
|
||||
id: 'open-source-app'
|
||||
collection_name: 'open-source-app'
|
||||
collect_metrics: false
|
||||
|
||||
llm:
|
||||
provider: gpt4all
|
||||
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
||||
config:
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
stream: false
|
||||
|
||||
vectordb:
|
||||
provider: chroma
|
||||
config:
|
||||
collection_name: 'open-source-app'
|
||||
dir: db
|
||||
allow_reset: true
|
||||
|
||||
embedder:
|
||||
provider: gpt4all
|
||||
config:
|
||||
model: 'all-MiniLM-L6-v2'
|
||||
deployment_name: null
|
||||
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
|
||||
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
|
||||
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
|
||||
BaseLlmConfig, ChromaDbConfig)
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
@@ -100,3 +103,74 @@ class TestConfigForAppComponents(unittest.TestCase):
|
||||
App(embedder_config=wrong_embedder_config)
|
||||
|
||||
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
with open(yaml_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
def test_from_chroma_config(self):
|
||||
yaml_path = "embedchain/yaml/chroma.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
# Even though not present in the config, the default value is used
|
||||
assert app.config.collect_metrics is True
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
|
||||
def test_from_opensource_config(self):
|
||||
yaml_path = "embedchain/yaml/opensource.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
from embedchain.bots.base import BaseBot
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.bots.base import BaseBot
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_bot():
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
@@ -41,7 +41,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
Test if the `App` instance is correctly reconstructed after a reset.
|
||||
"""
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
|
||||
chroma_config = {"allow_reset": True}
|
||||
app = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||
app.reset()
|
||||
|
||||
# Make sure the client is still healthy
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
|
||||
import pytest
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_embedder():
|
||||
|
||||
@@ -1,62 +1,63 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain.llm.antrophic import AntrophicLlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.anthropic import AnthropicLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def antrophic_llm():
|
||||
def anthropic_llm():
|
||||
config = BaseLlmConfig(temperature=0.5, model="gpt2")
|
||||
return AntrophicLlm(config)
|
||||
return AnthropicLlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(antrophic_llm):
|
||||
with patch.object(AntrophicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
def test_get_llm_model_answer(anthropic_llm):
|
||||
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = antrophic_llm.get_llm_model_answer(prompt)
|
||||
response = anthropic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=antrophic_llm.config)
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=anthropic_llm.config)
|
||||
|
||||
|
||||
def test_get_answer(antrophic_llm):
|
||||
def test_get_answer(anthropic_llm):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
response = antrophic_llm._get_answer(prompt, antrophic_llm.config)
|
||||
response = anthropic_llm._get_answer(prompt, anthropic_llm.config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
temperature=antrophic_llm.config.temperature, model=antrophic_llm.config.model
|
||||
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model
|
||||
)
|
||||
mock_chat_instance.assert_called_once_with(
|
||||
antrophic_llm._get_messages(prompt, system_prompt=antrophic_llm.config.system_prompt)
|
||||
anthropic_llm._get_messages(prompt, system_prompt=anthropic_llm.config.system_prompt)
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages(antrophic_llm):
|
||||
def test_get_messages(anthropic_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = antrophic_llm._get_messages(prompt, system_prompt)
|
||||
messages = anthropic_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_get_answer_max_tokens_is_provided(antrophic_llm, caplog):
|
||||
def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
config = antrophic_llm.config
|
||||
config = anthropic_llm.config
|
||||
config.max_tokens = 500
|
||||
|
||||
response = antrophic_llm._get_answer(prompt, config)
|
||||
response = anthropic_llm._get_answer(prompt, config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_llm():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from embedchain.llm.vertex_ai import VertexAiLlm
|
||||
from embedchain.config import BaseLlmConfig
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.vertex_ai import VertexAiLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_llm():
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.docx_file import DocxFileLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_text import LocalTextLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import hashlib
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.notion import NotionLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
|
||||
|
||||
61
tests/test_factory.py
Normal file
61
tests/test_factory.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
|
||||
import embedchain
|
||||
import embedchain.embedder.gpt4all
|
||||
import embedchain.embedder.huggingface
|
||||
import embedchain.embedder.openai
|
||||
import embedchain.embedder.vertexai
|
||||
import embedchain.llm.anthropic
|
||||
import embedchain.llm.openai
|
||||
import embedchain.vectordb.chroma
|
||||
import embedchain.vectordb.elasticsearch
|
||||
import embedchain.vectordb.opensearch
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
|
||||
|
||||
class TestFactories:
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("openai", {}, embedchain.llm.openai.OpenAILlm),
|
||||
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
|
||||
],
|
||||
)
|
||||
def test_llm_factory_create(self, provider_name, config_data, expected_class):
|
||||
llm_instance = LlmFactory.create(provider_name, config_data)
|
||||
assert isinstance(llm_instance, expected_class)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
|
||||
(
|
||||
"huggingface",
|
||||
{"model": "sentence-transformers/all-mpnet-base-v2"},
|
||||
embedchain.embedder.huggingface.HuggingFaceEmbedder,
|
||||
),
|
||||
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAiEmbedder),
|
||||
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
|
||||
],
|
||||
)
|
||||
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
|
||||
mocker.patch("embedchain.embedder.vertexai.VertexAiEmbedder", autospec=True)
|
||||
embedder_instance = EmbedderFactory.create(provider_name, config_data)
|
||||
assert isinstance(embedder_instance, expected_class)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
|
||||
(
|
||||
"opensearch",
|
||||
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
|
||||
embedchain.vectordb.opensearch.OpenSearchDB,
|
||||
),
|
||||
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
|
||||
],
|
||||
)
|
||||
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
|
||||
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
|
||||
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
|
||||
assert isinstance(vectordb_instance, expected_class)
|
||||
@@ -28,19 +28,25 @@ class TestChromaDbHosts(unittest.TestCase):
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
|
||||
chroma_auth_settings = {
|
||||
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
chroma_config = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"chroma_settings": {
|
||||
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
},
|
||||
}
|
||||
|
||||
config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings)
|
||||
config = ChromaDbConfig(**chroma_config)
|
||||
db = ChromaDB(config=config)
|
||||
settings = db.client.get_settings()
|
||||
self.assertEqual(settings.chroma_server_host, host)
|
||||
self.assertEqual(settings.chroma_server_http_port, port)
|
||||
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
|
||||
self.assertEqual(
|
||||
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
|
||||
settings.chroma_client_auth_provider, chroma_config["chroma_settings"]["chroma_client_auth_provider"]
|
||||
)
|
||||
self.assertEqual(
|
||||
settings.chroma_client_auth_credentials, chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
|
||||
)
|
||||
|
||||
|
||||
@@ -55,9 +61,9 @@ class TestChromaDbHostsInit(unittest.TestCase):
|
||||
port = "1234"
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chromadb_config = ChromaDbConfig(host=host, port=port)
|
||||
db_config = ChromaDbConfig(host=host, port=port)
|
||||
|
||||
_app = App(config, chromadb_config=chromadb_config)
|
||||
_app = App(config, db_config=db_config)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
|
||||
@@ -95,7 +101,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
class TestChromaDbDuplicateHandling:
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
app_with_settings = App(config=app_config, db_config=chroma_config)
|
||||
|
||||
def test_duplicates_throw_warning(self, caplog):
|
||||
"""
|
||||
@@ -131,7 +137,7 @@ class TestChromaDbDuplicateHandling:
|
||||
class TestChromaDbCollection(unittest.TestCase):
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_with_settings = App(config=app_config, chromadb_config=chroma_config)
|
||||
app_with_settings = App(config=app_config, db_config=chroma_config)
|
||||
|
||||
def test_init_with_default_collection(self):
|
||||
"""
|
||||
@@ -296,7 +302,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
||||
|
||||
# Create four apps.
|
||||
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
|
||||
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), db_config=self.chroma_config)
|
||||
app1.set_collection_name("one_collection")
|
||||
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from embedchain.config import ZillizDBConfig
|
||||
from embedchain.vectordb.zilliz import ZillizVectorDB
|
||||
|
||||
|
||||
# to run tests, provide the URI and TOKEN in .env file
|
||||
class TestZillizVectorDBConfig:
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
@@ -51,6 +52,7 @@ class TestZillizVectorDBConfig:
|
||||
with pytest.raises(AttributeError):
|
||||
ZillizDBConfig()
|
||||
|
||||
|
||||
class TestZillizVectorDB:
|
||||
@pytest.fixture
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
@@ -147,7 +149,7 @@ class TestZillizDBCollection:
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
# Add a 'embedder' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
|
||||
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
|
||||
|
||||
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
|
||||
|
||||
Reference in New Issue
Block a user