[Improvement] Use SQLite for chat memory (#910)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-09 13:56:28 -08:00
committed by GitHub
parent 9d3568ef75
commit 654fd8d74c
15 changed files with 396 additions and 48 deletions

View File

@@ -1,14 +1,15 @@
import logging
from typing import Any, Dict, Generator, List, Optional
from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.memory.base import ECChatMemory
from embedchain.memory.message import ChatMessage
class BaseLlm(JSONSerializable):
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
else:
self.config = config
self.memory = ConversationBufferMemory()
self.memory = ECChatMemory()
self.is_docs_site_instance = False
self.online = False
self.history: Any = None
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
"""
self.history = history
def update_history(self):
def update_history(self, app_id: str):
"""Update class history attribute with history in memory (for chat method)"""
chat_history = self.memory.load_memory_variables({})["history"]
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
if chat_history:
self.set_history(chat_history)
self.set_history([str(history) for history in chat_history])
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
chat_message = ChatMessage()
chat_message.add_user_message(question, metadata=metadata)
chat_message.add_ai_message(answer, metadata=metadata)
self.memory.add(app_id=app_id, chat_message=chat_message)
self.update_history(app_id=app_id)
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
"""
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
self.update_history()
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
answer = self.get_answer_from_llm(prompt)
self.memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
# NOTE: Adding to history before and after. This could be seen as redundant.
# If we change it, we have to change the tests (no big deal).
self.update_history()
return answer
else:
# this is a streamed response and needs to be handled differently.
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
"""
Construct a list of langchain messages