[Improvement] Use SQLite for chat memory (#910)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
class BaseLlm(JSONSerializable):
|
||||
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.memory = ConversationBufferMemory()
|
||||
self.memory = ECChatMemory()
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.history: Any = None
|
||||
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
self.history = history
|
||||
|
||||
def update_history(self):
|
||||
def update_history(self, app_id: str):
|
||||
"""Update class history attribute with history in memory (for chat method)"""
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
|
||||
if chat_history:
|
||||
self.set_history(chat_history)
|
||||
self.set_history([str(history) for history in chat_history])
|
||||
|
||||
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(question, metadata=metadata)
|
||||
chat_message.add_ai_message(answer, metadata=metadata)
|
||||
self.memory.add(app_id=app_id, chat_message=chat_message)
|
||||
self.update_history(app_id=app_id)
|
||||
|
||||
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
|
||||
self.update_history()
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
self.memory.chat_memory.add_user_message(input_query)
|
||||
|
||||
if isinstance(answer, str):
|
||||
self.memory.chat_memory.add_ai_message(answer)
|
||||
logging.info(f"Answer: {answer}")
|
||||
|
||||
# NOTE: Adding to history before and after. This could be seen as redundant.
|
||||
# If we change it, we have to change the tests (no big deal).
|
||||
self.update_history()
|
||||
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
|
||||
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
|
||||
"""
|
||||
Construct a list of langchain messages
|
||||
|
||||
|
||||
Reference in New Issue
Block a user