[Improvement] Use SQLite for chat memory (#910)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-09 13:56:28 -08:00
committed by GitHub
parent 9d3568ef75
commit 654fd8d74c
15 changed files with 396 additions and 48 deletions

View File

112
embedchain/memory/base.py Normal file
View File

@@ -0,0 +1,112 @@
import json
import logging
import sqlite3
import uuid
from typing import Any, Dict, List, Optional
from embedchain.constants import SQLITE_PATH
from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS chat_history (
app_id TEXT,
id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id)
)
"""
class ECChatMemory:
def __init__(self) -> None:
with sqlite3.connect(SQLITE_PATH) as self.connection:
self.cursor = self.connection.cursor()
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
self.connection.commit()
def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
memory_id = str(uuid.uuid4())
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
if metadata_dict:
metadata = self._serialize_json(metadata_dict)
ADD_CHAT_MESSAGE_QUERY = """
INSERT INTO chat_history (app_id, id, question, answer, metadata)
VALUES (?, ?, ?, ?, ?)
"""
self.cursor.execute(
ADD_CHAT_MESSAGE_QUERY,
(
app_id,
memory_id,
chat_message.human_message.content,
chat_message.ai_message.content,
metadata if metadata_dict else "{}",
),
)
self.connection.commit()
logging.info(f"Added chat memory to db with id: {memory_id}")
return memory_id
def delete_chat_history(self, app_id: str):
DELETE_CHAT_HISTORY_QUERY = """
DELETE FROM chat_history WHERE app_id=?
"""
self.cursor.execute(
DELETE_CHAT_HISTORY_QUERY,
(app_id,),
)
self.connection.commit()
def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
"""
Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id.
"""
QUERY = """
SELECT * FROM chat_history
WHERE app_id=?
ORDER BY created_at DESC
LIMIT ?
"""
self.cursor.execute(
QUERY,
(app_id, num_rounds),
)
results = self.cursor.fetchall()
history = []
for result in results:
app_id, id, question, answer, metadata, timestamp = result
metadata = self._deserialize_json(metadata=metadata)
memory = ChatMessage()
memory.add_user_message(question, metadata=metadata)
memory.add_ai_message(answer, metadata=metadata)
history.append(memory)
return history
def _serialize_json(self, metadata: Dict[str, Any]):
return json.dumps(metadata)
def _deserialize_json(self, metadata: str):
return json.loads(metadata)
def close_connection(self):
self.connection.close()
def count_history_messages(self, app_id: str):
QUERY = """
SELECT COUNT(*) FROM chat_history
WHERE app_id=?
"""
self.cursor.execute(
QUERY,
(app_id,),
)
count = self.cursor.fetchone()[0]
return count

View File

@@ -0,0 +1,72 @@
import logging
from typing import Any, Dict, Optional
from embedchain.helper.json_serializable import JSONSerializable
class BaseMessage(JSONSerializable):
"""
The base abstract message class.
Messages are the inputs and outputs of Models.
"""
# The string content of the message.
content: str
# The creator of the message. AI, Human, Bot etc.
by: str
# Any additional info.
metadata: Dict[str, Any]
def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
super().__init__()
self.content = content
self.creator = creator
self.metadata = metadata
@property
def type(self) -> str:
"""Type of the Message, used for serialization."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __str__(self) -> str:
return f"{self.creator}: {self.content}"
class ChatMessage(JSONSerializable):
"""
The base abstract chat message class.
Chat messages are the pair of (question, answer) conversation
between human and model.
"""
human_message: Optional[BaseMessage] = None
ai_message: Optional[BaseMessage] = None
def add_user_message(self, message: str, metadata: Optional[dict] = None):
if self.human_message:
logging.info(
"Human message already exists in the chat message,\
overwritting it with new message."
)
self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
if self.ai_message:
logging.info(
"AI message already exists in the chat message,\
overwritting it with new message."
)
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
def __str__(self) -> str:
return f"{self.human_message} | {self.ai_message}"

View File

@@ -0,0 +1,35 @@
from typing import Any, Dict, Optional
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
Merge the metadatas of two BaseMessage types.
Args:
left (Dict[str, Any]): metadata of human message
right (Dict[str, Any]): metadata of ai message
Returns:
Dict[str, Any]: combined metadata dict with dedup
to be saved in db.
"""
if not left and not right:
return None
elif not left:
return right
elif not right:
return left
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.")
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = merge_metadata_dict(merged[k], v)
else:
raise ValueError(f"Additional kwargs key {k} already exists in this message.")
return merged