|
|
|
|
@@ -1,5 +1,6 @@
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, Dict, Generator, List, Optional
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from langchain.schema import BaseMessage as LCBaseMessage
|
|
|
|
|
|
|
|
|
|
@@ -55,7 +56,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
app_id: str,
|
|
|
|
|
question: str,
|
|
|
|
|
answer: str,
|
|
|
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
|
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
|
|
|
session_id: str = "default",
|
|
|
|
|
):
|
|
|
|
|
chat_message = ChatMessage()
|
|
|
|
|
@@ -64,7 +65,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
|
|
|
|
|
self.update_history(app_id=app_id, session_id=session_id)
|
|
|
|
|
|
|
|
|
|
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
|
|
|
|
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generates a prompt based on the given query and context, ready to be
|
|
|
|
|
passed to an LLM
|
|
|
|
|
@@ -72,7 +73,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
:param input_query: The query to use.
|
|
|
|
|
:type input_query: str
|
|
|
|
|
:param contexts: List of similar documents to the query used as context.
|
|
|
|
|
:type contexts: List[str]
|
|
|
|
|
:type contexts: list[str]
|
|
|
|
|
:return: The prompt
|
|
|
|
|
:rtype: str
|
|
|
|
|
"""
|
|
|
|
|
@@ -170,7 +171,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
yield chunk
|
|
|
|
|
logging.info(f"Answer: {streamed_answer}")
|
|
|
|
|
|
|
|
|
|
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
|
|
|
|
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
|
|
|
|
|
"""
|
|
|
|
|
Queries the vector database based on the given input query.
|
|
|
|
|
Gets relevant doc based on the query and then passes it to an
|
|
|
|
|
@@ -179,7 +180,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
:param input_query: The query to use.
|
|
|
|
|
:type input_query: str
|
|
|
|
|
:param contexts: Embeddings retrieved from the database to be used as context.
|
|
|
|
|
:type contexts: List[str]
|
|
|
|
|
:type contexts: list[str]
|
|
|
|
|
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
|
|
|
|
|
To persistently use a config, declare it during app init., defaults to None
|
|
|
|
|
:type config: Optional[BaseLlmConfig], optional
|
|
|
|
|
@@ -223,7 +224,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
|
|
|
|
|
|
|
|
|
def chat(
|
|
|
|
|
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
|
|
|
|
|
self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Queries the vector database on the given input query.
|
|
|
|
|
@@ -235,7 +236,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
:param input_query: The query to use.
|
|
|
|
|
:type input_query: str
|
|
|
|
|
:param contexts: Embeddings retrieved from the database to be used as context.
|
|
|
|
|
:type contexts: List[str]
|
|
|
|
|
:type contexts: list[str]
|
|
|
|
|
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
|
|
|
|
|
To persistently use a config, declare it during app init., defaults to None
|
|
|
|
|
:type config: Optional[BaseLlmConfig], optional
|
|
|
|
|
@@ -281,7 +282,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
|
|
|
|
|
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
|
|
|
|
|
"""
|
|
|
|
|
Construct a list of langchain messages
|
|
|
|
|
|
|
|
|
|
@@ -290,7 +291,7 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
:param system_prompt: System prompt, defaults to None
|
|
|
|
|
:type system_prompt: Optional[str], optional
|
|
|
|
|
:return: List of messages
|
|
|
|
|
:rtype: List[BaseMessage]
|
|
|
|
|
:rtype: list[BaseMessage]
|
|
|
|
|
"""
|
|
|
|
|
from langchain.schema import HumanMessage, SystemMessage
|
|
|
|
|
|
|
|
|
|
|