[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)

This commit is contained in:
Deshraj Yadav
2024-02-19 13:04:18 -08:00
committed by GitHub
parent 6c12bc9044
commit 5e2e7fb639
20 changed files with 601 additions and 202 deletions

View File

@@ -1,55 +1,40 @@
import json
import logging
import sqlite3
import uuid
from typing import Any, Optional
from embedchain.constants import SQLITE_PATH
from embedchain.core.db.database import get_session
from embedchain.core.db.models import ChatHistory as ChatHistoryModel
from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS ec_chat_history (
app_id TEXT,
id TEXT,
session_id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id, session_id)
)
"""
class ChatHistory:
def __init__(self) -> None:
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
self.cursor = self.connection.cursor()
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
self.connection.commit()
self.db_session = get_session()
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
memory_id = str(uuid.uuid4())
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
if metadata_dict:
metadata = self._serialize_json(metadata_dict)
ADD_CHAT_MESSAGE_QUERY = """
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
VALUES (?, ?, ?, ?, ?, ?)
"""
self.cursor.execute(
ADD_CHAT_MESSAGE_QUERY,
(
app_id,
memory_id,
session_id,
chat_message.human_message.content,
chat_message.ai_message.content,
metadata if metadata_dict else "{}",
),
self.db_session.add(
ChatHistoryModel(
app_id=app_id,
id=memory_id,
session_id=session_id,
question=chat_message.human_message.content,
answer=chat_message.ai_message.content,
metadata=metadata if metadata_dict else "{}",
)
)
self.connection.commit()
try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error adding chat memory to db: {e}")
self.db_session.rollback()
return None
logging.info(f"Added chat memory to db with id: {memory_id}")
return memory_id
@@ -63,15 +48,15 @@ class ChatHistory:
:return: None
"""
params = {"app_id": app_id}
if session_id:
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
params = (app_id, session_id)
else:
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
params = (app_id,)
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
self.connection.commit()
params["session_id"] = session_id
self.db_session.query(ChatHistoryModel).filter_by(**params).delete()
try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error deleting chat history: {e}")
self.db_session.rollback()
def get(
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
@@ -85,50 +70,31 @@ class ChatHistory:
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
"""
base_query = """
SELECT * FROM ec_chat_history
WHERE app_id=?
"""
if fetch_all:
additional_query = "ORDER BY created_at ASC"
params = (app_id,)
else:
additional_query = """
AND session_id=?
ORDER BY created_at ASC
LIMIT ?
"""
params = (app_id, session_id, num_rounds)
QUERY = base_query + additional_query
self.cursor.execute(
QUERY,
params,
params = {"app_id": app_id}
if not fetch_all:
params["session_id"] = session_id
results = (
self.db_session.query(ChatHistoryModel).filter_by(**params).order_by(ChatHistoryModel.created_at.asc())
)
results = self.cursor.fetchall()
results = results.limit(num_rounds) if not fetch_all else results
history = []
for result in results:
app_id, _, session_id, question, answer, metadata, timestamp = result
metadata = self._deserialize_json(metadata=metadata)
metadata = self._deserialize_json(metadata=result.meta_data or "{}")
# Return list of dict if display_format is True
if display_format:
history.append(
{
"session_id": session_id,
"human": question,
"ai": answer,
"metadata": metadata,
"timestamp": timestamp,
"session_id": result.session_id,
"human": result.question,
"ai": result.answer,
"metadata": result.meta_data,
"timestamp": result.created_at,
}
)
else:
memory = ChatMessage()
memory.add_user_message(question, metadata=metadata)
memory.add_ai_message(answer, metadata=metadata)
memory.add_user_message(result.question, metadata=metadata)
memory.add_ai_message(result.answer, metadata=metadata)
history.append(memory)
return history
@@ -141,16 +107,11 @@ class ChatHistory:
:return: The number of chat messages for a given app_id and session_id
"""
# Rewrite the logic below with sqlalchemy
params = {"app_id": app_id}
if session_id:
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
params = (app_id, session_id)
else:
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
params = (app_id,)
self.cursor.execute(QUERY, params)
count = self.cursor.fetchone()[0]
return count
params["session_id"] = session_id
return self.db_session.query(ChatHistoryModel).filter_by(**params).count()
@staticmethod
def _serialize_json(metadata: dict[str, Any]):