[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)
This commit is contained in:
@@ -9,40 +9,41 @@ from embedchain.memory.message import ChatMessage
|
||||
from embedchain.memory.utils import merge_metadata_dict
|
||||
|
||||
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
|
||||
CREATE TABLE IF NOT EXISTS chat_history (
|
||||
app_id TEXT,
|
||||
id TEXT,
|
||||
question TEXT,
|
||||
answer TEXT,
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id, app_id)
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS ec_chat_history (
|
||||
app_id TEXT,
|
||||
id TEXT,
|
||||
session_id TEXT,
|
||||
question TEXT,
|
||||
answer TEXT,
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id, app_id, session_id)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class ECChatMemory:
|
||||
class ChatHistory:
|
||||
def __init__(self) -> None:
|
||||
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
|
||||
self.connection.commit()
|
||||
|
||||
def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
|
||||
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
|
||||
if metadata_dict:
|
||||
metadata = self._serialize_json(metadata_dict)
|
||||
ADD_CHAT_MESSAGE_QUERY = """
|
||||
INSERT INTO chat_history (app_id, id, question, answer, metadata)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
self.cursor.execute(
|
||||
ADD_CHAT_MESSAGE_QUERY,
|
||||
(
|
||||
app_id,
|
||||
memory_id,
|
||||
session_id,
|
||||
chat_message.human_message.content,
|
||||
chat_message.ai_message.content,
|
||||
metadata if metadata_dict else "{}",
|
||||
@@ -52,37 +53,41 @@ class ECChatMemory:
|
||||
logging.info(f"Added chat memory to db with id: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
def delete_chat_history(self, app_id: str):
|
||||
DELETE_CHAT_HISTORY_QUERY = """
|
||||
DELETE FROM chat_history WHERE app_id=?
|
||||
def delete(self, app_id: str, session_id: str):
|
||||
"""
|
||||
self.cursor.execute(
|
||||
DELETE_CHAT_HISTORY_QUERY,
|
||||
(app_id,),
|
||||
)
|
||||
Delete all chat history for a given app_id and session_id.
|
||||
This is useful for deleting chat history for a given user.
|
||||
|
||||
:param app_id: The app_id to delete chat history for
|
||||
:param session_id: The session_id to delete chat history for
|
||||
|
||||
:return: None
|
||||
"""
|
||||
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
||||
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
|
||||
self.connection.commit()
|
||||
|
||||
def get_recent_memories(self, app_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
|
||||
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
|
||||
"""
|
||||
Get the most recent num_rounds rounds of conversations
|
||||
between human and AI, for a given app_id.
|
||||
"""
|
||||
|
||||
QUERY = """
|
||||
SELECT * FROM chat_history
|
||||
WHERE app_id=?
|
||||
SELECT * FROM ec_chat_history
|
||||
WHERE app_id=? AND session_id=?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
self.cursor.execute(
|
||||
QUERY,
|
||||
(app_id, num_rounds),
|
||||
(app_id, session_id, num_rounds),
|
||||
)
|
||||
|
||||
results = self.cursor.fetchall()
|
||||
history = []
|
||||
for result in results:
|
||||
app_id, _, question, answer, metadata, timestamp = result
|
||||
app_id, _, session_id, question, answer, metadata, timestamp = result
|
||||
metadata = self._deserialize_json(metadata=metadata)
|
||||
# Return list of dict if display_format is True
|
||||
if display_format:
|
||||
@@ -94,6 +99,20 @@ class ECChatMemory:
|
||||
history.append(memory)
|
||||
return history
|
||||
|
||||
def count(self, app_id: str, session_id: str):
|
||||
"""
|
||||
Count the number of chat messages for a given app_id and session_id.
|
||||
|
||||
:param app_id: The app_id to count chat history for
|
||||
:param session_id: The session_id to count chat history for
|
||||
|
||||
:return: The number of chat messages for a given app_id and session_id
|
||||
"""
|
||||
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
||||
self.cursor.execute(QUERY, (app_id, session_id))
|
||||
count = self.cursor.fetchone()[0]
|
||||
return count
|
||||
|
||||
def _serialize_json(self, metadata: Dict[str, Any]):
|
||||
return json.dumps(metadata)
|
||||
|
||||
@@ -102,15 +121,3 @@ class ECChatMemory:
|
||||
|
||||
def close_connection(self):
|
||||
self.connection.close()
|
||||
|
||||
def count_history_messages(self, app_id: str):
|
||||
QUERY = """
|
||||
SELECT COUNT(*) FROM chat_history
|
||||
WHERE app_id=?
|
||||
"""
|
||||
self.cursor.execute(
|
||||
QUERY,
|
||||
(app_id,),
|
||||
)
|
||||
count = self.cursor.fetchone()[0]
|
||||
return count
|
||||
|
||||
@@ -14,16 +14,16 @@ class BaseMessage(JSONSerializable):
|
||||
# The string content of the message.
|
||||
content: str
|
||||
|
||||
# The creator of the message. AI, Human, Bot etc.
|
||||
by: str
|
||||
# The created_by of the message. AI, Human, Bot etc.
|
||||
created_by: str
|
||||
|
||||
# Any additional info.
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self.content = content
|
||||
self.creator = creator
|
||||
self.created_by = created_by
|
||||
self.metadata = metadata
|
||||
|
||||
@property
|
||||
@@ -36,7 +36,7 @@ class BaseMessage(JSONSerializable):
|
||||
return True
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.creator}: {self.content}"
|
||||
return f"{self.created_by}: {self.content}"
|
||||
|
||||
|
||||
class ChatMessage(JSONSerializable):
|
||||
@@ -57,7 +57,7 @@ class ChatMessage(JSONSerializable):
|
||||
overwritting it with new message."
|
||||
)
|
||||
|
||||
self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
|
||||
self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
|
||||
|
||||
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
|
||||
if self.ai_message:
|
||||
@@ -66,7 +66,7 @@ class ChatMessage(JSONSerializable):
|
||||
overwritting it with new message."
|
||||
)
|
||||
|
||||
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
|
||||
self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.human_message}\n{self.ai_message}"
|
||||
|
||||
Reference in New Issue
Block a user