refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -11,30 +11,37 @@ from typing import Dict, Optional
import requests
from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory
from tenacity import retry, stop_after_attempt, wait_fixed
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, ChatConfig, QueryConfig
from embedchain.config import AddConfig, BaseLlmConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.llm.base_llm import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
from embedchain.vectordb.base_vector_db import BaseVectorDB
load_dotenv()
ABS_PATH = os.getcwd()
DB_DIR = os.path.join(ABS_PATH, "db")
HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
class EmbedChain(JSONSerializable):
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
def __init__(
self,
config: BaseAppConfig,
llm: BaseLlm,
db: BaseVectorDB = None,
embedder: BaseEmbedder = None,
system_prompt: Optional[str] = None,
):
"""
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
@@ -44,17 +51,40 @@ class EmbedChain(JSONSerializable):
"""
self.config = config
self.system_prompt = system_prompt
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
self.db = self.config.db
# Add subclasses
## Llm
self.llm = llm
## Database
# Database has support for config assignment for backwards compatibility
if db is None and (not hasattr(self.config, "db") or self.config.db is None):
raise ValueError("App requires Database.")
self.db = db or self.config.db
## Embedder
if embedder is None:
raise ValueError("App requires Embedder.")
self.embedder = embedder
# Initialize database
self.db._set_embedder(self.embedder)
self.db._initialize()
# Set collection name from app config for backwards compatibility.
if config.collection_name:
self.db.set_collection_name(config.collection_name)
# Add variables that are "shortcuts"
if system_prompt:
self.llm.config.system_prompt = system_prompt
# Attributes that aren't subclass related.
self.user_asks = []
self.is_docs_site_instance = False
self.online = False
self.memory = ConversationBufferMemory()
# Send anonymous telemetry
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id()
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# if (self.config.collect_metrics):
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
thread_telemetry.start()
@@ -227,10 +257,10 @@ class EmbedChain(JSONSerializable):
metadatas = new_metadatas
# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()
chunks_before_addition = self.db.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
@@ -244,13 +274,7 @@ class EmbedChain(JSONSerializable):
)
]
def get_llm_model_answer(self):
"""
Usually implemented by child class
"""
raise NotImplementedError
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
@@ -260,11 +284,12 @@ class EmbedChain(JSONSerializable):
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The content of the document that matched your query.
"""
query_config = config or self.llm.config
if where is not None:
where = where
elif config is not None and config.where is not None:
where = config.where
elif query_config is not None and query_config.where is not None:
where = query_config.where
else:
where = {}
@@ -273,64 +298,21 @@ class EmbedChain(JSONSerializable):
contents = self.db.query(
input_query=input_query,
n_results=config.number_documents,
n_results=query_config.number_documents,
where=where,
)
return contents
def _append_search_and_context(self, context, web_search_result):
return f"{context}\nWeb Search Result: {web_search_result}"
def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
:param input_query: The query to use.
:param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:return: The prompt
"""
context_string = (" | ").join(contexts)
web_search_result = kwargs.get("web_search_result", "")
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
if not config.history:
prompt = config.template.substitute(context=context_string, query=input_query)
else:
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
return prompt
def get_answer_from_llm(self, prompt, config: ChatConfig):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
:param query: The query to use.
:param context: Similar documents to the query used as context.
:return: The answer.
"""
return self.get_llm_model_answer(prompt, config)
def access_search_and_get_results(self, input_query):
from langchain.tools import DuckDuckGoSearchRun
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
@@ -340,41 +322,16 @@ class EmbedChain(JSONSerializable):
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
config = QueryConfig()
if self.is_docs_site_instance:
config.template = DOCS_SITE_PROMPT_TEMPLATE
config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config, where)
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt, config)
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("query",))
thread_telemetry.start()
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
return answer
def _stream_query_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -382,8 +339,8 @@ class EmbedChain(JSONSerializable):
Maintains the whole conversation in memory.
:param input_query: The query to use.
:param config: Optional. The `ChatConfig` instance to use as
configuration options.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
@@ -393,50 +350,14 @@ class EmbedChain(JSONSerializable):
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
config = ChatConfig()
if self.is_docs_site_instance:
config.template = DOCS_SITE_PROMPT_TEMPLATE
config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config, where)
chat_history = self.memory.load_memory_variables({})["history"]
if chat_history:
config.set_history(chat_history)
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt, config)
self.memory.chat_memory.add_user_message(input_query)
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",))
thread_telemetry.start()
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
def _stream_chat_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
return answer
def set_collection(self, collection_name):
"""
@@ -444,34 +365,36 @@ class EmbedChain(JSONSerializable):
:param collection_name: The name of the collection to use.
"""
self.collection = self.config.db._get_or_create_collection(collection_name)
self.db.set_collection_name(collection_name)
# Create the collection if it does not exist
self.db._get_or_create_collection(collection_name)
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
# since the main purpose is the creation.
def count(self) -> int:
"""
Count the number of embeddings.
DEPRECATED IN FAVOR OF `db.count()`
:return: The number of embeddings.
"""
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
return self.db.count()
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
DEPRECATED IN FAVOR OF `db.reset()`
"""
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
thread_telemetry.start()
collection_name = self.collection.name
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
self.db.reset()
self.collection = self.config.db._get_or_create_collection(collection_name)
# Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
# A downside of this implementation is, if you have two instances,
# the other instance will not get the updated `self.collection` attribute.
# A better way would be to create the collection if it is called again after being reset.
# That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
# That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):