[Feature]: Add support for creating app using yaml config (#787)

This commit is contained in:
Deshraj Yadav
2023-10-12 15:35:49 -07:00
committed by GitHub
parent 4820ea15d6
commit a86d7f52e9
36 changed files with 479 additions and 95 deletions

View File

@@ -7,5 +7,5 @@ from embedchain.apps.custom_app import CustomApp # noqa: F401
from embedchain.apps.Llama2App import Llama2App # noqa: F401
from embedchain.apps.open_source_app import OpenSourceApp # noqa: F401
from embedchain.apps.person_app import (PersonApp, # noqa: F401
PersonOpenSourceApp)
PersonOpenSourceApp)
from embedchain.vectordb.chroma import ChromaDB # noqa: F401

View File

@@ -1,12 +1,13 @@
import logging
from typing import Optional
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChromaDbConfig)
import yaml
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
@@ -35,7 +36,6 @@ class App(EmbedChain):
db_config: Optional[BaseVectorDbConfig] = None,
embedder: BaseEmbedder = None,
embedder_config: Optional[BaseEmbedderConfig] = None,
chromadb_config: Optional[ChromaDbConfig] = None,
system_prompt: Optional[str] = None,
):
"""
@@ -60,20 +60,10 @@ class App(EmbedChain):
:param embedder_config: Allows you to configure the Embedder.
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
:type embedder_config: Optional[BaseEmbedderConfig], optional
:param chromadb_config: Deprecated alias of `db_config`, defaults to None
:type chromadb_config: Optional[ChromaDbConfig], optional
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
:type system_prompt: Optional[str], optional
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
"""
# Overwrite deprecated arguments
if chromadb_config:
logging.warning(
"DEPRECATION WARNING: Please use `db_config` argument instead of `chromadb_config`."
"`chromadb_config` will be removed in a future release."
)
db_config = chromadb_config
# Type check configs
if config and not isinstance(config, AppConfig):
raise TypeError(
@@ -123,3 +113,33 @@ class App(EmbedChain):
"Please make sure the type is right and that you are passing an instance."
)
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
@classmethod
def from_config(cls, yaml_path: str):
"""
Instantiate an App object from a YAML configuration file.
:param yaml_path: Path to the YAML configuration file.
:type yaml_path: str
:return: An instance of the App class.
:rtype: App
"""
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
app_config_data = config_data.get("app", {})
llm_config_data = config_data.get("llm", {})
db_config_data = config_data.get("vectordb", {})
embedder_config_data = config_data.get("embedder", {})
app_config = AppConfig(**app_config_data.get("config", {}))
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
db_provider = db_config_data.get("provider", "chroma")
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
embedder_provider = embedder_config_data.get("provider", "openai")
embedder = EmbedderFactory.create(embedder_provider, embedder_config_data.get("config", {}))
return cls(config=app_config, llm=llm, db=db, embedder=embedder)

View File

@@ -3,7 +3,8 @@ from typing import Any
from embedchain import App
from embedchain.config import AddConfig, AppConfig, LlmConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB

View File

@@ -10,6 +10,7 @@ class BaseVectorDbConfig(BaseConfig):
dir: str = "db",
host: Optional[str] = None,
port: Optional[str] = None,
**kwargs,
):
"""
Initializes a configuration class instance for the vector database.
@@ -22,8 +23,14 @@ class BaseVectorDbConfig(BaseConfig):
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
"""
self.collection_name = collection_name or "embedchain_store"
self.dir = dir
self.host = host
self.port = port
# Assign additional keyword arguments
if kwargs:
for key, value in kwargs.items():
setattr(self, key, value)

88
embedchain/factory.py Normal file
View File

@@ -0,0 +1,88 @@
import importlib
def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
class LlmFactory:
provider_to_class = {
"anthropic": "embedchain.llm.anthropic.AnthropicLlm",
"azure_openai": "embedchain.llm.azure_openai.AzureOpenAILlm",
"cohere": "embedchain.llm.cohere.CohereLlm",
"gpt4all": "embedchain.llm.gpt4all.GPT4ALLLlm",
"hugging_face_llm": "embedchain.llm.hugging_face_llm.HuggingFaceLlm",
"jina": "embedchain.llm.jina.JinaLlm",
"llama2": "embedchain.llm.llama2.Llama2Llm",
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
"openai": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
"anthropic": "embedchain.config.llm.base_llm_config.BaseLlmConfig",
}
@classmethod
def create(cls, provider_name, config_data):
class_type = cls.provider_to_class.get(provider_name)
# Default to embedchain base config if the provider is not in the config map
config_name = "embedchain" if provider_name not in cls.provider_to_config_class else provider_name
config_class_type = cls.provider_to_config_class.get(config_name)
if class_type:
llm_class = load_class(class_type)
llm_config_class = load_class(config_class_type)
return llm_class(config=llm_config_class(**config_data))
else:
raise ValueError(f"Unsupported Llm provider: {provider_name}")
class EmbedderFactory:
provider_to_class = {
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAiEmbedder",
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
}
provider_to_config_class = {
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
}
@classmethod
def create(cls, provider_name, config_data):
class_type = cls.provider_to_class.get(provider_name)
# Default to openai config if the provider is not in the config map
config_name = "openai" if provider_name not in cls.provider_to_config_class else provider_name
config_class_type = cls.provider_to_config_class.get(config_name)
if class_type:
embedder_class = load_class(class_type)
embedder_config_class = load_class(config_class_type)
return embedder_class(config=embedder_config_class(**config_data))
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")
class VectorDBFactory:
provider_to_class = {
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
}
@classmethod
def create(cls, provider_name, config_data):
class_type = cls.provider_to_class.get(provider_name)
config_class_type = cls.provider_to_config_class.get(provider_name)
if class_type:
embedder_class = load_class(class_type)
embedder_config_class = load_class(config_class_type)
return embedder_class(config=embedder_config_class(**config_data))
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")

