[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)
This commit is contained in:
@@ -1,55 +1,40 @@
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.core.db.database import get_session
|
||||
from embedchain.core.db.models import ChatHistory as ChatHistoryModel
|
||||
from embedchain.memory.message import ChatMessage
|
||||
from embedchain.memory.utils import merge_metadata_dict
|
||||
|
||||
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
|
||||
CREATE TABLE IF NOT EXISTS ec_chat_history (
|
||||
app_id TEXT,
|
||||
id TEXT,
|
||||
session_id TEXT,
|
||||
question TEXT,
|
||||
answer TEXT,
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id, app_id, session_id)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
def __init__(self) -> None:
|
||||
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
|
||||
self.cursor = self.connection.cursor()
|
||||
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
|
||||
self.connection.commit()
|
||||
self.db_session = get_session()
|
||||
|
||||
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
|
||||
if metadata_dict:
|
||||
metadata = self._serialize_json(metadata_dict)
|
||||
ADD_CHAT_MESSAGE_QUERY = """
|
||||
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
self.cursor.execute(
|
||||
ADD_CHAT_MESSAGE_QUERY,
|
||||
(
|
||||
app_id,
|
||||
memory_id,
|
||||
session_id,
|
||||
chat_message.human_message.content,
|
||||
chat_message.ai_message.content,
|
||||
metadata if metadata_dict else "{}",
|
||||
),
|
||||
self.db_session.add(
|
||||
ChatHistoryModel(
|
||||
app_id=app_id,
|
||||
id=memory_id,
|
||||
session_id=session_id,
|
||||
question=chat_message.human_message.content,
|
||||
answer=chat_message.ai_message.content,
|
||||
metadata=metadata if metadata_dict else "{}",
|
||||
)
|
||||
)
|
||||
self.connection.commit()
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding chat memory to db: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
|
||||
logging.info(f"Added chat memory to db with id: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
@@ -63,15 +48,15 @@ class ChatHistory:
|
||||
|
||||
:return: None
|
||||
"""
|
||||
params = {"app_id": app_id}
|
||||
if session_id:
|
||||
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
||||
params = (app_id, session_id)
|
||||
else:
|
||||
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
|
||||
params = (app_id,)
|
||||
|
||||
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
|
||||
self.connection.commit()
|
||||
params["session_id"] = session_id
|
||||
self.db_session.query(ChatHistoryModel).filter_by(**params).delete()
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting chat history: {e}")
|
||||
self.db_session.rollback()
|
||||
|
||||
def get(
|
||||
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
|
||||
@@ -85,50 +70,31 @@ class ChatHistory:
|
||||
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
|
||||
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
|
||||
"""
|
||||
|
||||
base_query = """
|
||||
SELECT * FROM ec_chat_history
|
||||
WHERE app_id=?
|
||||
"""
|
||||
|
||||
if fetch_all:
|
||||
additional_query = "ORDER BY created_at ASC"
|
||||
params = (app_id,)
|
||||
else:
|
||||
additional_query = """
|
||||
AND session_id=?
|
||||
ORDER BY created_at ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params = (app_id, session_id, num_rounds)
|
||||
|
||||
QUERY = base_query + additional_query
|
||||
|
||||
self.cursor.execute(
|
||||
QUERY,
|
||||
params,
|
||||
params = {"app_id": app_id}
|
||||
if not fetch_all:
|
||||
params["session_id"] = session_id
|
||||
results = (
|
||||
self.db_session.query(ChatHistoryModel).filter_by(**params).order_by(ChatHistoryModel.created_at.asc())
|
||||
)
|
||||
|
||||
results = self.cursor.fetchall()
|
||||
results = results.limit(num_rounds) if not fetch_all else results
|
||||
history = []
|
||||
for result in results:
|
||||
app_id, _, session_id, question, answer, metadata, timestamp = result
|
||||
metadata = self._deserialize_json(metadata=metadata)
|
||||
metadata = self._deserialize_json(metadata=result.meta_data or "{}")
|
||||
# Return list of dict if display_format is True
|
||||
if display_format:
|
||||
history.append(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"human": question,
|
||||
"ai": answer,
|
||||
"metadata": metadata,
|
||||
"timestamp": timestamp,
|
||||
"session_id": result.session_id,
|
||||
"human": result.question,
|
||||
"ai": result.answer,
|
||||
"metadata": result.meta_data,
|
||||
"timestamp": result.created_at,
|
||||
}
|
||||
)
|
||||
else:
|
||||
memory = ChatMessage()
|
||||
memory.add_user_message(question, metadata=metadata)
|
||||
memory.add_ai_message(answer, metadata=metadata)
|
||||
memory.add_user_message(result.question, metadata=metadata)
|
||||
memory.add_ai_message(result.answer, metadata=metadata)
|
||||
history.append(memory)
|
||||
return history
|
||||
|
||||
@@ -141,16 +107,11 @@ class ChatHistory:
|
||||
|
||||
:return: The number of chat messages for a given app_id and session_id
|
||||
"""
|
||||
# Rewrite the logic below with sqlalchemy
|
||||
params = {"app_id": app_id}
|
||||
if session_id:
|
||||
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
|
||||
params = (app_id, session_id)
|
||||
else:
|
||||
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
|
||||
params = (app_id,)
|
||||
|
||||
self.cursor.execute(QUERY, params)
|
||||
count = self.cursor.fetchone()[0]
|
||||
return count
|
||||
params["session_id"] = session_id
|
||||
return self.db_session.query(ChatHistoryModel).filter_by(**params).count()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_json(metadata: dict[str, Any]):
|
||||
|
||||
Reference in New Issue
Block a user