[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)
This commit is contained in:
@@ -7,9 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.cache import (adapt, get_gptcache_session,
|
||||
gptcache_data_convert,
|
||||
gptcache_update_cache_callback)
|
||||
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
@@ -19,8 +17,7 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
@@ -580,6 +577,7 @@ class EmbedChain(JSONSerializable):
|
||||
input_query: str,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
dry_run=False,
|
||||
session_id: str = "default",
|
||||
where: Optional[Dict[str, str]] = None,
|
||||
citations: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
@@ -599,6 +597,8 @@ class EmbedChain(JSONSerializable):
|
||||
:param dry_run: A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param session_id: The session id to use for chat history, defaults to 'default'.
|
||||
:type session_id: Optional[str], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Optional[Dict[str, str]], optional
|
||||
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||
@@ -616,6 +616,9 @@ class EmbedChain(JSONSerializable):
|
||||
else:
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
# Update the history beforehand so that we can handle multiple chat sessions in the same python session
|
||||
self.llm.update_history(app_id=self.config.id, session_id=session_id)
|
||||
|
||||
if self.cache_config is not None:
|
||||
logging.info("Cache enabled. Checking cache...")
|
||||
answer = adapt(
|
||||
@@ -634,7 +637,7 @@ class EmbedChain(JSONSerializable):
|
||||
)
|
||||
|
||||
# add conversation in memory
|
||||
self.llm.add_history(self.config.id, input_query, answer)
|
||||
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
||||
@@ -684,12 +687,8 @@ class EmbedChain(JSONSerializable):
|
||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||
|
||||
def get_history(self, num_rounds: int = 10, display_format: bool = True):
|
||||
return self.llm.memory.get_recent_memories(
|
||||
app_id=self.config.id,
|
||||
num_rounds=num_rounds,
|
||||
display_format=display_format,
|
||||
)
|
||||
return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format)
|
||||
|
||||
def delete_chat_history(self):
|
||||
self.llm.memory.delete_chat_history(app_id=self.config.id)
|
||||
def delete_chat_history(self, session_id: str = "default"):
|
||||
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
|
||||
self.llm.update_history(app_id=self.config.id)
|
||||
|
||||
Reference in New Issue
Block a user