View File

@@ -7,12 +7,12 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class AntrophicLlm(BaseLlm):
class AnthropicLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return AntrophicLlm._get_answer(prompt=prompt, config=self.config)
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:

View File

@@ -34,7 +34,8 @@ class JinaLlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:

View File

@@ -2,9 +2,7 @@ try:
from PIL import Image, UnidentifiedImageError
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
"Images requires extra dependencies. Install with `pip install 'embedchain[images]'"
) from None
raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None
MODEL_NAME = "clip-ViT-B-32"

View File

@@ -37,7 +37,7 @@ class ChromaDB(BaseVectorDB):
self.config = ChromaDbConfig()
self.settings = Settings()
self.settings.allow_reset = self.config.allow_reset
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
if self.config.chroma_settings:
for key, value in self.config.chroma_settings.items():
if hasattr(self.settings, key):
@@ -72,6 +72,17 @@ class ChromaDB(BaseVectorDB):
"""Called during initialization"""
return self.client
def _generate_where_clause(self, where: Dict[str, any]) -> str:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if len(where.keys()) == 1:
return where
where_filters = []
for k, v in where.items():
if isinstance(v, str):
where_filters.append({k: v})
return {"$and": where_filters}
def _get_or_create_collection(self, name: str) -> Collection:
"""
Get or create a named collection.
@@ -107,13 +118,14 @@ class ChromaDB(BaseVectorDB):
if ids:
args["ids"] = ids
if where:
args["where"] = where
args["where"] = self._generate_where_clause(where)
if limit:
args["limit"] = limit
return self.collection.get(**args)
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
where_clause = self._generate_where_clause(where)
return self.collection.get(where=where_clause, limit=1)
def add(
self,

View File

@@ -110,8 +110,13 @@ class OpenSearchDB(BaseVectorDB):
return result
def add(
self, embeddings: List[List[str]], documents: List[str], metadatas: List[object], ids: List[str],
skip_embedding: bool):
self,
embeddings: List[List[str]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""add data in vector database
:param embeddings: list of embeddings to add
@@ -162,7 +167,8 @@ class OpenSearchDB(BaseVectorDB):
embedding_function=embeddings,
opensearch_url=f"{self.config.opensearch_url}",
http_auth=self.config.http_auth,
use_ssl=True,
use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl,
verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs,
)
pre_filter = {"match_all": {}} # default

View File

@@ -5,8 +5,8 @@ from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
try:
from pymilvus import MilvusClient
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
MilvusClient, connections, utility)
except ImportError:
raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"

View File

@@ -0,0 +1,26 @@
app:
config:
id: 'my-app'
collection_name: 'my-app'
llm:
provider: openai
model: 'gpt-3.5-turbo'
config:
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
vectordb:
provider: chroma
config:
collection_name: 'my-app'
dir: db
allow_reset: true
embedder:
provider: openai
config:
model: 'text-embedding-ada-002'
deployment_name: null

View File

@@ -0,0 +1,33 @@
app:
config:
id: 'my-app'
log_level: 'WARN'
collect_metrics: true
collection_name: 'my-app'
llm:
provider: openai
model: 'gpt-3.5-turbo'
config:
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
vectordb:
provider: opensearch
config:
opensearch_url: 'https://localhost:9200'
http_auth:
- admin
- admin
vector_dimension: 1536
collection_name: 'my-app'
use_ssl: false
verify_certs: false
embedder:
provider: openai
config:
model: 'text-embedding-ada-002'
deployment_name: null

View File

@@ -0,0 +1,27 @@
app:
config:
id: 'open-source-app'
collection_name: 'open-source-app'
collect_metrics: false
llm:
provider: gpt4all
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
config:
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
vectordb:
provider: chroma
config:
collection_name: 'open-source-app'
dir: db
allow_reset: true
embedder:
provider: gpt4all
config:
model: 'all-MiniLM-L6-v2'
deployment_name: null