From 17129e2eaa7c4876c5e4029c86e132435c0ec8a8 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Thu, 9 Nov 2023 15:17:51 -0800 Subject: [PATCH] [Improvement] Add support for reloading history for an existing app (#930) --- embedchain/embedchain.py | 14 ++++++++++++-- embedchain/memory/base.py | 16 ++++++++++------ embedchain/memory/message.py | 2 +- pyproject.toml | 4 +--- tests/memory/test_memory_messages.py | 2 +- 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 373bf635..84ac5ff7 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -74,6 +74,9 @@ class EmbedChain(JSONSerializable): if system_prompt: self.llm.config.system_prompt = system_prompt + # Fetch the history from the database if exists + self.llm.update_history(app_id=self.config.id) + # Attributes that aren't subclass related. self.user_asks = [] @@ -641,9 +644,16 @@ class EmbedChain(JSONSerializable): self.db.reset() self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,)) self.connection.commit() - self.clear_history() + self.delete_history() # Send anonymous telemetry self.telemetry.capture(event_name="reset", properties=self._telemetry_props) - def clear_history(self): + def get_history(self, num_rounds: int = 10, display_format: bool = True): + return self.llm.memory.get_recent_memories( + app_id=self.config.id, + num_rounds=num_rounds, + display_format=display_format, + ) + + def delete_history(self): self.llm.memory.delete_chat_history(app_id=self.config.id) diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py index 83bb146b..d3386664 100644 --- a/embedchain/memory/base.py +++ b/embedchain/memory/base.py @@ -62,7 +62,7 @@ class ECChatMemory: ) self.connection.commit() - def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]: + def get_recent_memories(self, app_id, num_rounds=10, display_format=False) -> List[ChatMessage]: """ Get the most recent num_rounds rounds of conversations between human and AI, for a given app_id. @@ -82,12 +82,16 @@ class ECChatMemory: results = self.cursor.fetchall() history = [] for result in results: - app_id, id, question, answer, metadata, timestamp = result + app_id, _, question, answer, metadata, timestamp = result metadata = self._deserialize_json(metadata=metadata) - memory = ChatMessage() - memory.add_user_message(question, metadata=metadata) - memory.add_ai_message(answer, metadata=metadata) - history.append(memory) + # Return list of dict if display_format is True + if display_format: + history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp}) + else: + memory = ChatMessage() + memory.add_user_message(question, metadata=metadata) + memory.add_ai_message(answer, metadata=metadata) + history.append(memory) return history def _serialize_json(self, metadata: Dict[str, Any]): diff --git a/embedchain/memory/message.py b/embedchain/memory/message.py index 3fe74af2..383b7c0f 100644 --- a/embedchain/memory/message.py +++ b/embedchain/memory/message.py @@ -69,4 +69,4 @@ class ChatMessage(JSONSerializable): self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata) def __str__(self) -> str: - return f"{self.human_message} | {self.ai_message}" + return f"{self.human_message}\n{self.ai_message}" diff --git a/pyproject.toml b/pyproject.toml index 8eb4cf73..67fb9648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.3" +version = "0.1.4" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", @@ -191,6 +191,4 @@ postgres = ["psycopg", "psycopg-binary", "psycopg-pool"] [tool.poetry.group.docs.dependencies] - - [tool.poetry.scripts] diff --git a/tests/memory/test_memory_messages.py b/tests/memory/test_memory_messages.py index 81adc79c..75d915a3 100644 --- a/tests/memory/test_memory_messages.py +++ b/tests/memory/test_memory_messages.py @@ -34,4 +34,4 @@ def test_ec_base_chat_message(): assert chat_message.ai_message.creator == "ai" assert chat_message.ai_message.metadata == ai_metadata - assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}" + assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"