refactor: classes and configs (#528)
This commit is contained in:
@@ -11,30 +11,37 @@ from typing import Dict, Optional
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, ChatConfig, QueryConfig
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ABS_PATH = os.getcwd()
|
||||
DB_DIR = os.path.join(ABS_PATH, "db")
|
||||
HOME_DIR = str(Path.home())
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseAppConfig,
|
||||
llm: BaseLlm,
|
||||
db: BaseVectorDB = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
creates a collection.
|
||||
@@ -44,17 +51,40 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
self.system_prompt = system_prompt
|
||||
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
|
||||
self.db = self.config.db
|
||||
|
||||
# Add subclasses
|
||||
## Llm
|
||||
self.llm = llm
|
||||
## Database
|
||||
# Database has support for config assignment for backwards compatibility
|
||||
if db is None and (not hasattr(self.config, "db") or self.config.db is None):
|
||||
raise ValueError("App requires Database.")
|
||||
self.db = db or self.config.db
|
||||
## Embedder
|
||||
if embedder is None:
|
||||
raise ValueError("App requires Embedder.")
|
||||
self.embedder = embedder
|
||||
|
||||
# Initialize database
|
||||
self.db._set_embedder(self.embedder)
|
||||
self.db._initialize()
|
||||
# Set collection name from app config for backwards compatibility.
|
||||
if config.collection_name:
|
||||
self.db.set_collection_name(config.collection_name)
|
||||
|
||||
# Add variables that are "shortcuts"
|
||||
if system_prompt:
|
||||
self.llm.config.system_prompt = system_prompt
|
||||
|
||||
# Attributes that aren't subclass related.
|
||||
self.user_asks = []
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.memory = ConversationBufferMemory()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
|
||||
self.u_id = self._load_or_generate_user_id()
|
||||
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
|
||||
# if (self.config.collect_metrics):
|
||||
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
|
||||
thread_telemetry.start()
|
||||
|
||||
@@ -227,10 +257,10 @@ class EmbedChain(JSONSerializable):
|
||||
metadatas = new_metadatas
|
||||
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.count()
|
||||
chunks_before_addition = self.db.count()
|
||||
|
||||
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
count_new_chunks = self.count() - chunks_before_addition
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
@@ -244,13 +274,7 @@ class EmbedChain(JSONSerializable):
|
||||
)
|
||||
]
|
||||
|
||||
def get_llm_model_answer(self):
|
||||
"""
|
||||
Usually implemented by child class
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
|
||||
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
@@ -260,11 +284,12 @@ class EmbedChain(JSONSerializable):
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The content of the document that matched your query.
|
||||
"""
|
||||
query_config = config or self.llm.config
|
||||
|
||||
if where is not None:
|
||||
where = where
|
||||
elif config is not None and config.where is not None:
|
||||
where = config.where
|
||||
elif query_config is not None and query_config.where is not None:
|
||||
where = query_config.where
|
||||
else:
|
||||
where = {}
|
||||
|
||||
@@ -273,64 +298,21 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
contents = self.db.query(
|
||||
input_query=input_query,
|
||||
n_results=config.number_documents,
|
||||
n_results=query_config.number_documents,
|
||||
where=where,
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
def _append_search_and_context(self, context, web_search_result):
|
||||
return f"{context}\nWeb Search Result: {web_search_result}"
|
||||
|
||||
def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param contexts: List of similar documents to the query used as context.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:return: The prompt
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
if not config.history:
|
||||
prompt = config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
|
||||
return prompt
|
||||
|
||||
def get_answer_from_llm(self, prompt, config: ChatConfig):
|
||||
"""
|
||||
Gets an answer based on the given query and context by passing it
|
||||
to an LLM.
|
||||
|
||||
:param query: The query to use.
|
||||
:param context: Similar documents to the query used as context.
|
||||
:return: The answer.
|
||||
"""
|
||||
|
||||
return self.get_llm_model_answer(prompt, config)
|
||||
|
||||
def access_search_and_get_results(self, input_query):
|
||||
from langchain.tools import DuckDuckGoSearchRun
|
||||
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
@@ -340,41 +322,16 @@ class EmbedChain(JSONSerializable):
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
if self.is_docs_site_instance:
|
||||
config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config, where)
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
||||
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
||||
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("query",))
|
||||
thread_telemetry.start()
|
||||
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
return self._stream_query_response(answer)
|
||||
return answer
|
||||
|
||||
def _stream_query_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -382,8 +339,8 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
Maintains the whole conversation in memory.
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `ChatConfig` instance to use as
|
||||
configuration options.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
@@ -393,50 +350,14 @@ class EmbedChain(JSONSerializable):
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
config = ChatConfig()
|
||||
if self.is_docs_site_instance:
|
||||
config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config, where)
|
||||
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
|
||||
if chat_history:
|
||||
config.set_history(chat_history)
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
|
||||
self.memory.chat_memory.add_user_message(input_query)
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
||||
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
||||
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",))
|
||||
thread_telemetry.start()
|
||||
|
||||
if isinstance(answer, str):
|
||||
self.memory.chat_memory.add_ai_message(answer)
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
return self._stream_chat_response(answer)
|
||||
|
||||
def _stream_chat_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
return answer
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
"""
|
||||
@@ -444,34 +365,36 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
:param collection_name: The name of the collection to use.
|
||||
"""
|
||||
self.collection = self.config.db._get_or_create_collection(collection_name)
|
||||
self.db.set_collection_name(collection_name)
|
||||
# Create the collection if it does not exist
|
||||
self.db._get_or_create_collection(collection_name)
|
||||
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
|
||||
# since the main purpose is the creation.
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count the number of embeddings.
|
||||
|
||||
DEPRECATED IN FAVOR OF `db.count()`
|
||||
|
||||
:return: The number of embeddings.
|
||||
"""
|
||||
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
|
||||
return self.db.count()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
|
||||
DEPRECATED IN FAVOR OF `db.reset()`
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
||||
thread_telemetry.start()
|
||||
|
||||
collection_name = self.collection.name
|
||||
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
|
||||
self.db.reset()
|
||||
self.collection = self.config.db._get_or_create_collection(collection_name)
|
||||
# Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
|
||||
# A downside of this implementation is, if you have two instances,
|
||||
# the other instance will not get the updated `self.collection` attribute.
|
||||
# A better way would be to create the collection if it is called again after being reset.
|
||||
# That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
|
||||
# That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
||||
|
||||
Reference in New Issue
Block a user