docs: update docstrings (#565)

This commit is contained in:
cachho
2023-09-07 02:04:44 +02:00
committed by GitHub
parent 4754372fcd
commit 1ac8aef4de
25 changed files with 736 additions and 298 deletions

View File

@@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Any, Dict, Generator, List, Optional
from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage
@@ -13,6 +13,11 @@ from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseLlm(JSONSerializable):
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""Initialize a base LLM class
:param config: LLM configuration option class, defaults to None
:type config: Optional[BaseLlmConfig], optional
"""
if config is None:
self.config = BaseLlmConfig()
else:
@@ -21,7 +26,7 @@ class BaseLlm(JSONSerializable):
self.memory = ConversationBufferMemory()
self.is_docs_site_instance = False
self.online = False
self.history: any = None
self.history: Any = None
def get_llm_model_answer(self):
"""
@@ -29,24 +34,33 @@ class BaseLlm(JSONSerializable):
"""
raise NotImplementedError
def set_history(self, history: any):
def set_history(self, history: Any):
"""
Provide your own history.
Especially interesting for the query method, which does not internally manage conversation history.
:param history: History to set
:type history: Any
"""
self.history = history
def update_history(self):
"""Update class history attribute with history in memory (for chat method)"""
chat_history = self.memory.load_memory_variables({})["history"]
if chat_history:
self.set_history(chat_history)
def generate_prompt(self, input_query, contexts, **kwargs):
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
:param input_query: The query to use.
:type input_query: str
:param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:type contexts: List[str]
:return: The prompt
:rtype: str
"""
context_string = (" | ").join(contexts)
web_search_result = kwargs.get("web_search_result", "")
@@ -73,36 +87,67 @@ class BaseLlm(JSONSerializable):
)
return prompt
def _append_search_and_context(self, context, web_search_result):
def _append_search_and_context(self, context: str, web_search_result: str) -> str:
"""Append web search context to existing context
:param context: Existing context
:type context: str
:param web_search_result: Web search result
:type web_search_result: str
:return: Concatenated web search result
:rtype: str
"""
return f"{context}\nWeb Search Result: {web_search_result}"
def get_answer_from_llm(self, prompt):
def get_answer_from_llm(self, prompt: str):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
:param query: The query to use.
:param context: Similar documents to the query used as context.
:param prompt: Gets an answer based on the given query and context by passing it to an LLM.
:type prompt: str
:return: The answer.
:rtype: _type_
"""
return self.get_llm_model_answer(prompt)
def access_search_and_get_results(self, input_query):
def access_search_and_get_results(self, input_query: str):
"""
Search the internet for additional context
:param input_query: search query
:type input_query: str
:return: Search results
:rtype: Unknown
"""
from langchain.tools import DuckDuckGoSearchRun
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def _stream_query_response(self, answer):
def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
:type answer: Any
:yield: Answer chunk from llm
:rtype: Generator[Any, Any, None]
"""
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def _stream_chat_response(self, answer):
def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
:type answer: Any
:yield: Answer chunk from llm
:rtype: Generator[Any, Any, None]
"""
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
@@ -110,23 +155,24 @@ class BaseLlm(JSONSerializable):
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
:type input_query: str
:param contexts: Embeddings retrieved from the database to be used as context.
:type contexts: List[str]
:param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:return: The answer to the query or the dry run result
:rtype: str
"""
query_config = config or self.config
@@ -150,24 +196,26 @@ class BaseLlm(JSONSerializable):
else:
return self._stream_query_response(answer)
def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
Maintains the whole conversation in memory.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
:type input_query: str
:param contexts: Embeddings retrieved from the database to be used as context.
:type contexts: List[str]
:param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:return: The answer to the query or the dry run result
:rtype: str
"""
query_config = config or self.config
@@ -205,6 +253,16 @@ class BaseLlm(JSONSerializable):
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
"""
Construct a list of langchain messages
:param prompt: User prompt
:type prompt: str
:param system_prompt: System prompt, defaults to None
:type system_prompt: Optional[str], optional
:return: List of messages
:rtype: List[BaseMessage]
"""
from langchain.schema import HumanMessage, SystemMessage
messages = []