refactor: app design concept (#305)

This commit is contained in:
cachho
2023-07-18 01:20:26 +02:00
committed by GitHub
parent 7ed46260b3
commit 0ea278f633
16 changed files with 378 additions and 240 deletions

View File

@@ -0,0 +1,65 @@
from string import Template
from embedchain.apps.App import App
from embedchain.apps.OpenSourceApp import OpenSourceApp
from embedchain.config import ChatConfig, QueryConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
class EmbedChainPersonApp:
"""
Base class to create a person bot.
This bot behaves and speaks like a person.
:param person: name of the person, better if its a well known person.
:param config: BaseAppConfig instance to load as configuration.
"""
def __init__(self, person, config: BaseAppConfig = None):
self.person = person
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
if config is None:
config = BaseAppConfig()
super().__init__(config)
class PersonApp(EmbedChainPersonApp, App):
"""
The Person app.
Extends functionality from EmbedChainPersonApp and App
"""
def query(self, input_query, config: QueryConfig = None):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
query_config = QueryConfig(
template=self.template,
)
return super().query(input_query, query_config)
def chat(self, input_query, config: ChatConfig = None):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
chat_config = ChatConfig(
template=self.template,
)
return super().chat(input_query, chat_config)
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
"""
The Person app.
Extends functionality from EmbedChainPersonApp and OpenSourceApp
"""
def query(self, input_query, config: QueryConfig = None):
query_config = QueryConfig(
template=self.template,
)
return super().query(input_query, query_config)
def chat(self, input_query, config: ChatConfig = None):
chat_config = ChatConfig(
template=self.template,
)
return super().chat(input_query, chat_config)

View File

@@ -2,7 +2,7 @@ import importlib.metadata
__version__ = importlib.metadata.version(__package__ or __name__) __version__ = importlib.metadata.version(__package__ or __name__)
from .embedchain import App # noqa: F401 from embedchain.apps.App import App # noqa: F401
from .embedchain import OpenSourceApp # noqa: F401 from embedchain.apps.OpenSourceApp import OpenSourceApp # noqa: F401
from .embedchain import PersonApp # noqa: F401 from embedchain.apps.PersonApp import (PersonApp, # noqa: F401
from .embedchain import PersonOpenSourceApp # noqa: F401 PersonOpenSourceApp)

49
embedchain/apps/App.py Normal file
View File

@@ -0,0 +1,49 @@
import openai
from embedchain.config import AppConfig, ChatConfig
from embedchain.embedchain import EmbedChain
class App(EmbedChain):
"""
The EmbedChain app.
Has two functions: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
dry_run(query): test your prompt without consuming tokens.
"""
def __init__(self, config: AppConfig = None):
"""
:param config: AppConfig instance to load as configuration. Optional.
"""
if config is None:
config = AppConfig()
super().__init__(config)
def get_llm_model_answer(self, prompt, config: ChatConfig):
messages = []
messages.append({"role": "user", "content": prompt})
response = openai.ChatCompletion.create(
model=config.model,
messages=messages,
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stream=config.stream,
)
if config.stream:
return self._stream_llm_model_response(response)
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk

View File

@@ -0,0 +1,39 @@
import logging
from embedchain.config import ChatConfig, OpenSourceAppConfig
from embedchain.embedchain import EmbedChain
gpt4all_model = None
class OpenSourceApp(EmbedChain):
"""
The OpenSource app.
Same as App, but uses an open source embedding model and LLM.
Has two function: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, config: OpenSourceAppConfig = None):
"""
:param config: InitConfig instance to load as configuration. Optional.
`ef` defaults to open source.
"""
logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
if not config:
config = OpenSourceAppConfig()
logging.info("Successfully loaded open source embedding model.")
super().__init__(config)
def get_llm_model_answer(self, prompt, config: ChatConfig):
from gpt4all import GPT4All
global gpt4all_model
if gpt4all_model is None:
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
return response

View File

@@ -0,0 +1,65 @@
from string import Template
from embedchain.apps.App import App
from embedchain.apps.OpenSourceApp import OpenSourceApp
from embedchain.config import ChatConfig, QueryConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
class EmbedChainPersonApp:
"""
Base class to create a person bot.
This bot behaves and speaks like a person.
:param person: name of the person, better if its a well known person.
:param config: BaseAppConfig instance to load as configuration.
"""
def __init__(self, person, config: BaseAppConfig = None):
self.person = person
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
if config is None:
config = BaseAppConfig()
super().__init__(config)
class PersonApp(EmbedChainPersonApp, App):
"""
The Person app.
Extends functionality from EmbedChainPersonApp and App
"""
def query(self, input_query, config: QueryConfig = None):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
query_config = QueryConfig(
template=self.template,
)
return super().query(input_query, query_config)
def chat(self, input_query, config: ChatConfig = None):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
chat_config = ChatConfig(
template=self.template,
)
return super().chat(input_query, chat_config)
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
"""
The Person app.
Extends functionality from EmbedChainPersonApp and OpenSourceApp
"""
def query(self, input_query, config: QueryConfig = None):
query_config = QueryConfig(
template=self.template,
)
return super().query(input_query, query_config)
def chat(self, input_query, config: ChatConfig = None):
chat_config = ChatConfig(
template=self.template,
)
return super().chat(input_query, chat_config)

View File

View File

@@ -1,74 +0,0 @@
import logging
import os
from chromadb.utils import embedding_functions
from embedchain.config.BaseConfig import BaseConfig
class InitConfig(BaseConfig):
"""
Config to initialize an embedchain `App` instance.
"""
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param ef: Optional. Embedding function to use.
:param db: Optional. (Vector) database to use for embeddings.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
self._setup_logging(log_level)
self.ef = ef
self.db = db
self.host = host
self.port = port
self.id = id
return
def _set_embedding_function(self, ef):
self.ef = ef
return
def _set_embedding_function_to_default(self):
"""
Sets embedding function to default (`text-embedding-ada-002`).
:raises ValueError: If the template is not valid as template should contain
$context and $query
"""
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
self.ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002",
)
return
def _set_db(self, db):
if db:
self.db = db
return
def _set_db_to_default(self):
"""
Sets database to default (`ChromaDb`).
"""
from embedchain.vectordb.chroma_db import ChromaDB
self.db = ChromaDB(ef=self.ef, host=self.host, port=self.port)
def _setup_logging(self, debug_level):
level = logging.WARNING # Default level
if debug_level is not None:
level = getattr(logging, debug_level.upper(), None)
if not isinstance(level, int):
raise ValueError(f"Invalid log level: {debug_level}")
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
self.logger = logging.getLogger(__name__)
return

View File

@@ -1,5 +1,7 @@
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401 from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
from .apps.AppConfig import AppConfig # noqa: F401
from .apps.CustomAppConfig import CustomAppConfig # noqa: F401
from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
from .BaseConfig import BaseConfig # noqa: F401 from .BaseConfig import BaseConfig # noqa: F401
from .ChatConfig import ChatConfig # noqa: F401 from .ChatConfig import ChatConfig # noqa: F401
from .InitConfig import InitConfig # noqa: F401
from .QueryConfig import QueryConfig # noqa: F401 from .QueryConfig import QueryConfig # noqa: F401

View File

@@ -0,0 +1,38 @@
import os
from chromadb.utils import embedding_functions
from .BaseAppConfig import BaseAppConfig
class AppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `App` instance, with extra config options.
"""
def __init__(self, log_level=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
super().__init__(log_level=log_level, ef=AppConfig.default_embedding_function(), host=host, port=port, id=id)
@staticmethod
def default_embedding_function():
"""
Sets embedding function to default (`text-embedding-ada-002`).
:raises ValueError: If the template is not valid as template should contain
$context and $query
:returns: The default embedding function for the app class.
"""
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
return embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002",
)

View File

@@ -0,0 +1,53 @@
import logging
from embedchain.config.BaseConfig import BaseConfig
class BaseAppConfig(BaseConfig):
"""
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
"""
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param ef: Embedding function to use.
:param db: Optional. (Vector) database instance to use for embeddings.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
self._setup_logging(log_level)
self.db = db if db else BaseAppConfig.default_db(ef=ef, host=host, port=port)
self.id = id
return
@staticmethod
def default_db(ef, host, port):
"""
Sets database to default (`ChromaDb`).
:param ef: Embedding function to use in database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:returns: Default database
:raises ValueError: BaseAppConfig knows no default embedding function.
"""
if ef is None:
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
from embedchain.vectordb.chroma_db import ChromaDB
return ChromaDB(ef=ef, host=host, port=port)
def _setup_logging(self, debug_level):
level = logging.WARNING # Default level
if debug_level is not None:
level = getattr(logging, debug_level.upper(), None)
if not isinstance(level, int):
raise ValueError(f"Invalid log level: {debug_level}")
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
self.logger = logging.getLogger(__name__)
return

View File

@@ -0,0 +1,19 @@
from .BaseAppConfig import BaseAppConfig
class CustomAppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `App` instance, with extra config options.
"""
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param ef: Optional. Embedding function to use.
:param db: Optional. (Vector) database to use for embeddings.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
super().__init__(log_level=log_level, db=db, host=host, port=port, id=id)

View File

@@ -0,0 +1,30 @@
from chromadb.utils import embedding_functions
from .BaseAppConfig import BaseAppConfig
class OpenSourceAppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
"""
def __init__(self, log_level=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
super().__init__(
log_level=log_level, ef=OpenSourceAppConfig.default_embedding_function(), host=host, port=port, id=id
)
@staticmethod
def default_embedding_function():
"""
Sets embedding function to default (`all-MiniLM-L6-v2`).
:returns: The default embedding function
"""
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")

View File

View File

@@ -1,15 +1,13 @@
import logging import logging
import os import os
from string import Template
import openai
from chromadb.utils import embedding_functions
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig from embedchain.config import AddConfig, ChatConfig, QueryConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE, DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter from embedchain.data_formatter import DataFormatter
gpt4all_model = None gpt4all_model = None
@@ -23,7 +21,7 @@ memory = ConversationBufferMemory()
class EmbedChain: class EmbedChain:
def __init__(self, config: InitConfig): def __init__(self, config: BaseAppConfig):
""" """
Initializes the EmbedChain instance, sets up a vector DB client and Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection. creates a collection.
@@ -139,7 +137,10 @@ class EmbedChain:
) )
] ]
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self):
"""
Usually implemented by child class
"""
raise NotImplementedError raise NotImplementedError
def retrieve_from_database(self, input_query, config: QueryConfig): def retrieve_from_database(self, input_query, config: QueryConfig):
@@ -329,152 +330,3 @@ class EmbedChain:
`App` has to be reinitialized after using this method. `App` has to be reinitialized after using this method.
""" """
self.db_client.reset() self.db_client.reset()
class App(EmbedChain):
"""
The EmbedChain app.
Has two functions: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
dry_run(query): test your prompt without consuming tokens.
"""
def __init__(self, config: InitConfig = None):
"""
:param config: InitConfig instance to load as configuration. Optional.
"""
if config is None:
config = InitConfig()
if not config.ef:
config._set_embedding_function_to_default()
if not config.db:
config._set_db_to_default()
super().__init__(config)
def get_llm_model_answer(self, prompt, config: ChatConfig):
messages = []
messages.append({"role": "user", "content": prompt})
response = openai.ChatCompletion.create(
model=config.model,
messages=messages,
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stream=config.stream,
)
if config.stream:
return self._stream_llm_model_response(response)
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk
class OpenSourceApp(EmbedChain):
"""
The OpenSource app.
Same as App, but uses an open source embedding model and LLM.
Has two function: add and query.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, config: InitConfig = None):
"""
:param config: InitConfig instance to load as configuration. Optional.
`ef` defaults to open source.
"""
print("Loading open source embedding model. This may take some time...") # noqa:E501
if not config:
config = InitConfig()
if not config.ef:
config._set_embedding_function(
embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
)
if not config.db:
config._set_db_to_default()
print("Successfully loaded open source embedding model.")
super().__init__(config)
def get_llm_model_answer(self, prompt, config: ChatConfig):
from gpt4all import GPT4All
global gpt4all_model
if gpt4all_model is None:
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
return response
class EmbedChainPersonApp:
"""
Base class to create a person bot.
This bot behaves and speaks like a person.
:param person: name of the person, better if its a well known person.
:param config: InitConfig instance to load as configuration.
"""
def __init__(self, person, config: InitConfig = None):
self.person = person
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
if config is None:
config = InitConfig()
super().__init__(config)
class PersonApp(EmbedChainPersonApp, App):
"""
The Person app.
Extends functionality from EmbedChainPersonApp and App
"""
def query(self, input_query, config: QueryConfig = None):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
query_config = QueryConfig(
template=self.template,
)
return super().query(input_query, query_config)
def chat(self, input_query, config: ChatConfig = None):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
chat_config = ChatConfig(
template=self.template,
)
return super().chat(input_query, chat_config)
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
"""
The Person app.
Extends functionality from EmbedChainPersonApp and OpenSourceApp
"""
def query(self, input_query, config: QueryConfig = None):
query_config = QueryConfig(
template=self.template,
)
return super().query(input_query, query_config)
def chat(self, input_query, config: ChatConfig = None):
chat_config = ChatConfig(
template=self.template,
)
return super().chat(input_query, chat_config)

View File

@@ -3,7 +3,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import InitConfig from embedchain.config import AppConfig
class TestChromaDbHostsLoglevel(unittest.TestCase): class TestChromaDbHostsLoglevel(unittest.TestCase):
@@ -25,7 +25,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
""" """
Test if the `App` instance is initialized without a config that does not contain default hosts and ports. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
""" """
config = InitConfig(log_level="DEBUG") config = AppConfig(log_level="DEBUG")
app = App(config) app = App(config)

View File

@@ -4,7 +4,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import InitConfig from embedchain.config import AppConfig
from embedchain.vectordb.chroma_db import ChromaDB, chromadb from embedchain.vectordb.chroma_db import ChromaDB, chromadb
@@ -38,7 +38,7 @@ class TestChromaDbHostsInit(unittest.TestCase):
host = "test-host" host = "test-host"
port = "1234" port = "1234"
config = InitConfig(host=host, port=port) config = AppConfig(host=host, port=port)
_app = App(config) _app = App(config)
@@ -65,7 +65,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
""" """
Test if the `App` instance is initialized without a config that does not contain default hosts and ports. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
""" """
config = InitConfig(log_level="DEBUG") config = AppConfig(log_level="DEBUG")
_app = App(config) _app = App(config)