refactor: Add config for init, app and query (#158)
This commit is contained in:
58
README.md
58
README.md
@@ -265,6 +265,64 @@ _The embedding is confirmed to work as expected. It returns the right document,
|
|||||||
|
|
||||||
**The dry run will still consume tokens to embed your query, but it is only ~1/15 of the prompt.**
|
**The dry run will still consume tokens to embed your query, but it is only ~1/15 of the prompt.**
|
||||||
|
|
||||||
|
# Advanced
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
Embedchain is made to work out of the box. However, for advanced users we're also offering configuration options. All of these configuration options are optional and have sane defaults.
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
Here's the readme example with configuration options.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.config import InitConfig, AddConfig, QueryConfig
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
# Example: use your own embedding function
|
||||||
|
config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||||
|
model_name="text-embedding-ada-002"
|
||||||
|
))
|
||||||
|
naval_chat_bot = App(config)
|
||||||
|
|
||||||
|
add_config = AddConfig() # Currently no options
|
||||||
|
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config)
|
||||||
|
naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
|
||||||
|
naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config)
|
||||||
|
naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
|
||||||
|
|
||||||
|
naval_chat_bot.add_local("qna_pair", ("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."), add_config)
|
||||||
|
|
||||||
|
query_config = QueryConfig() # Currently no options
|
||||||
|
print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", query_config))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configs
|
||||||
|
This section describes all possible config options.
|
||||||
|
|
||||||
|
#### **InitConfig**
|
||||||
|
|option|description|type|default|
|
||||||
|
|---|---|---|---|
|
||||||
|
|ef|embedding function|chromadb.utils.embedding_functions|{text-embedding-ada-002}|
|
||||||
|
|db|vector database (experimental)|BaseVectorDB|ChromaDB|
|
||||||
|
|
||||||
|
#### **Add Config**
|
||||||
|
|
||||||
|
*coming soon*
|
||||||
|
|
||||||
|
#### **Query Config**
|
||||||
|
|
||||||
|
*coming soon*
|
||||||
|
|
||||||
|
#### **Chat Config**
|
||||||
|
|
||||||
|
All options for query and...
|
||||||
|
|
||||||
|
*coming soon*
|
||||||
|
|
||||||
# How does it work?
|
# How does it work?
|
||||||
|
|
||||||
Creating a chat bot over any dataset needs the following steps to happen
|
Creating a chat bot over any dataset needs the following steps to happen
|
||||||
|
|||||||
8
embedchain/config/AddConfig.py
Normal file
8
embedchain/config/AddConfig.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
|
||||||
|
class AddConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Config for the `add` method.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
9
embedchain/config/BaseConfig.py
Normal file
9
embedchain/config/BaseConfig.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
class BaseConfig:
|
||||||
|
"""
|
||||||
|
Base config.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return vars(self)
|
||||||
8
embedchain/config/ChatConfig.py
Normal file
8
embedchain/config/ChatConfig.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from embedchain.config.QueryConfig import QueryConfig
|
||||||
|
|
||||||
|
class ChatConfig(QueryConfig):
|
||||||
|
"""
|
||||||
|
Config for the `chat` method, inherits from `QueryConfig`.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
36
embedchain/config/InitConfig.py
Normal file
36
embedchain/config/InitConfig.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
|
||||||
|
class InitConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Config to initialize an embedchain `App` instance.
|
||||||
|
"""
|
||||||
|
def __init__(self, ef=None, db=None):
|
||||||
|
"""
|
||||||
|
:param ef: Optional. Embedding function to use.
|
||||||
|
:param db: Optional. (Vector) database to use for embeddings.
|
||||||
|
"""
|
||||||
|
# Embedding Function
|
||||||
|
if ef is None:
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
self.ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||||
|
model_name="text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ef = ef
|
||||||
|
|
||||||
|
if db is None:
|
||||||
|
from embedchain.vectordb.chroma_db import ChromaDB
|
||||||
|
self.db = ChromaDB(ef=self.ef)
|
||||||
|
else:
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _set_embedding_function(self, ef):
|
||||||
|
self.ef = ef
|
||||||
|
return
|
||||||
8
embedchain/config/QueryConfig.py
Normal file
8
embedchain/config/QueryConfig.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
|
||||||
|
class QueryConfig(BaseConfig):
|
||||||
|
"""
|
||||||
|
Config for the `query` method.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
5
embedchain/config/__init__.py
Normal file
5
embedchain/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .BaseConfig import BaseConfig
|
||||||
|
from .AddConfig import AddConfig
|
||||||
|
from .ChatConfig import ChatConfig
|
||||||
|
from .InitConfig import InitConfig
|
||||||
|
from .QueryConfig import QueryConfig
|
||||||
@@ -6,6 +6,7 @@ from dotenv import load_dotenv
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
|
||||||
|
|
||||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||||
@@ -33,17 +34,17 @@ memory = ConversationBufferMemory()
|
|||||||
|
|
||||||
|
|
||||||
class EmbedChain:
|
class EmbedChain:
|
||||||
def __init__(self, db=None, ef=None):
|
def __init__(self, config: InitConfig):
|
||||||
"""
|
"""
|
||||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||||
creates a collection.
|
creates a collection.
|
||||||
|
|
||||||
:param db: The instance of the VectorDB subclass.
|
:param config: InitConfig instance to load as configuration.
|
||||||
"""
|
"""
|
||||||
if db is None:
|
|
||||||
db = ChromaDB(ef=ef)
|
self.config = config
|
||||||
self.db_client = db.client
|
self.db_client = self.config.db.client
|
||||||
self.collection = db.collection
|
self.collection = self.config.db.collection
|
||||||
self.user_asks = []
|
self.user_asks = []
|
||||||
|
|
||||||
def _get_loader(self, data_type):
|
def _get_loader(self, data_type):
|
||||||
@@ -86,7 +87,7 @@ class EmbedChain:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported data type: {data_type}")
|
raise ValueError(f"Unsupported data type: {data_type}")
|
||||||
|
|
||||||
def add(self, data_type, url):
|
def add(self, data_type, url, config: AddConfig = None):
|
||||||
"""
|
"""
|
||||||
Adds the data from the given URL to the vector db.
|
Adds the data from the given URL to the vector db.
|
||||||
Loads the data, chunks it, create embedding for each chunk
|
Loads the data, chunks it, create embedding for each chunk
|
||||||
@@ -94,13 +95,16 @@ class EmbedChain:
|
|||||||
|
|
||||||
:param data_type: The type of the data to add.
|
:param data_type: The type of the data to add.
|
||||||
:param url: The URL where the data is located.
|
:param url: The URL where the data is located.
|
||||||
|
:param config: Optional. The `AddConfig` instance to use as configuration options.
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = AddConfig()
|
||||||
loader = self._get_loader(data_type)
|
loader = self._get_loader(data_type)
|
||||||
chunker = self._get_chunker(data_type)
|
chunker = self._get_chunker(data_type)
|
||||||
self.user_asks.append([data_type, url])
|
self.user_asks.append([data_type, url])
|
||||||
self.load_and_embed(loader, chunker, url)
|
self.load_and_embed(loader, chunker, url)
|
||||||
|
|
||||||
def add_local(self, data_type, content):
|
def add_local(self, data_type, content, config: AddConfig = None):
|
||||||
"""
|
"""
|
||||||
Adds the data you supply to the vector db.
|
Adds the data you supply to the vector db.
|
||||||
Loads the data, chunks it, create embedding for each chunk
|
Loads the data, chunks it, create embedding for each chunk
|
||||||
@@ -108,7 +112,10 @@ class EmbedChain:
|
|||||||
|
|
||||||
:param data_type: The type of the data to add.
|
:param data_type: The type of the data to add.
|
||||||
:param content: The local data. Refer to the `README` for formatting.
|
:param content: The local data. Refer to the `README` for formatting.
|
||||||
|
:param config: Optional. The `AddConfig` instance to use as configuration options.
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = AddConfig()
|
||||||
loader = self._get_loader(data_type)
|
loader = self._get_loader(data_type)
|
||||||
chunker = self._get_chunker(data_type)
|
chunker = self._get_chunker(data_type)
|
||||||
self.user_asks.append([data_type, content])
|
self.user_asks.append([data_type, content])
|
||||||
@@ -210,15 +217,18 @@ class EmbedChain:
|
|||||||
answer = self.get_llm_model_answer(prompt)
|
answer = self.get_llm_model_answer(prompt)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
def query(self, input_query):
|
def query(self, input_query, config: QueryConfig = None):
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
LLM as context to get the answer.
|
LLM as context to get the answer.
|
||||||
|
|
||||||
:param input_query: The query to use.
|
:param input_query: The query to use.
|
||||||
|
:param config: Optional. The `QueryConfig` instance to use as configuration options.
|
||||||
:return: The answer to the query.
|
:return: The answer to the query.
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = QueryConfig()
|
||||||
context = self.retrieve_from_database(input_query)
|
context = self.retrieve_from_database(input_query)
|
||||||
prompt = self.generate_prompt(input_query, context)
|
prompt = self.generate_prompt(input_query, context)
|
||||||
answer = self.get_answer_from_llm(prompt)
|
answer = self.get_answer_from_llm(prompt)
|
||||||
@@ -243,14 +253,19 @@ class EmbedChain:
|
|||||||
prompt += suffix_prompt
|
prompt += suffix_prompt
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def chat(self, input_query):
|
def chat(self, input_query, config: ChatConfig = None):
|
||||||
"""
|
"""
|
||||||
Queries the vector database on the given input query.
|
Queries the vector database on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
LLM as context to get the answer.
|
LLM as context to get the answer.
|
||||||
|
|
||||||
Maintains last 5 conversations in memory.
|
Maintains last 5 conversations in memory.
|
||||||
|
:param input_query: The query to use.
|
||||||
|
:param config: Optional. The `ChatConfig` instance to use as configuration options.
|
||||||
|
:return: The answer to the query.
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = ChatConfig()
|
||||||
context = self.retrieve_from_database(input_query)
|
context = self.retrieve_from_database(input_query)
|
||||||
global memory
|
global memory
|
||||||
chat_history = memory.load_memory_variables({})["history"]
|
chat_history = memory.load_memory_variables({})["history"]
|
||||||
@@ -274,8 +289,11 @@ class EmbedChain:
|
|||||||
the `max_tokens` parameter.
|
the `max_tokens` parameter.
|
||||||
|
|
||||||
:param input_query: The query to use.
|
:param input_query: The query to use.
|
||||||
|
:param config: Optional. The `QueryConfig` instance to use as configuration options.
|
||||||
:return: The prompt that would be sent to the LLM
|
:return: The prompt that would be sent to the LLM
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = QueryConfig()
|
||||||
context = self.retrieve_from_database(input_query)
|
context = self.retrieve_from_database(input_query)
|
||||||
prompt = self.generate_prompt(input_query, context)
|
prompt = self.generate_prompt(input_query, context)
|
||||||
return prompt
|
return prompt
|
||||||
@@ -291,14 +309,13 @@ class App(EmbedChain):
|
|||||||
dry_run(query): test your prompt without consuming tokens.
|
dry_run(query): test your prompt without consuming tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __int__(self, db=None, ef=None):
|
def __init__(self, config: InitConfig = None):
|
||||||
if ef is None:
|
"""
|
||||||
ef = embedding_functions.OpenAIEmbeddingFunction(
|
:param config: InitConfig instance to load as configuration. Optional.
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
"""
|
||||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
if config is None:
|
||||||
model_name="text-embedding-ada-002"
|
config = InitConfig()
|
||||||
)
|
super().__init__(config)
|
||||||
super().__init__(db, ef)
|
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
messages = []
|
messages = []
|
||||||
@@ -326,14 +343,25 @@ class OpenSourceApp(EmbedChain):
|
|||||||
query(query): finds answer to the given query using vector database and LLM.
|
query(query): finds answer to the given query using vector database and LLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db=None, ef=None):
|
def __init__(self, config: InitConfig = None):
|
||||||
|
"""
|
||||||
|
:param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source.
|
||||||
|
"""
|
||||||
print("Loading open source embedding model. This may take some time...")
|
print("Loading open source embedding model. This may take some time...")
|
||||||
if ef is None:
|
if not config or not config.ef:
|
||||||
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
if config is None:
|
||||||
|
config = InitConfig(
|
||||||
|
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
|
model_name="all-MiniLM-L6-v2"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config._set_embedding_function(
|
||||||
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
model_name="all-MiniLM-L6-v2"
|
model_name="all-MiniLM-L6-v2"
|
||||||
)
|
))
|
||||||
print("Successfully loaded open source embedding model.")
|
print("Successfully loaded open source embedding model.")
|
||||||
super().__init__(db, ef)
|
super().__init__(config)
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
from gpt4all import GPT4All
|
from gpt4all import GPT4All
|
||||||
|
|||||||
Reference in New Issue
Block a user