refactor: Add config for init, app and query (#158)
This commit is contained in:
58
README.md
58
README.md
@@ -265,6 +265,64 @@ _The embedding is confirmed to work as expected. It returns the right document,
|
||||
|
||||
**The dry run will still consume tokens to embed your query, but it is only ~1/15 of the prompt.**
|
||||
|
||||
# Advanced
|
||||
|
||||
## Configuration
|
||||
Embedchain is made to work out of the box. However, for advanced users we're also offering configuration options. All of these configuration options are optional and have sane defaults.
|
||||
|
||||
### Example
|
||||
|
||||
Here's the readme example with configuration options.
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import App
|
||||
from embedchain.config import InitConfig, AddConfig, QueryConfig
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
# Example: use your own embedding function
|
||||
config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name="text-embedding-ada-002"
|
||||
))
|
||||
naval_chat_bot = App(config)
|
||||
|
||||
add_config = AddConfig() # Currently no options
|
||||
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config)
|
||||
naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
|
||||
naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config)
|
||||
naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
|
||||
|
||||
naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."), add_config)
|
||||
|
||||
query_config = QueryConfig() # Currently no options
|
||||
print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config))
|
||||
```
|
||||
|
||||
### Configs
|
||||
This section describes all possible config options.
|
||||
|
||||
#### **InitConfig**
|
||||
|option|description|type|default|
|
||||
|---|---|---|---|
|
||||
|ef|embedding function|chromadb.utils.embedding_functions|{text-embedding-ada-002}|
|
||||
|db|vector database (experimental)|BaseVectorDB|ChromaDB|
|
||||
|
||||
#### **Add Config**
|
||||
|
||||
*coming soon*
|
||||
|
||||
#### **Query Config**
|
||||
|
||||
*coming soon*
|
||||
|
||||
#### **Chat Config**
|
||||
|
||||
All options for query and...
|
||||
|
||||
*coming soon*
|
||||
|
||||
# How does it work?
|
||||
|
||||
Creating a chat bot over any dataset needs the following steps to happen
|
||||
|
||||
8
embedchain/config/AddConfig.py
Normal file
8
embedchain/config/AddConfig.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
class AddConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `add` method.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
9
embedchain/config/BaseConfig.py
Normal file
9
embedchain/config/BaseConfig.py
Normal file
@@ -0,0 +1,9 @@
|
||||
class BaseConfig:
|
||||
"""
|
||||
Base config.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def as_dict(self):
|
||||
return vars(self)
|
||||
8
embedchain/config/ChatConfig.py
Normal file
8
embedchain/config/ChatConfig.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from embedchain.config.QueryConfig import QueryConfig
|
||||
|
||||
class ChatConfig(QueryConfig):
|
||||
"""
|
||||
Config for the `chat` method, inherits from `QueryConfig`.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
36
embedchain/config/InitConfig.py
Normal file
36
embedchain/config/InitConfig.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
class InitConfig(BaseConfig):
|
||||
"""
|
||||
Config to initialize an embedchain `App` instance.
|
||||
"""
|
||||
def __init__(self, ef=None, db=None):
|
||||
"""
|
||||
:param ef: Optional. Embedding function to use.
|
||||
:param db: Optional. (Vector) database to use for embeddings.
|
||||
"""
|
||||
# Embedding Function
|
||||
if ef is None:
|
||||
from chromadb.utils import embedding_functions
|
||||
self.ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
else:
|
||||
self.ef = ef
|
||||
|
||||
if db is None:
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
self.db = ChromaDB(ef=self.ef)
|
||||
else:
|
||||
self.db = db
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _set_embedding_function(self, ef):
|
||||
self.ef = ef
|
||||
return
|
||||
8
embedchain/config/QueryConfig.py
Normal file
8
embedchain/config/QueryConfig.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
class QueryConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `query` method.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
5
embedchain/config/__init__.py
Normal file
5
embedchain/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .BaseConfig import BaseConfig
|
||||
from .AddConfig import AddConfig
|
||||
from .ChatConfig import ChatConfig
|
||||
from .InitConfig import InitConfig
|
||||
from .QueryConfig import QueryConfig
|
||||
@@ -6,6 +6,7 @@ from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
|
||||
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
@@ -33,17 +34,17 @@ memory = ConversationBufferMemory()
|
||||
|
||||
|
||||
class EmbedChain:
|
||||
def __init__(self, db=None, ef=None):
|
||||
def __init__(self, config: InitConfig):
|
||||
"""
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
creates a collection.
|
||||
|
||||
:param db: The instance of the VectorDB subclass.
|
||||
:param config: InitConfig instance to load as configuration.
|
||||
"""
|
||||
if db is None:
|
||||
db = ChromaDB(ef=ef)
|
||||
self.db_client = db.client
|
||||
self.collection = db.collection
|
||||
|
||||
self.config = config
|
||||
self.db_client = self.config.db.client
|
||||
self.collection = self.config.db.collection
|
||||
self.user_asks = []
|
||||
|
||||
def _get_loader(self, data_type):
|
||||
@@ -86,7 +87,7 @@ class EmbedChain:
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
def add(self, data_type, url):
|
||||
def add(self, data_type, url, config: AddConfig = None):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
@@ -94,13 +95,16 @@ class EmbedChain:
|
||||
|
||||
:param data_type: The type of the data to add.
|
||||
:param url: The URL where the data is located.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration options.
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
loader = self._get_loader(data_type)
|
||||
chunker = self._get_chunker(data_type)
|
||||
self.user_asks.append([data_type, url])
|
||||
self.load_and_embed(loader, chunker, url)
|
||||
|
||||
def add_local(self, data_type, content):
|
||||
def add_local(self, data_type, content, config: AddConfig = None):
|
||||
"""
|
||||
Adds the data you supply to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
@@ -108,7 +112,10 @@ class EmbedChain:
|
||||
|
||||
:param data_type: The type of the data to add.
|
||||
:param content: The local data. Refer to the `README` for formatting.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration options.
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
loader = self._get_loader(data_type)
|
||||
chunker = self._get_chunker(data_type)
|
||||
self.user_asks.append([data_type, content])
|
||||
@@ -210,15 +217,18 @@ class EmbedChain:
|
||||
answer = self.get_llm_model_answer(prompt)
|
||||
return answer
|
||||
|
||||
def query(self, input_query):
|
||||
def query(self, input_query, config: QueryConfig = None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `QueryConfig` instance to use as configuration options.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
context = self.retrieve_from_database(input_query)
|
||||
prompt = self.generate_prompt(input_query, context)
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
@@ -243,14 +253,19 @@ class EmbedChain:
|
||||
prompt += suffix_prompt
|
||||
return prompt
|
||||
|
||||
def chat(self, input_query):
|
||||
def chat(self, input_query, config: ChatConfig = None):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
Maintains last 5 conversations in memory.
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `ChatConfig` instance to use as configuration options.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
config = ChatConfig()
|
||||
context = self.retrieve_from_database(input_query)
|
||||
global memory
|
||||
chat_history = memory.load_memory_variables({})["history"]
|
||||
@@ -274,8 +289,11 @@ class EmbedChain:
|
||||
the `max_tokens` parameter.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `QueryConfig` instance to use as configuration options.
|
||||
:return: The prompt that would be sent to the LLM
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
context = self.retrieve_from_database(input_query)
|
||||
prompt = self.generate_prompt(input_query, context)
|
||||
return prompt
|
||||
@@ -291,14 +309,13 @@ class App(EmbedChain):
|
||||
dry_run(query): test your prompt without consuming tokens.
|
||||
"""
|
||||
|
||||
def __int__(self, db=None, ef=None):
|
||||
if ef is None:
|
||||
ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
super().__init__(db, ef)
|
||||
def __init__(self, config: InitConfig = None):
|
||||
"""
|
||||
:param config: InitConfig instance to load as configuration. Optional.
|
||||
"""
|
||||
if config is None:
|
||||
config = InitConfig()
|
||||
super().__init__(config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
messages = []
|
||||
@@ -326,14 +343,25 @@ class OpenSourceApp(EmbedChain):
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, db=None, ef=None):
|
||||
def __init__(self, config: InitConfig = None):
|
||||
"""
|
||||
:param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source.
|
||||
"""
|
||||
print("Loading open source embedding model. This may take some time...")
|
||||
if ef is None:
|
||||
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
if not config or not config.ef:
|
||||
if config is None:
|
||||
config = InitConfig(
|
||||
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name="all-MiniLM-L6-v2"
|
||||
)
|
||||
)
|
||||
else:
|
||||
config._set_embedding_function(
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name="all-MiniLM-L6-v2"
|
||||
)
|
||||
))
|
||||
print("Successfully loaded open source embedding model.")
|
||||
super().__init__(db, ef)
|
||||
super().__init__(config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
from gpt4all import GPT4All
|
||||
|
||||
Reference in New Issue
Block a user