[Improvement] Use SQLite for chat memory (#910)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-09 13:56:28 -08:00
committed by GitHub
parent 9d3568ef75
commit 654fd8d74c
15 changed files with 396 additions and 48 deletions

View File

@@ -5,7 +5,7 @@ import uuid
import requests
from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
class Client:

8
embedchain/constants.py Normal file
View File

@@ -0,0 +1,8 @@
import os
from pathlib import Path
ABS_PATH = os.getcwd()
HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")

View File

@@ -1,9 +1,7 @@
import hashlib
import json
import logging
import os
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv
@@ -12,6 +10,7 @@ from langchain.docstore.document import Document
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.apps.base_app_config import BaseAppConfig
from embedchain.constants import SQLITE_PATH
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
@@ -25,12 +24,6 @@ from embedchain.vectordb.base import BaseVectorDB
load_dotenv()
ABS_PATH = os.getcwd()
HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
class EmbedChain(JSONSerializable):
def __init__(
@@ -602,6 +595,9 @@ class EmbedChain(JSONSerializable):
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# add conversation in memory
self.llm.add_history(self.config.id, input_query, answer)
# Send anonymous telemetry
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
@@ -645,5 +641,9 @@ class EmbedChain(JSONSerializable):
self.db.reset()
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
self.connection.commit()
self.clear_history()
# Send anonymous telemetry
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
def clear_history(self):
self.llm.memory.delete_chat_history(app_id=self.config.id)

View File

@@ -1,14 +1,15 @@
import logging
from typing import Any, Dict, Generator, List, Optional
from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.memory.base import ECChatMemory
from embedchain.memory.message import ChatMessage
class BaseLlm(JSONSerializable):
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
else:
self.config = config
self.memory = ConversationBufferMemory()
self.memory = ECChatMemory()
self.is_docs_site_instance = False
self.online = False
self.history: Any = None
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
"""
self.history = history
def update_history(self):
def update_history(self, app_id: str):
"""Update class history attribute with history in memory (for chat method)"""
chat_history = self.memory.load_memory_variables({})["history"]
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
if chat_history:
self.set_history(chat_history)
self.set_history([str(history) for history in chat_history])
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
chat_message = ChatMessage()
chat_message.add_user_message(question, metadata=metadata)
chat_message.add_ai_message(answer, metadata=metadata)
self.memory.add(app_id=app_id, chat_message=chat_message)
self.update_history(app_id=app_id)
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
"""
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
self.update_history()
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
answer = self.get_answer_from_llm(prompt)
self.memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
# NOTE: Adding to history before and after. This could be seen as redundant.
# If we change it, we have to change the tests (no big deal).
self.update_history()
return answer
else:
# this is a streamed response and needs to be handled differently.
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
"""
Construct a list of langchain messages

View File

112
embedchain/memory/base.py Normal file
View File

@@ -0,0 +1,112 @@
import json
import logging
import sqlite3
import uuid
from typing import Any, Dict, List, Optional
from embedchain.constants import SQLITE_PATH
from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS chat_history (
app_id TEXT,
id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id)
)
"""
class ECChatMemory:
def __init__(self) -> None:
with sqlite3.connect(SQLITE_PATH) as self.connection:
self.cursor = self.connection.cursor()
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
self.connection.commit()
def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
memory_id = str(uuid.uuid4())
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
if metadata_dict:
metadata = self._serialize_json(metadata_dict)
ADD_CHAT_MESSAGE_QUERY = """
INSERT INTO chat_history (app_id, id, question, answer, metadata)
VALUES (?, ?, ?, ?, ?)
"""
self.cursor.execute(
ADD_CHAT_MESSAGE_QUERY,
(
app_id,
memory_id,
chat_message.human_message.content,
chat_message.ai_message.content,
metadata if metadata_dict else "{}",
),
)
self.connection.commit()
logging.info(f"Added chat memory to db with id: {memory_id}")
return memory_id
def delete_chat_history(self, app_id: str):
DELETE_CHAT_HISTORY_QUERY = """
DELETE FROM chat_history WHERE app_id=?
"""
self.cursor.execute(
DELETE_CHAT_HISTORY_QUERY,
(app_id,),
)
self.connection.commit()
def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
"""
Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id.
"""
QUERY = """
SELECT * FROM chat_history
WHERE app_id=?
ORDER BY created_at DESC
LIMIT ?
"""
self.cursor.execute(
QUERY,
(app_id, num_rounds),
)
results = self.cursor.fetchall()
history = []
for result in results:
app_id, id, question, answer, metadata, timestamp = result
metadata = self._deserialize_json(metadata=metadata)
memory = ChatMessage()
memory.add_user_message(question, metadata=metadata)
memory.add_ai_message(answer, metadata=metadata)
history.append(memory)
return history
def _serialize_json(self, metadata: Dict[str, Any]):
return json.dumps(metadata)
def _deserialize_json(self, metadata: str):
return json.loads(metadata)
def close_connection(self):
self.connection.close()
def count_history_messages(self, app_id: str):
QUERY = """
SELECT COUNT(*) FROM chat_history
WHERE app_id=?
"""
self.cursor.execute(
QUERY,
(app_id,),
)
count = self.cursor.fetchone()[0]
return count

View File

@@ -0,0 +1,72 @@
import logging
from typing import Any, Dict, Optional
from embedchain.helper.json_serializable import JSONSerializable
class BaseMessage(JSONSerializable):
"""
The base abstract message class.
Messages are the inputs and outputs of Models.
"""
# The string content of the message.
content: str
# The creator of the message. AI, Human, Bot etc.
by: str
# Any additional info.
metadata: Dict[str, Any]
def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
super().__init__()
self.content = content
self.creator = creator
self.metadata = metadata
@property
def type(self) -> str:
"""Type of the Message, used for serialization."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
def __str__(self) -> str:
return f"{self.creator}: {self.content}"
class ChatMessage(JSONSerializable):
"""
The base abstract chat message class.
Chat messages are the pair of (question, answer) conversation
between human and model.
"""
human_message: Optional[BaseMessage] = None
ai_message: Optional[BaseMessage] = None
def add_user_message(self, message: str, metadata: Optional[dict] = None):
if self.human_message:
logging.info(
"Human message already exists in the chat message,\
overwritting it with new message."
)
self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
if self.ai_message:
logging.info(
"AI message already exists in the chat message,\
overwritting it with new message."
)
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
def __str__(self) -> str:
return f"{self.human_message} | {self.ai_message}"

View File

@@ -0,0 +1,35 @@
from typing import Any, Dict, Optional
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
Merge the metadatas of two BaseMessage types.
Args:
left (Dict[str, Any]): metadata of human message
right (Dict[str, Any]): metadata of ai message
Returns:
Dict[str, Any]: combined metadata dict with dedup
to be saved in db.
"""
if not left and not right:
return None
elif not left:
return right
elif not right:
return left
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.")
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = merge_metadata_dict(merged[k], v)
else:
raise ValueError(f"Additional kwargs key {k} already exists in this message.")
return merged

View File

@@ -10,7 +10,8 @@ import yaml
from embedchain import Client
from embedchain.config import ChunkerConfig, PipelineConfig
from embedchain.embedchain import CONFIG_DIR, EmbedChain
from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
@@ -22,8 +23,6 @@ from embedchain.utils import validate_yaml_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
@register_deserializable
class Pipeline(EmbedChain):

View File

@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")