[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)

This commit is contained in:
Deshraj Yadav
2023-12-30 15:36:24 +05:30
committed by GitHub
parent 52b4577d3b
commit a54dde0509
10 changed files with 124 additions and 187 deletions

View File

@@ -8,7 +8,7 @@ from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ECChatMemory
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -24,7 +24,7 @@ class BaseLlm(JSONSerializable):
else:
self.config = config
self.memory = ECChatMemory()
self.memory = ChatHistory()
self.is_docs_site_instance = False
self.online = False
self.history: Any = None
@@ -45,17 +45,24 @@ class BaseLlm(JSONSerializable):
"""
self.history = history
def update_history(self, app_id: str):
def update_history(self, app_id: str, session_id: str = "default"):
"""Update class history attribute with history in memory (for chat method)"""
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
self.set_history([str(history) for history in chat_history])
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
def add_history(
self,
app_id: str,
question: str,
answer: str,
metadata: Optional[Dict[str, Any]] = None,
session_id: str = "default",
):
chat_message = ChatMessage()
chat_message.add_user_message(question, metadata=metadata)
chat_message.add_ai_message(answer, metadata=metadata)
self.memory.add(app_id=app_id, chat_message=chat_message)
self.update_history(app_id=app_id)
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
self.update_history(app_id=app_id, session_id=session_id)
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
"""
@@ -212,7 +219,9 @@ class BaseLlm(JSONSerializable):
# Restore previous config
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
def chat(
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -230,6 +239,8 @@ class BaseLlm(JSONSerializable):
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param session_id: Session ID to use for the conversation, defaults to None
:type session_id: str, optional
:return: The answer to the query or the dry run result
:rtype: str
"""