126 lines
4.8 KiB
Python
126 lines
4.8 KiB
Python
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Any, Optional
|
|
|
|
from embedchain.core.db.database import get_session
|
|
from embedchain.core.db.models import ChatHistory as ChatHistoryModel
|
|
from embedchain.memory.message import ChatMessage
|
|
from embedchain.memory.utils import merge_metadata_dict
|
|
|
|
|
|
class ChatHistory:
|
|
def __init__(self) -> None:
|
|
self.db_session = get_session()
|
|
|
|
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
|
|
memory_id = str(uuid.uuid4())
|
|
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
|
|
if metadata_dict:
|
|
metadata = self._serialize_json(metadata_dict)
|
|
self.db_session.add(
|
|
ChatHistoryModel(
|
|
app_id=app_id,
|
|
id=memory_id,
|
|
session_id=session_id,
|
|
question=chat_message.human_message.content,
|
|
answer=chat_message.ai_message.content,
|
|
metadata=metadata if metadata_dict else "{}",
|
|
)
|
|
)
|
|
try:
|
|
self.db_session.commit()
|
|
except Exception as e:
|
|
logging.error(f"Error adding chat memory to db: {e}")
|
|
self.db_session.rollback()
|
|
return None
|
|
|
|
logging.info(f"Added chat memory to db with id: {memory_id}")
|
|
return memory_id
|
|
|
|
def delete(self, app_id: str, session_id: Optional[str] = None):
|
|
"""
|
|
Delete all chat history for a given app_id and session_id.
|
|
This is useful for deleting chat history for a given user.
|
|
|
|
:param app_id: The app_id to delete chat history for
|
|
:param session_id: The session_id to delete chat history for
|
|
|
|
:return: None
|
|
"""
|
|
params = {"app_id": app_id}
|
|
if session_id:
|
|
params["session_id"] = session_id
|
|
self.db_session.query(ChatHistoryModel).filter_by(**params).delete()
|
|
try:
|
|
self.db_session.commit()
|
|
except Exception as e:
|
|
logging.error(f"Error deleting chat history: {e}")
|
|
self.db_session.rollback()
|
|
|
|
def get(
|
|
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
|
|
) -> list[ChatMessage]:
|
|
"""
|
|
Get the chat history for a given app_id.
|
|
|
|
param: app_id - The app_id to get chat history
|
|
param: session_id (optional) - The session_id to get chat history. Defaults to "default"
|
|
param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
|
|
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
|
|
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
|
|
"""
|
|
params = {"app_id": app_id}
|
|
if not fetch_all:
|
|
params["session_id"] = session_id
|
|
results = (
|
|
self.db_session.query(ChatHistoryModel).filter_by(**params).order_by(ChatHistoryModel.created_at.asc())
|
|
)
|
|
results = results.limit(num_rounds) if not fetch_all else results
|
|
history = []
|
|
for result in results:
|
|
metadata = self._deserialize_json(metadata=result.meta_data or "{}")
|
|
# Return list of dict if display_format is True
|
|
if display_format:
|
|
history.append(
|
|
{
|
|
"session_id": result.session_id,
|
|
"human": result.question,
|
|
"ai": result.answer,
|
|
"metadata": result.meta_data,
|
|
"timestamp": result.created_at,
|
|
}
|
|
)
|
|
else:
|
|
memory = ChatMessage()
|
|
memory.add_user_message(result.question, metadata=metadata)
|
|
memory.add_ai_message(result.answer, metadata=metadata)
|
|
history.append(memory)
|
|
return history
|
|
|
|
def count(self, app_id: str, session_id: Optional[str] = None):
|
|
"""
|
|
Count the number of chat messages for a given app_id and session_id.
|
|
|
|
:param app_id: The app_id to count chat history for
|
|
:param session_id: The session_id to count chat history for
|
|
|
|
:return: The number of chat messages for a given app_id and session_id
|
|
"""
|
|
# Rewrite the logic below with sqlalchemy
|
|
params = {"app_id": app_id}
|
|
if session_id:
|
|
params["session_id"] = session_id
|
|
return self.db_session.query(ChatHistoryModel).filter_by(**params).count()
|
|
|
|
@staticmethod
|
|
def _serialize_json(metadata: dict[str, Any]):
|
|
return json.dumps(metadata)
|
|
|
|
@staticmethod
|
|
def _deserialize_json(metadata: str):
|
|
return json.loads(metadata)
|
|
|
|
def close_connection(self):
|
|
self.connection.close()
|