[Feature]: Add support for creating app using yaml config (#787)
This commit is contained in:
@@ -7,5 +7,5 @@ from embedchain.apps.custom_app import CustomApp # noqa: F401
|
||||
from embedchain.apps.Llama2App import Llama2App # noqa: F401
|
||||
from embedchain.apps.open_source_app import OpenSourceApp # noqa: F401
|
||||
from embedchain.apps.person_app import (PersonApp, # noqa: F401
|
||||
PersonOpenSourceApp)
|
||||
PersonOpenSourceApp)
|
||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChromaDbConfig)
|
||||
import yaml
|
||||
|
||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
@@ -35,7 +36,6 @@ class App(EmbedChain):
|
||||
db_config: Optional[BaseVectorDbConfig] = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
embedder_config: Optional[BaseEmbedderConfig] = None,
|
||||
chromadb_config: Optional[ChromaDbConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@@ -60,20 +60,10 @@ class App(EmbedChain):
|
||||
:param embedder_config: Allows you to configure the Embedder.
|
||||
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
|
||||
:type embedder_config: Optional[BaseEmbedderConfig], optional
|
||||
:param chromadb_config: Deprecated alias of `db_config`, defaults to None
|
||||
:type chromadb_config: Optional[ChromaDbConfig], optional
|
||||
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
|
||||
"""
|
||||
# Overwrite deprecated arguments
|
||||
if chromadb_config:
|
||||
logging.warning(
|
||||
"DEPRECATION WARNING: Please use `db_config` argument instead of `chromadb_config`."
|
||||
"`chromadb_config` will be removed in a future release."
|
||||
)
|
||||
db_config = chromadb_config
|
||||
|
||||
# Type check configs
|
||||
if config and not isinstance(config, AppConfig):
|
||||
raise TypeError(
|
||||
@@ -123,3 +113,33 @@ class App(EmbedChain):
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, yaml_path: str):
|
||||
"""
|
||||
Instantiate an App object from a YAML configuration file.
|
||||
|
||||
:param yaml_path: Path to the YAML configuration file.
|
||||
:type yaml_path: str
|
||||
:return: An instance of the App class.
|
||||
:rtype: App
|
||||
"""
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
app_config_data = config_data.get("app", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedder_config_data = config_data.get("embedder", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data.get("config", {}))
|
||||
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
embedder_provider = embedder_config_data.get("provider", "openai")
|
||||
embedder = EmbedderFactory.create(embedder_provider, embedder_config_data.get("config", {}))
|
||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import Any
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, LlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
dir: str = "db",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the vector database.
|
||||
@@ -22,8 +23,14 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
:type host: Optional[str], optional
|
||||
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param kwargs: Additional keyword arguments
|
||||
:type kwargs: dict
|
||||
"""
|
||||
self.collection_name = collection_name or "embedchain_store"
|
||||
self.dir = dir
|
||||
self.host = host
|
||||
self.port = port
|
||||
# Assign additional keyword arguments
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
88
embedchain/factory.py
Normal file
88
embedchain/factory.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import importlib
|
||||
|
||||
|
||||
def load_class(class_type):
|
||||
module_path, class_name = class_type.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
class LlmFactory:
|
||||
provider_to_class = {
|
||||
"anthropic": "embedchain.llm.anthropic.AnthropicLlm",
|
||||
"azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
|
||||
"cohere": "embedchain.llm.cohere.CohereLlm",
|
||||
"gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
|
||||
"hugging_face_llm": "embedchain.llm.hugging_face_llm.HuggingFaceLlm",
|
||||
"jina": "embedchain.llm.jina.JinaLlm",
|
||||
"llama2": "embedchain.llm.llama2.Llama2Llm",
|
||||
"openai": "embedchain.llm.openai.OpenAILlm",
|
||||
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"embedchain": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
|
||||
"openai": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
|
||||
"anthropic": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, provider_name, config_data):
|
||||
class_type = cls.provider_to_class.get(provider_name)
|
||||
# Default to embedchain base config if the provider is not in the config map
|
||||
config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
|
||||
config_class_type = cls.provider_to_config_class.get(config_name)
|
||||
if class_type:
|
||||
llm_class = load_class(class_type)
|
||||
llm_config_class = load_class(config_class_type)
|
||||
return llm_class(config=llm_config_class(**config_data))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
||||
|
||||
|
||||
class EmbedderFactory:
|
||||
provider_to_class = {
|
||||
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
|
||||
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
|
||||
"vertexai": "embedchain.embedder.vertexai.VertexAiEmbedder",
|
||||
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, provider_name, config_data):
|
||||
class_type = cls.provider_to_class.get(provider_name)
|
||||
# Default to openai config if the provider is not in the config map
|
||||
config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
|
||||
config_class_type = cls.provider_to_config_class.get(config_name)
|
||||
if class_type:
|
||||
embedder_class = load_class(class_type)
|
||||
embedder_config_class = load_class(config_class_type)
|
||||
return embedder_class(config=embedder_config_class(**config_data))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
||||
|
||||
|
||||
class VectorDBFactory:
|
||||
provider_to_class = {
|
||||
"chroma": "embedchain.vectordb.chroma.ChromaDB",
|
||||
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
||||
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, provider_name, config_data):
|
||||
class_type = cls.provider_to_class.get(provider_name)
|
||||
config_class_type = cls.provider_to_config_class.get(provider_name)
|
||||
if class_type:
|
||||
embedder_class = load_class(class_type)
|
||||
embedder_config_class = load_class(config_class_type)
|
||||
return embedder_class(config=embedder_config_class(**config_data))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
|
||||
@@ -7,12 +7,12 @@ from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AntrophicLlm(BaseLlm):
|
||||
class AnthropicLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AntrophicLlm._get_answer(prompt=prompt, config=self.config)
|
||||
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
@@ -34,7 +34,8 @@ class JinaLlm(BaseLlm):
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
|
||||
@@ -2,9 +2,7 @@ try:
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
|
||||
) from None
|
||||
raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None
|
||||
|
||||
MODEL_NAME = "clip-ViT-B-32"
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class ChromaDB(BaseVectorDB):
|
||||
self.config = ChromaDbConfig()
|
||||
|
||||
self.settings = Settings()
|
||||
self.settings.allow_reset = self.config.allow_reset
|
||||
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
|
||||
if self.config.chroma_settings:
|
||||
for key, value in self.config.chroma_settings.items():
|
||||
if hasattr(self.settings, key):
|
||||
@@ -72,6 +72,17 @@ class ChromaDB(BaseVectorDB):
|
||||
"""Called during initialization"""
|
||||
return self.client
|
||||
|
||||
def _generate_where_clause(self, where: Dict[str, any]) -> str:
|
||||
# If only one filter is supplied, return it as is
|
||||
# (no need to wrap in $and based on chroma docs)
|
||||
if len(where.keys()) == 1:
|
||||
return where
|
||||
where_filters = []
|
||||
for k, v in where.items():
|
||||
if isinstance(v, str):
|
||||
where_filters.append({k: v})
|
||||
return {"$and": where_filters}
|
||||
|
||||
def _get_or_create_collection(self, name: str) -> Collection:
|
||||
"""
|
||||
Get or create a named collection.
|
||||
@@ -107,13 +118,14 @@ class ChromaDB(BaseVectorDB):
|
||||
if ids:
|
||||
args["ids"] = ids
|
||||
if where:
|
||||
args["where"] = where
|
||||
args["where"] = self._generate_where_clause(where)
|
||||
if limit:
|
||||
args["limit"] = limit
|
||||
return self.collection.get(**args)
|
||||
|
||||
def get_advanced(self, where):
|
||||
return self.collection.get(where=where, limit=1)
|
||||
where_clause = self._generate_where_clause(where)
|
||||
return self.collection.get(where=where_clause, limit=1)
|
||||
|
||||
def add(
|
||||
self,
|
||||
|
||||
@@ -110,8 +110,13 @@ class OpenSearchDB(BaseVectorDB):
|
||||
return result
|
||||
|
||||
def add(
|
||||
self, embeddings: List[List[str]], documents: List[str], metadatas: List[object], ids: List[str],
|
||||
skip_embedding: bool):
|
||||
self,
|
||||
embeddings: List[List[str]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
):
|
||||
"""add data in vector database
|
||||
|
||||
:param embeddings: list of embeddings to add
|
||||
@@ -162,7 +167,8 @@ class OpenSearchDB(BaseVectorDB):
|
||||
embedding_function=embeddings,
|
||||
opensearch_url=f"{self.config.opensearch_url}",
|
||||
http_auth=self.config.http_auth,
|
||||
use_ssl=True,
|
||||
use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl,
|
||||
verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs,
|
||||
)
|
||||
|
||||
pre_filter = {"match_all": {}} # default
|
||||
|
||||
@@ -5,8 +5,8 @@ from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
|
||||
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
|
||||
MilvusClient, connections, utility)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||
|
||||
26
embedchain/yaml/chroma.yaml
Normal file
26
embedchain/yaml/chroma.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
app:
|
||||
config:
|
||||
id: 'my-app'
|
||||
collection_name: 'my-app'
|
||||
|
||||
llm:
|
||||
provider: openai
|
||||
model: 'gpt-3.5-turbo'
|
||||
config:
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
stream: false
|
||||
|
||||
vectordb:
|
||||
provider: chroma
|
||||
config:
|
||||
collection_name: 'my-app'
|
||||
dir: db
|
||||
allow_reset: true
|
||||
|
||||
embedder:
|
||||
provider: openai
|
||||
config:
|
||||
model: 'text-embedding-ada-002'
|
||||
deployment_name: null
|
||||
33
embedchain/yaml/opensearch.yaml
Normal file
33
embedchain/yaml/opensearch.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
app:
|
||||
config:
|
||||
id: 'my-app'
|
||||
log_level: 'WARN'
|
||||
collect_metrics: true
|
||||
collection_name: 'my-app'
|
||||
|
||||
llm:
|
||||
provider: openai
|
||||
model: 'gpt-3.5-turbo'
|
||||
config:
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
stream: false
|
||||
|
||||
vectordb:
|
||||
provider: opensearch
|
||||
config:
|
||||
opensearch_url: 'https://localhost:9200'
|
||||
http_auth:
|
||||
- admin
|
||||
- admin
|
||||
vector_dimension: 1536
|
||||
collection_name: 'my-app'
|
||||
use_ssl: false
|
||||
verify_certs: false
|
||||
|
||||
embedder:
|
||||
provider: openai
|
||||
config:
|
||||
model: 'text-embedding-ada-002'
|
||||
deployment_name: null
|
||||
27
embedchain/yaml/opensource.yaml
Normal file
27
embedchain/yaml/opensource.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
app:
|
||||
config:
|
||||
id: 'open-source-app'
|
||||
collection_name: 'open-source-app'
|
||||
collect_metrics: false
|
||||
|
||||
llm:
|
||||
provider: gpt4all
|
||||
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
||||
config:
|
||||
temperature: 0.5
|
||||
max_tokens: 1000
|
||||
top_p: 1
|
||||
stream: false
|
||||
|
||||
vectordb:
|
||||
provider: chroma
|
||||
config:
|
||||
collection_name: 'open-source-app'
|
||||
dir: db
|
||||
allow_reset: true
|
||||
|
||||
embedder:
|
||||
provider: gpt4all
|
||||
config:
|
||||
model: 'all-MiniLM-L6-v2'
|
||||
deployment_name: null
|
||||
Reference in New Issue
Block a user