Files
t6_mem0/embedchain/memory/base.py

126 lines
4.4 KiB
Python

import json
import logging
import sqlite3
import uuid
from typing import Any, Dict, List, Optional
from embedchain.constants import SQLITE_PATH
from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS ec_chat_history (
app_id TEXT,
id TEXT,
session_id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id, session_id)
)
"""
class ChatHistory:
def __init__(self) -> None:
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
self.cursor = self.connection.cursor()
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
self.connection.commit()
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
memory_id = str(uuid.uuid4())
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
if metadata_dict:
metadata = self._serialize_json(metadata_dict)
ADD_CHAT_MESSAGE_QUERY = """
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
VALUES (?, ?, ?, ?, ?, ?)
"""
self.cursor.execute(
ADD_CHAT_MESSAGE_QUERY,
(
app_id,
memory_id,
session_id,
chat_message.human_message.content,
chat_message.ai_message.content,
metadata if metadata_dict else "{}",
),
)
self.connection.commit()
logging.info(f"Added chat memory to db with id: {memory_id}")
return memory_id
def delete(self, app_id: str, session_id: str):
"""
Delete all chat history for a given app_id and session_id.
This is useful for deleting chat history for a given user.
:param app_id: The app_id to delete chat history for
:param session_id: The session_id to delete chat history for
:return: None
"""
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
self.connection.commit()
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
"""
Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id.
"""
QUERY = """
SELECT * FROM ec_chat_history
WHERE app_id=? AND session_id=?
ORDER BY created_at DESC
LIMIT ?
"""
self.cursor.execute(
QUERY,
(app_id, session_id, num_rounds),
)
results = self.cursor.fetchall()
history = []
for result in results:
app_id, _, session_id, question, answer, metadata, timestamp = result
metadata = self._deserialize_json(metadata=metadata)
# Return list of dict if display_format is True
if display_format:
history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
else:
memory = ChatMessage()
memory.add_user_message(question, metadata=metadata)
memory.add_ai_message(answer, metadata=metadata)
history.append(memory)
return history
def count(self, app_id: str, session_id: str):
"""
Count the number of chat messages for a given app_id and session_id.
:param app_id: The app_id to count chat history for
:param session_id: The session_id to count chat history for
:return: The number of chat messages for a given app_id and session_id
"""
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
self.cursor.execute(QUERY, (app_id, session_id))
count = self.cursor.fetchone()[0]
return count
@staticmethod
def _serialize_json(metadata: Dict[str, Any]):
return json.dumps(metadata)
@staticmethod
def _deserialize_json(metadata: str):
return json.loads(metadata)
def close_connection(self):
self.connection.close()