[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)
This commit is contained in:
@@ -8,7 +8,7 @@ from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class BaseLlm(JSONSerializable):
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.memory = ECChatMemory()
|
||||
self.memory = ChatHistory()
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.history: Any = None
|
||||
@@ -45,17 +45,24 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
self.history = history
|
||||
|
||||
def update_history(self, app_id: str):
|
||||
def update_history(self, app_id: str, session_id: str = "default"):
|
||||
"""Update class history attribute with history in memory (for chat method)"""
|
||||
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
|
||||
chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
|
||||
self.set_history([str(history) for history in chat_history])
|
||||
|
||||
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
def add_history(
|
||||
self,
|
||||
app_id: str,
|
||||
question: str,
|
||||
answer: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
session_id: str = "default",
|
||||
):
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(question, metadata=metadata)
|
||||
chat_message.add_ai_message(answer, metadata=metadata)
|
||||
self.memory.add(app_id=app_id, chat_message=chat_message)
|
||||
self.update_history(app_id=app_id)
|
||||
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
|
||||
self.update_history(app_id=app_id, session_id=session_id)
|
||||
|
||||
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
@@ -212,7 +219,9 @@ class BaseLlm(JSONSerializable):
|
||||
# Restore previous config
|
||||
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
||||
|
||||
def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
def chat(
|
||||
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
|
||||
):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -230,6 +239,8 @@ class BaseLlm(JSONSerializable):
|
||||
:param dry_run: A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param session_id: Session ID to use for the conversation, defaults to None
|
||||
:type session_id: str, optional
|
||||
:return: The answer to the query or the dry run result
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user