165 lines
5.7 KiB
Python
165 lines
5.7 KiB
Python
import json
|
|
import logging
|
|
import sqlite3
|
|
import uuid
|
|
from typing import Any, Optional
|
|
|
|
from embedchain.constants import SQLITE_PATH
|
|
from embedchain.memory.message import ChatMessage
|
|
from embedchain.memory.utils import merge_metadata_dict
|
|
|
|
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
|
|
CREATE TABLE IF NOT EXISTS ec_chat_history (
|
|
app_id TEXT,
|
|
id TEXT,
|
|
session_id TEXT,
|
|
question TEXT,
|
|
answer TEXT,
|
|
metadata TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (id, app_id, session_id)
|
|
)
|
|
"""
|
|
|
|
|
|
class ChatHistory:
|
|
def __init__(self) -> None:
|
|
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
|
|
self.cursor = self.connection.cursor()
|
|
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
|
|
self.connection.commit()
|
|
|
|
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
|
|
memory_id = str(uuid.uuid4())
|
|
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
|
|
if metadata_dict:
|
|
metadata = self._serialize_json(metadata_dict)
|
|
ADD_CHAT_MESSAGE_QUERY = """
|
|
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
"""
|
|
self.cursor.execute(
|
|
ADD_CHAT_MESSAGE_QUERY,
|
|
(
|
|
app_id,
|
|
memory_id,
|
|
session_id,
|
|
chat_message.human_message.content,
|
|
chat_message.ai_message.content,
|
|
metadata if metadata_dict else "{}",
|
|
),
|
|
)
|
|
self.connection.commit()
|
|
logging.info(f"Added chat memory to db with id: {memory_id}")
|
|
return memory_id
|
|
|
|
def delete(self, app_id: str, session_id: Optional[str] = None):
|
|
"""
|
|
Delete all chat history for a given app_id and session_id.
|
|
This is useful for deleting chat history for a given user.
|
|
|
|
:param app_id: The app_id to delete chat history for
|
|
:param session_id: The session_id to delete chat history for
|
|
|
|
:return: None
|
|
"""
|
|
if session_id:
|
|
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
|
params = (app_id, session_id)
|
|
else:
|
|
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
|
|
params = (app_id,)
|
|
|
|
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
|
|
self.connection.commit()
|
|
|
|
def get(
|
|
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
|
|
) -> list[ChatMessage]:
|
|
"""
|
|
Get the chat history for a given app_id.
|
|
|
|
param: app_id - The app_id to get chat history
|
|
param: session_id (optional) - The session_id to get chat history. Defaults to "default"
|
|
param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
|
|
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
|
|
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
|
|
"""
|
|
|
|
base_query = """
|
|
SELECT * FROM ec_chat_history
|
|
WHERE app_id=?
|
|
"""
|
|
|
|
if fetch_all:
|
|
additional_query = "ORDER BY created_at ASC"
|
|
params = (app_id,)
|
|
else:
|
|
additional_query = """
|
|
AND session_id=?
|
|
ORDER BY created_at ASC
|
|
LIMIT ?
|
|
"""
|
|
params = (app_id, session_id, num_rounds)
|
|
|
|
QUERY = base_query + additional_query
|
|
|
|
self.cursor.execute(
|
|
QUERY,
|
|
params,
|
|
)
|
|
|
|
results = self.cursor.fetchall()
|
|
history = []
|
|
for result in results:
|
|
app_id, _, session_id, question, answer, metadata, timestamp = result
|
|
metadata = self._deserialize_json(metadata=metadata)
|
|
# Return list of dict if display_format is True
|
|
if display_format:
|
|
history.append(
|
|
{
|
|
"session_id": session_id,
|
|
"human": question,
|
|
"ai": answer,
|
|
"metadata": metadata,
|
|
"timestamp": timestamp,
|
|
}
|
|
)
|
|
else:
|
|
memory = ChatMessage()
|
|
memory.add_user_message(question, metadata=metadata)
|
|
memory.add_ai_message(answer, metadata=metadata)
|
|
history.append(memory)
|
|
return history
|
|
|
|
def count(self, app_id: str, session_id: Optional[str] = None):
|
|
"""
|
|
Count the number of chat messages for a given app_id and session_id.
|
|
|
|
:param app_id: The app_id to count chat history for
|
|
:param session_id: The session_id to count chat history for
|
|
|
|
:return: The number of chat messages for a given app_id and session_id
|
|
"""
|
|
if session_id:
|
|
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
|
params = (app_id, session_id)
|
|
else:
|
|
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
|
|
params = (app_id,)
|
|
|
|
self.cursor.execute(QUERY, params)
|
|
count = self.cursor.fetchone()[0]
|
|
return count
|
|
|
|
@staticmethod
|
|
def _serialize_json(metadata: dict[str, Any]):
|
|
return json.dumps(metadata)
|
|
|
|
@staticmethod
|
|
def _deserialize_json(metadata: str):
|
|
return json.loads(metadata)
|
|
|
|
def close_connection(self):
|
|
self.connection.close()
|