[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)

This commit is contained in:
Deshraj Yadav
2023-12-30 15:36:24 +05:30
committed by GitHub
parent 52b4577d3b
commit a54dde0509
10 changed files with 124 additions and 187 deletions

View File

@@ -7,9 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -19,8 +17,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -580,6 +577,7 @@ class EmbedChain(JSONSerializable):
input_query: str,
config: Optional[BaseLlmConfig] = None,
dry_run=False,
session_id: str = "default",
where: Optional[Dict[str, str]] = None,
citations: bool = False,
**kwargs: Dict[str, Any],
@@ -599,6 +597,8 @@ class EmbedChain(JSONSerializable):
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param session_id: The session id to use for chat history, defaults to 'default'.
:type session_id: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional
:param kwargs: To read more params for the query function. Ex. we use citations boolean
@@ -616,6 +616,9 @@ class EmbedChain(JSONSerializable):
else:
contexts_data_for_llm_query = contexts
# Update the history beforehand so that we can handle multiple chat sessions in the same python session
self.llm.update_history(app_id=self.config.id, session_id=session_id)
if self.cache_config is not None:
logging.info("Cache enabled. Checking cache...")
answer = adapt(
@@ -634,7 +637,7 @@ class EmbedChain(JSONSerializable):
)
# add conversation in memory
self.llm.add_history(self.config.id, input_query, answer)
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
# Send anonymous telemetry
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
@@ -684,12 +687,8 @@ class EmbedChain(JSONSerializable):
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
def get_history(self, num_rounds: int = 10, display_format: bool = True):
return self.llm.memory.get_recent_memories(
app_id=self.config.id,
num_rounds=num_rounds,
display_format=display_format,
)
return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format)
def delete_chat_history(self):
self.llm.memory.delete_chat_history(app_id=self.config.id)
def delete_chat_history(self, session_id: str = "default"):
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
self.llm.update_history(app_id=self.config.id)

View File

@@ -8,7 +8,7 @@ from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ECChatMemory
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -24,7 +24,7 @@ class BaseLlm(JSONSerializable):
else:
self.config = config
self.memory = ECChatMemory()
self.memory = ChatHistory()
self.is_docs_site_instance = False
self.online = False
self.history: Any = None
@@ -45,17 +45,24 @@ class BaseLlm(JSONSerializable):
"""
self.history = history
def update_history(self, app_id: str):
def update_history(self, app_id: str, session_id: str = "default"):
"""Update class history attribute with history in memory (for chat method)"""
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
self.set_history([str(history) for history in chat_history])
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
def add_history(
self,
app_id: str,
question: str,
answer: str,
metadata: Optional[Dict[str, Any]] = None,
session_id: str = "default",
):
chat_message = ChatMessage()
chat_message.add_user_message(question, metadata=metadata)
chat_message.add_ai_message(answer, metadata=metadata)
self.memory.add(app_id=app_id, chat_message=chat_message)
self.update_history(app_id=app_id)
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
self.update_history(app_id=app_id, session_id=session_id)
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
"""
@@ -212,7 +219,9 @@ class BaseLlm(JSONSerializable):
# Restore previous config
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
def chat(
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -230,6 +239,8 @@ class BaseLlm(JSONSerializable):
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param session_id: Session ID to use for the conversation, defaults to None
:type session_id: str, optional
:return: The answer to the query or the dry run result
:rtype: str
"""

View File

@@ -9,40 +9,41 @@ from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS chat_history (
app_id TEXT,
id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id)
)
"""
CREATE TABLE IF NOT EXISTS ec_chat_history (
app_id TEXT,
id TEXT,
session_id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id, session_id)
)
"""
class ECChatMemory:
class ChatHistory:
def __init__(self) -> None:
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
self.cursor = self.connection.cursor()
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
self.connection.commit()
def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
memory_id = str(uuid.uuid4())
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
if metadata_dict:
metadata = self._serialize_json(metadata_dict)
ADD_CHAT_MESSAGE_QUERY = """
INSERT INTO chat_history (app_id, id, question, answer, metadata)
VALUES (?, ?, ?, ?, ?)
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
VALUES (?, ?, ?, ?, ?, ?)
"""
self.cursor.execute(
ADD_CHAT_MESSAGE_QUERY,
(
app_id,
memory_id,
session_id,
chat_message.human_message.content,
chat_message.ai_message.content,
metadata if metadata_dict else "{}",
@@ -52,37 +53,41 @@ class ECChatMemory:
logging.info(f"Added chat memory to db with id: {memory_id}")
return memory_id
def delete_chat_history(self, app_id: str):
DELETE_CHAT_HISTORY_QUERY = """
DELETE FROM chat_history WHERE app_id=?
def delete(self, app_id: str, session_id: str):
"""
self.cursor.execute(
DELETE_CHAT_HISTORY_QUERY,
(app_id,),
)
Delete all chat history for a given app_id and session_id.
This is useful for deleting chat history for a given user.
:param app_id: The app_id to delete chat history for
:param session_id: The session_id to delete chat history for
:return: None
"""
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
self.connection.commit()
def get_recent_memories(self, app_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
"""
Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id.
"""
QUERY = """
SELECT * FROM chat_history
WHERE app_id=?
SELECT * FROM ec_chat_history
WHERE app_id=? AND session_id=?
ORDER BY created_at DESC
LIMIT ?
"""
self.cursor.execute(
QUERY,
(app_id, num_rounds),
(app_id, session_id, num_rounds),
)
results = self.cursor.fetchall()
history = []
for result in results:
app_id, _, question, answer, metadata, timestamp = result
app_id, _, session_id, question, answer, metadata, timestamp = result
metadata = self._deserialize_json(metadata=metadata)
# Return list of dict if display_format is True
if display_format:
@@ -94,6 +99,20 @@ class ECChatMemory:
history.append(memory)
return history
def count(self, app_id: str, session_id: str):
"""
Count the number of chat messages for a given app_id and session_id.
:param app_id: The app_id to count chat history for
:param session_id: The session_id to count chat history for
:return: The number of chat messages for a given app_id and session_id
"""
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
self.cursor.execute(QUERY, (app_id, session_id))
count = self.cursor.fetchone()[0]
return count
def _serialize_json(self, metadata: Dict[str, Any]):
return json.dumps(metadata)
@@ -102,15 +121,3 @@ class ECChatMemory:
def close_connection(self):
self.connection.close()
def count_history_messages(self, app_id: str):
QUERY = """
SELECT COUNT(*) FROM chat_history
WHERE app_id=?
"""
self.cursor.execute(
QUERY,
(app_id,),
)
count = self.cursor.fetchone()[0]
return count

View File

@@ -14,16 +14,16 @@ class BaseMessage(JSONSerializable):
# The string content of the message.
content: str
# The creator of the message. AI, Human, Bot etc.
by: str
# The created_by of the message. AI, Human, Bot etc.
created_by: str
# Any additional info.
metadata: Dict[str, Any]
def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
super().__init__()
self.content = content
self.creator = creator
self.created_by = created_by
self.metadata = metadata
@property
@@ -36,7 +36,7 @@ class BaseMessage(JSONSerializable):
return True
def __str__(self) -> str:
return f"{self.creator}: {self.content}"
return f"{self.created_by}: {self.content}"
class ChatMessage(JSONSerializable):
@@ -57,7 +57,7 @@ class ChatMessage(JSONSerializable):
overwritting it with new message."
)
self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
if self.ai_message:
@@ -66,7 +66,7 @@ class ChatMessage(JSONSerializable):
overwritting it with new message."
)
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)
def __str__(self) -> str:
return f"{self.human_message}\n{self.ai_message}"