refactor: Add config for init, app and query (#158)

This commit is contained in:
cachho
2023-07-06 20:16:40 +02:00
committed by GitHub
parent 68e732a426
commit e50c7e6843
8 changed files with 183 additions and 23 deletions

View File

@@ -265,6 +265,64 @@ _The embedding is confirmed to work as expected. It returns the right document,
**The dry run will still consume tokens to embed your query, but it is only ~1/15 of the prompt.**
# Advanced
## Configuration
Embedchain is made to work out of the box. However, for advanced users we're also offering configuration options. All of these configuration options are optional and have sane defaults.
### Example
Here's the readme example with configuration options.
```python
import os
from embedchain import App
from embedchain.config import InitConfig, AddConfig, QueryConfig
from chromadb.utils import embedding_functions
# Example: use your own embedding function
config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002"
))
naval_chat_bot = App(config)
add_config = AddConfig() # Currently no options
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config)
naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config)
naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."), add_config)
query_config = QueryConfig() # Currently no options
print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config))
```
### Configs
This section describes all possible config options.
#### **InitConfig**
|option|description|type|default|
|---|---|---|---|
|ef|embedding function|chromadb.utils.embedding_functions|{text-embedding-ada-002}|
|db|vector database (experimental)|BaseVectorDB|ChromaDB|
#### **Add Config**
*coming soon*
#### **Query Config**
*coming soon*
#### **Chat Config**
All options for query and...
*coming soon*
# How does it work?
Creating a chat bot over any dataset needs the following steps to happen

View File

@@ -0,0 +1,8 @@
from embedchain.config.BaseConfig import BaseConfig
class AddConfig(BaseConfig):
"""
Config for the `add` method.
"""
def __init__(self):
pass

View File

@@ -0,0 +1,9 @@
class BaseConfig:
"""
Base config.
"""
def __init__(self):
pass
def as_dict(self):
return vars(self)

View File

@@ -0,0 +1,8 @@
from embedchain.config.QueryConfig import QueryConfig
class ChatConfig(QueryConfig):
"""
Config for the `chat` method, inherits from `QueryConfig`.
"""
def __init__(self):
pass

View File

@@ -0,0 +1,36 @@
import os
from embedchain.config.BaseConfig import BaseConfig
class InitConfig(BaseConfig):
"""
Config to initialize an embedchain `App` instance.
"""
def __init__(self, ef=None, db=None):
"""
:param ef: Optional. Embedding function to use.
:param db: Optional. (Vector) database to use for embeddings.
"""
# Embedding Function
if ef is None:
from chromadb.utils import embedding_functions
self.ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002"
)
else:
self.ef = ef
if db is None:
from embedchain.vectordb.chroma_db import ChromaDB
self.db = ChromaDB(ef=self.ef)
else:
self.db = db
return
def _set_embedding_function(self, ef):
self.ef = ef
return

View File

@@ -0,0 +1,8 @@
from embedchain.config.BaseConfig import BaseConfig
class QueryConfig(BaseConfig):
"""
Config for the `query` method.
"""
def __init__(self):
pass

View File

@@ -0,0 +1,5 @@
from .BaseConfig import BaseConfig
from .AddConfig import AddConfig
from .ChatConfig import ChatConfig
from .InitConfig import InitConfig
from .QueryConfig import QueryConfig

View File

@@ -6,6 +6,7 @@ from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.loaders.pdf_file import PdfFileLoader
@@ -33,17 +34,17 @@ memory = ConversationBufferMemory()
class EmbedChain:
def __init__(self, db=None, ef=None):
def __init__(self, config: InitConfig):
"""
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
:param db: The instance of the VectorDB subclass.
:param config: InitConfig instance to load as configuration.
"""
if db is None:
db = ChromaDB(ef=ef)
self.db_client = db.client
self.collection = db.collection
self.config = config
self.db_client = self.config.db.client
self.collection = self.config.db.collection
self.user_asks = []
def _get_loader(self, data_type):
@@ -86,7 +87,7 @@ class EmbedChain:
else:
raise ValueError(f"Unsupported data type: {data_type}")
def add(self, data_type, url):
def add(self, data_type, url, config: AddConfig = None):
"""
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
@@ -94,13 +95,16 @@ class EmbedChain:
:param data_type: The type of the data to add.
:param url: The URL where the data is located.
:param config: Optional. The `AddConfig` instance to use as configuration options.
"""
if config is None:
config = AddConfig()
loader = self._get_loader(data_type)
chunker = self._get_chunker(data_type)
self.user_asks.append([data_type, url])
self.load_and_embed(loader, chunker, url)
def add_local(self, data_type, content):
def add_local(self, data_type, content, config: AddConfig = None):
"""
Adds the data you supply to the vector db.
Loads the data, chunks it, create embedding for each chunk
@@ -108,7 +112,10 @@ class EmbedChain:
:param data_type: The type of the data to add.
:param content: The local data. Refer to the `README` for formatting.
:param config: Optional. The `AddConfig` instance to use as configuration options.
"""
if config is None:
config = AddConfig()
loader = self._get_loader(data_type)
chunker = self._get_chunker(data_type)
self.user_asks.append([data_type, content])
@@ -210,15 +217,18 @@ class EmbedChain:
answer = self.get_llm_model_answer(prompt)
return answer
def query(self, input_query):
def query(self, input_query, config: QueryConfig = None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `QueryConfig` instance to use as configuration options.
:return: The answer to the query.
"""
if config is None:
config = QueryConfig()
context = self.retrieve_from_database(input_query)
prompt = self.generate_prompt(input_query, context)
answer = self.get_answer_from_llm(prompt)
@@ -243,14 +253,19 @@ class EmbedChain:
prompt += suffix_prompt
return prompt
def chat(self, input_query):
def chat(self, input_query, config: ChatConfig = None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
Maintains last 5 conversations in memory.
:param input_query: The query to use.
:param config: Optional. The `ChatConfig` instance to use as configuration options.
:return: The answer to the query.
"""
if config is None:
config = ChatConfig()
context = self.retrieve_from_database(input_query)
global memory
chat_history = memory.load_memory_variables({})["history"]
@@ -274,8 +289,11 @@ class EmbedChain:
the `max_tokens` parameter.
:param input_query: The query to use.
:param config: Optional. The `QueryConfig` instance to use as configuration options.
:return: The prompt that would be sent to the LLM
"""
if config is None:
config = QueryConfig()
context = self.retrieve_from_database(input_query)
prompt = self.generate_prompt(input_query, context)
return prompt
@@ -291,14 +309,13 @@ class App(EmbedChain):
dry_run(query): test your prompt without consuming tokens.
"""
def __int__(self, db=None, ef=None):
if ef is None:
ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002"
)
super().__init__(db, ef)
def __init__(self, config: InitConfig = None):
"""
:param config: InitConfig instance to load as configuration. Optional.
"""
if config is None:
config = InitConfig()
super().__init__(config)
def get_llm_model_answer(self, prompt):
messages = []
@@ -326,14 +343,25 @@ class OpenSourceApp(EmbedChain):
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, db=None, ef=None):
def __init__(self, config: InitConfig = None):
"""
:param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source.
"""
print("Loading open source embedding model. This may take some time...")
if ef is None:
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
if not config or not config.ef:
if config is None:
config = InitConfig(
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
)
else:
config._set_embedding_function(
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
))
print("Successfully loaded open source embedding model.")
super().__init__(db, ef)
super().__init__(config)
def get_llm_model_answer(self, prompt):
from gpt4all import GPT4All