refactor: app design concept (#305)
This commit is contained in:
65
embedchain/Apps/PersonApp.py
Normal file
65
embedchain/Apps/PersonApp.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from string import Template
|
||||||
|
|
||||||
|
from embedchain.apps.App import App
|
||||||
|
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||||
|
from embedchain.config import ChatConfig, QueryConfig
|
||||||
|
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||||
|
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
||||||
|
DEFAULT_PROMPT_WITH_HISTORY)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedChainPersonApp:
|
||||||
|
"""
|
||||||
|
Base class to create a person bot.
|
||||||
|
This bot behaves and speaks like a person.
|
||||||
|
|
||||||
|
:param person: name of the person, better if its a well known person.
|
||||||
|
:param config: BaseAppConfig instance to load as configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, person, config: BaseAppConfig = None):
|
||||||
|
self.person = person
|
||||||
|
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
||||||
|
if config is None:
|
||||||
|
config = BaseAppConfig()
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonApp(EmbedChainPersonApp, App):
|
||||||
|
"""
|
||||||
|
The Person app.
|
||||||
|
Extends functionality from EmbedChainPersonApp and App
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, input_query, config: QueryConfig = None):
|
||||||
|
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
|
||||||
|
query_config = QueryConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().query(input_query, query_config)
|
||||||
|
|
||||||
|
def chat(self, input_query, config: ChatConfig = None):
|
||||||
|
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
|
||||||
|
chat_config = ChatConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().chat(input_query, chat_config)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||||
|
"""
|
||||||
|
The Person app.
|
||||||
|
Extends functionality from EmbedChainPersonApp and OpenSourceApp
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, input_query, config: QueryConfig = None):
|
||||||
|
query_config = QueryConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().query(input_query, query_config)
|
||||||
|
|
||||||
|
def chat(self, input_query, config: ChatConfig = None):
|
||||||
|
chat_config = ChatConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().chat(input_query, chat_config)
|
||||||
@@ -2,7 +2,7 @@ import importlib.metadata
|
|||||||
|
|
||||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||||
|
|
||||||
from .embedchain import App # noqa: F401
|
from embedchain.apps.App import App # noqa: F401
|
||||||
from .embedchain import OpenSourceApp # noqa: F401
|
from embedchain.apps.OpenSourceApp import OpenSourceApp # noqa: F401
|
||||||
from .embedchain import PersonApp # noqa: F401
|
from embedchain.apps.PersonApp import (PersonApp, # noqa: F401
|
||||||
from .embedchain import PersonOpenSourceApp # noqa: F401
|
PersonOpenSourceApp)
|
||||||
|
|||||||
49
embedchain/apps/App.py
Normal file
49
embedchain/apps/App.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import openai
|
||||||
|
|
||||||
|
from embedchain.config import AppConfig, ChatConfig
|
||||||
|
from embedchain.embedchain import EmbedChain
|
||||||
|
|
||||||
|
|
||||||
|
class App(EmbedChain):
|
||||||
|
"""
|
||||||
|
The EmbedChain app.
|
||||||
|
Has two functions: add and query.
|
||||||
|
|
||||||
|
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||||
|
query(query): finds answer to the given query using vector database and LLM.
|
||||||
|
dry_run(query): test your prompt without consuming tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: AppConfig = None):
|
||||||
|
"""
|
||||||
|
:param config: AppConfig instance to load as configuration. Optional.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = AppConfig()
|
||||||
|
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||||
|
messages = []
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model=config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=config.temperature,
|
||||||
|
max_tokens=config.max_tokens,
|
||||||
|
top_p=config.top_p,
|
||||||
|
stream=config.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.stream:
|
||||||
|
return self._stream_llm_model_response(response)
|
||||||
|
else:
|
||||||
|
return response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
def _stream_llm_model_response(self, response):
|
||||||
|
"""
|
||||||
|
This is a generator for streaming response from the OpenAI completions API
|
||||||
|
"""
|
||||||
|
for line in response:
|
||||||
|
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||||
|
yield chunk
|
||||||
39
embedchain/apps/OpenSourceApp.py
Normal file
39
embedchain/apps/OpenSourceApp.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
||||||
|
from embedchain.embedchain import EmbedChain
|
||||||
|
|
||||||
|
gpt4all_model = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenSourceApp(EmbedChain):
|
||||||
|
"""
|
||||||
|
The OpenSource app.
|
||||||
|
Same as App, but uses an open source embedding model and LLM.
|
||||||
|
|
||||||
|
Has two function: add and query.
|
||||||
|
|
||||||
|
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||||
|
query(query): finds answer to the given query using vector database and LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenSourceAppConfig = None):
|
||||||
|
"""
|
||||||
|
:param config: InitConfig instance to load as configuration. Optional.
|
||||||
|
`ef` defaults to open source.
|
||||||
|
"""
|
||||||
|
logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
|
||||||
|
if not config:
|
||||||
|
config = OpenSourceAppConfig()
|
||||||
|
|
||||||
|
logging.info("Successfully loaded open source embedding model.")
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
|
global gpt4all_model
|
||||||
|
if gpt4all_model is None:
|
||||||
|
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||||
|
response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
|
||||||
|
return response
|
||||||
65
embedchain/apps/PersonApp.py
Normal file
65
embedchain/apps/PersonApp.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from string import Template
|
||||||
|
|
||||||
|
from embedchain.apps.App import App
|
||||||
|
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||||
|
from embedchain.config import ChatConfig, QueryConfig
|
||||||
|
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||||
|
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
||||||
|
DEFAULT_PROMPT_WITH_HISTORY)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedChainPersonApp:
|
||||||
|
"""
|
||||||
|
Base class to create a person bot.
|
||||||
|
This bot behaves and speaks like a person.
|
||||||
|
|
||||||
|
:param person: name of the person, better if its a well known person.
|
||||||
|
:param config: BaseAppConfig instance to load as configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, person, config: BaseAppConfig = None):
|
||||||
|
self.person = person
|
||||||
|
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
||||||
|
if config is None:
|
||||||
|
config = BaseAppConfig()
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonApp(EmbedChainPersonApp, App):
|
||||||
|
"""
|
||||||
|
The Person app.
|
||||||
|
Extends functionality from EmbedChainPersonApp and App
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, input_query, config: QueryConfig = None):
|
||||||
|
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
|
||||||
|
query_config = QueryConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().query(input_query, query_config)
|
||||||
|
|
||||||
|
def chat(self, input_query, config: ChatConfig = None):
|
||||||
|
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
|
||||||
|
chat_config = ChatConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().chat(input_query, chat_config)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||||
|
"""
|
||||||
|
The Person app.
|
||||||
|
Extends functionality from EmbedChainPersonApp and OpenSourceApp
|
||||||
|
"""
|
||||||
|
|
||||||
|
def query(self, input_query, config: QueryConfig = None):
|
||||||
|
query_config = QueryConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().query(input_query, query_config)
|
||||||
|
|
||||||
|
def chat(self, input_query, config: ChatConfig = None):
|
||||||
|
chat_config = ChatConfig(
|
||||||
|
template=self.template,
|
||||||
|
)
|
||||||
|
return super().chat(input_query, chat_config)
|
||||||
0
embedchain/apps/__init__.py
Normal file
0
embedchain/apps/__init__.py
Normal file
@@ -1,74 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
from chromadb.utils import embedding_functions
|
|
||||||
|
|
||||||
from embedchain.config.BaseConfig import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
class InitConfig(BaseConfig):
|
|
||||||
"""
|
|
||||||
Config to initialize an embedchain `App` instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
|
|
||||||
"""
|
|
||||||
:param log_level: Optional. (String) Debug level
|
|
||||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
|
||||||
:param ef: Optional. Embedding function to use.
|
|
||||||
:param db: Optional. (Vector) database to use for embeddings.
|
|
||||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
|
||||||
:param host: Optional. Hostname for the database server.
|
|
||||||
:param port: Optional. Port for the database server.
|
|
||||||
"""
|
|
||||||
self._setup_logging(log_level)
|
|
||||||
self.ef = ef
|
|
||||||
self.db = db
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.id = id
|
|
||||||
return
|
|
||||||
|
|
||||||
def _set_embedding_function(self, ef):
|
|
||||||
self.ef = ef
|
|
||||||
return
|
|
||||||
|
|
||||||
def _set_embedding_function_to_default(self):
|
|
||||||
"""
|
|
||||||
Sets embedding function to default (`text-embedding-ada-002`).
|
|
||||||
|
|
||||||
:raises ValueError: If the template is not valid as template should contain
|
|
||||||
$context and $query
|
|
||||||
"""
|
|
||||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
|
||||||
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
|
|
||||||
self.ef = embedding_functions.OpenAIEmbeddingFunction(
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
|
||||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
|
||||||
model_name="text-embedding-ada-002",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def _set_db(self, db):
|
|
||||||
if db:
|
|
||||||
self.db = db
|
|
||||||
return
|
|
||||||
|
|
||||||
def _set_db_to_default(self):
|
|
||||||
"""
|
|
||||||
Sets database to default (`ChromaDb`).
|
|
||||||
"""
|
|
||||||
from embedchain.vectordb.chroma_db import ChromaDB
|
|
||||||
|
|
||||||
self.db = ChromaDB(ef=self.ef, host=self.host, port=self.port)
|
|
||||||
|
|
||||||
def _setup_logging(self, debug_level):
|
|
||||||
level = logging.WARNING # Default level
|
|
||||||
if debug_level is not None:
|
|
||||||
level = getattr(logging, debug_level.upper(), None)
|
|
||||||
if not isinstance(level, int):
|
|
||||||
raise ValueError(f"Invalid log level: {debug_level}")
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
return
|
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
|
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
|
||||||
|
from .apps.AppConfig import AppConfig # noqa: F401
|
||||||
|
from .apps.CustomAppConfig import CustomAppConfig # noqa: F401
|
||||||
|
from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
|
||||||
from .BaseConfig import BaseConfig # noqa: F401
|
from .BaseConfig import BaseConfig # noqa: F401
|
||||||
from .ChatConfig import ChatConfig # noqa: F401
|
from .ChatConfig import ChatConfig # noqa: F401
|
||||||
from .InitConfig import InitConfig # noqa: F401
|
|
||||||
from .QueryConfig import QueryConfig # noqa: F401
|
from .QueryConfig import QueryConfig # noqa: F401
|
||||||
|
|||||||
38
embedchain/config/apps/AppConfig.py
Normal file
38
embedchain/config/apps/AppConfig.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
from .BaseAppConfig import BaseAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AppConfig(BaseAppConfig):
|
||||||
|
"""
|
||||||
|
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_level=None, host=None, port=None, id=None):
|
||||||
|
"""
|
||||||
|
:param log_level: Optional. (String) Debug level
|
||||||
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param host: Optional. Hostname for the database server.
|
||||||
|
:param port: Optional. Port for the database server.
|
||||||
|
"""
|
||||||
|
super().__init__(log_level=log_level, ef=AppConfig.default_embedding_function(), host=host, port=port, id=id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_embedding_function():
|
||||||
|
"""
|
||||||
|
Sets embedding function to default (`text-embedding-ada-002`).
|
||||||
|
|
||||||
|
:raises ValueError: If the template is not valid as template should contain
|
||||||
|
$context and $query
|
||||||
|
:returns: The default embedding function for the app class.
|
||||||
|
"""
|
||||||
|
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||||
|
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
|
||||||
|
return embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||||
|
model_name="text-embedding-ada-002",
|
||||||
|
)
|
||||||
53
embedchain/config/apps/BaseAppConfig.py
Normal file
53
embedchain/config/apps/BaseAppConfig.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAppConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
|
||||||
|
"""
|
||||||
|
:param log_level: Optional. (String) Debug level
|
||||||
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
|
:param ef: Embedding function to use.
|
||||||
|
:param db: Optional. (Vector) database instance to use for embeddings.
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param host: Optional. Hostname for the database server.
|
||||||
|
:param port: Optional. Port for the database server.
|
||||||
|
"""
|
||||||
|
self._setup_logging(log_level)
|
||||||
|
|
||||||
|
self.db = db if db else BaseAppConfig.default_db(ef=ef, host=host, port=port)
|
||||||
|
self.id = id
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_db(ef, host, port):
|
||||||
|
"""
|
||||||
|
Sets database to default (`ChromaDb`).
|
||||||
|
|
||||||
|
:param ef: Embedding function to use in database.
|
||||||
|
:param host: Optional. Hostname for the database server.
|
||||||
|
:param port: Optional. Port for the database server.
|
||||||
|
:returns: Default database
|
||||||
|
:raises ValueError: BaseAppConfig knows no default embedding function.
|
||||||
|
"""
|
||||||
|
if ef is None:
|
||||||
|
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
|
||||||
|
from embedchain.vectordb.chroma_db import ChromaDB
|
||||||
|
|
||||||
|
return ChromaDB(ef=ef, host=host, port=port)
|
||||||
|
|
||||||
|
def _setup_logging(self, debug_level):
|
||||||
|
level = logging.WARNING # Default level
|
||||||
|
if debug_level is not None:
|
||||||
|
level = getattr(logging, debug_level.upper(), None)
|
||||||
|
if not isinstance(level, int):
|
||||||
|
raise ValueError(f"Invalid log level: {debug_level}")
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
return
|
||||||
19
embedchain/config/apps/CustomAppConfig.py
Normal file
19
embedchain/config/apps/CustomAppConfig.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from .BaseAppConfig import BaseAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAppConfig(BaseAppConfig):
|
||||||
|
"""
|
||||||
|
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
|
||||||
|
"""
|
||||||
|
:param log_level: Optional. (String) Debug level
|
||||||
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
|
:param ef: Optional. Embedding function to use.
|
||||||
|
:param db: Optional. (Vector) database to use for embeddings.
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param host: Optional. Hostname for the database server.
|
||||||
|
:param port: Optional. Port for the database server.
|
||||||
|
"""
|
||||||
|
super().__init__(log_level=log_level, db=db, host=host, port=port, id=id)
|
||||||
30
embedchain/config/apps/OpenSourceAppConfig.py
Normal file
30
embedchain/config/apps/OpenSourceAppConfig.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
from .BaseAppConfig import BaseAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OpenSourceAppConfig(BaseAppConfig):
|
||||||
|
"""
|
||||||
|
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_level=None, host=None, port=None, id=None):
|
||||||
|
"""
|
||||||
|
:param log_level: Optional. (String) Debug level
|
||||||
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param host: Optional. Hostname for the database server.
|
||||||
|
:param port: Optional. Port for the database server.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
log_level=log_level, ef=OpenSourceAppConfig.default_embedding_function(), host=host, port=port, id=id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_embedding_function():
|
||||||
|
"""
|
||||||
|
Sets embedding function to default (`all-MiniLM-L6-v2`).
|
||||||
|
|
||||||
|
:returns: The default embedding function
|
||||||
|
"""
|
||||||
|
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||||
0
embedchain/config/apps/__init__.py
Normal file
0
embedchain/config/apps/__init__.py
Normal file
@@ -1,15 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from string import Template
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from chromadb.utils import embedding_functions
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
|
||||||
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
from embedchain.config import AddConfig, ChatConfig, QueryConfig
|
||||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE, DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||||
|
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||||
from embedchain.data_formatter import DataFormatter
|
from embedchain.data_formatter import DataFormatter
|
||||||
|
|
||||||
gpt4all_model = None
|
gpt4all_model = None
|
||||||
@@ -23,7 +21,7 @@ memory = ConversationBufferMemory()
|
|||||||
|
|
||||||
|
|
||||||
class EmbedChain:
|
class EmbedChain:
|
||||||
def __init__(self, config: InitConfig):
|
def __init__(self, config: BaseAppConfig):
|
||||||
"""
|
"""
|
||||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||||
creates a collection.
|
creates a collection.
|
||||||
@@ -139,7 +137,10 @@ class EmbedChain:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self):
|
||||||
|
"""
|
||||||
|
Usually implemented by child class
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def retrieve_from_database(self, input_query, config: QueryConfig):
|
def retrieve_from_database(self, input_query, config: QueryConfig):
|
||||||
@@ -329,152 +330,3 @@ class EmbedChain:
|
|||||||
`App` has to be reinitialized after using this method.
|
`App` has to be reinitialized after using this method.
|
||||||
"""
|
"""
|
||||||
self.db_client.reset()
|
self.db_client.reset()
|
||||||
|
|
||||||
|
|
||||||
class App(EmbedChain):
|
|
||||||
"""
|
|
||||||
The EmbedChain app.
|
|
||||||
Has two functions: add and query.
|
|
||||||
|
|
||||||
adds(data_type, url): adds the data from the given URL to the vector db.
|
|
||||||
query(query): finds answer to the given query using vector database and LLM.
|
|
||||||
dry_run(query): test your prompt without consuming tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: InitConfig = None):
|
|
||||||
"""
|
|
||||||
:param config: InitConfig instance to load as configuration. Optional.
|
|
||||||
"""
|
|
||||||
if config is None:
|
|
||||||
config = InitConfig()
|
|
||||||
|
|
||||||
if not config.ef:
|
|
||||||
config._set_embedding_function_to_default()
|
|
||||||
|
|
||||||
if not config.db:
|
|
||||||
config._set_db_to_default()
|
|
||||||
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
|
||||||
messages = []
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
response = openai.ChatCompletion.create(
|
|
||||||
model=config.model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=config.temperature,
|
|
||||||
max_tokens=config.max_tokens,
|
|
||||||
top_p=config.top_p,
|
|
||||||
stream=config.stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.stream:
|
|
||||||
return self._stream_llm_model_response(response)
|
|
||||||
else:
|
|
||||||
return response["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
def _stream_llm_model_response(self, response):
|
|
||||||
"""
|
|
||||||
This is a generator for streaming response from the OpenAI completions API
|
|
||||||
"""
|
|
||||||
for line in response:
|
|
||||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
class OpenSourceApp(EmbedChain):
|
|
||||||
"""
|
|
||||||
The OpenSource app.
|
|
||||||
Same as App, but uses an open source embedding model and LLM.
|
|
||||||
|
|
||||||
Has two function: add and query.
|
|
||||||
|
|
||||||
adds(data_type, url): adds the data from the given URL to the vector db.
|
|
||||||
query(query): finds answer to the given query using vector database and LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: InitConfig = None):
|
|
||||||
"""
|
|
||||||
:param config: InitConfig instance to load as configuration. Optional.
|
|
||||||
`ef` defaults to open source.
|
|
||||||
"""
|
|
||||||
print("Loading open source embedding model. This may take some time...") # noqa:E501
|
|
||||||
if not config:
|
|
||||||
config = InitConfig()
|
|
||||||
|
|
||||||
if not config.ef:
|
|
||||||
config._set_embedding_function(
|
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
|
||||||
)
|
|
||||||
|
|
||||||
if not config.db:
|
|
||||||
config._set_db_to_default()
|
|
||||||
|
|
||||||
print("Successfully loaded open source embedding model.")
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
|
|
||||||
global gpt4all_model
|
|
||||||
if gpt4all_model is None:
|
|
||||||
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
|
||||||
response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedChainPersonApp:
|
|
||||||
"""
|
|
||||||
Base class to create a person bot.
|
|
||||||
This bot behaves and speaks like a person.
|
|
||||||
|
|
||||||
:param person: name of the person, better if its a well known person.
|
|
||||||
:param config: InitConfig instance to load as configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, person, config: InitConfig = None):
|
|
||||||
self.person = person
|
|
||||||
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
|
||||||
if config is None:
|
|
||||||
config = InitConfig()
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
|
|
||||||
class PersonApp(EmbedChainPersonApp, App):
|
|
||||||
"""
|
|
||||||
The Person app.
|
|
||||||
Extends functionality from EmbedChainPersonApp and App
|
|
||||||
"""
|
|
||||||
|
|
||||||
def query(self, input_query, config: QueryConfig = None):
|
|
||||||
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
|
|
||||||
query_config = QueryConfig(
|
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().query(input_query, query_config)
|
|
||||||
|
|
||||||
def chat(self, input_query, config: ChatConfig = None):
|
|
||||||
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
|
|
||||||
chat_config = ChatConfig(
|
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().chat(input_query, chat_config)
|
|
||||||
|
|
||||||
|
|
||||||
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
|
||||||
"""
|
|
||||||
The Person app.
|
|
||||||
Extends functionality from EmbedChainPersonApp and OpenSourceApp
|
|
||||||
"""
|
|
||||||
|
|
||||||
def query(self, input_query, config: QueryConfig = None):
|
|
||||||
query_config = QueryConfig(
|
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().query(input_query, query_config)
|
|
||||||
|
|
||||||
def chat(self, input_query, config: ChatConfig = None):
|
|
||||||
chat_config = ChatConfig(
|
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().chat(input_query, chat_config)
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import InitConfig
|
from embedchain.config import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||||
@@ -25,7 +25,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
||||||
"""
|
"""
|
||||||
config = InitConfig(log_level="DEBUG")
|
config = AppConfig(log_level="DEBUG")
|
||||||
|
|
||||||
app = App(config)
|
app = App(config)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import InitConfig
|
from embedchain.config import AppConfig
|
||||||
from embedchain.vectordb.chroma_db import ChromaDB, chromadb
|
from embedchain.vectordb.chroma_db import ChromaDB, chromadb
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class TestChromaDbHostsInit(unittest.TestCase):
|
|||||||
host = "test-host"
|
host = "test-host"
|
||||||
port = "1234"
|
port = "1234"
|
||||||
|
|
||||||
config = InitConfig(host=host, port=port)
|
config = AppConfig(host=host, port=port)
|
||||||
|
|
||||||
_app = App(config)
|
_app = App(config)
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
||||||
"""
|
"""
|
||||||
config = InitConfig(log_level="DEBUG")
|
config = AppConfig(log_level="DEBUG")
|
||||||
|
|
||||||
_app = App(config)
|
_app = App(config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user