[Improvement] Add support for reloading history for an existing app (#930)
This commit is contained in:
@@ -74,6 +74,9 @@ class EmbedChain(JSONSerializable):
|
|||||||
if system_prompt:
|
if system_prompt:
|
||||||
self.llm.config.system_prompt = system_prompt
|
self.llm.config.system_prompt = system_prompt
|
||||||
|
|
||||||
|
# Fetch the history from the database if exists
|
||||||
|
self.llm.update_history(app_id=self.config.id)
|
||||||
|
|
||||||
# Attributes that aren't subclass related.
|
# Attributes that aren't subclass related.
|
||||||
self.user_asks = []
|
self.user_asks = []
|
||||||
|
|
||||||
@@ -641,9 +644,16 @@ class EmbedChain(JSONSerializable):
|
|||||||
self.db.reset()
|
self.db.reset()
|
||||||
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
self.clear_history()
|
self.delete_history()
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||||
|
|
||||||
def clear_history(self):
|
def get_history(self, num_rounds: int = 10, display_format: bool = True):
|
||||||
|
return self.llm.memory.get_recent_memories(
|
||||||
|
app_id=self.config.id,
|
||||||
|
num_rounds=num_rounds,
|
||||||
|
display_format=display_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_history(self):
|
||||||
self.llm.memory.delete_chat_history(app_id=self.config.id)
|
self.llm.memory.delete_chat_history(app_id=self.config.id)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class ECChatMemory:
|
|||||||
)
|
)
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
|
def get_recent_memories(self, app_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
|
||||||
"""
|
"""
|
||||||
Get the most recent num_rounds rounds of conversations
|
Get the most recent num_rounds rounds of conversations
|
||||||
between human and AI, for a given app_id.
|
between human and AI, for a given app_id.
|
||||||
@@ -82,12 +82,16 @@ class ECChatMemory:
|
|||||||
results = self.cursor.fetchall()
|
results = self.cursor.fetchall()
|
||||||
history = []
|
history = []
|
||||||
for result in results:
|
for result in results:
|
||||||
app_id, id, question, answer, metadata, timestamp = result
|
app_id, _, question, answer, metadata, timestamp = result
|
||||||
metadata = self._deserialize_json(metadata=metadata)
|
metadata = self._deserialize_json(metadata=metadata)
|
||||||
memory = ChatMessage()
|
# Return list of dict if display_format is True
|
||||||
memory.add_user_message(question, metadata=metadata)
|
if display_format:
|
||||||
memory.add_ai_message(answer, metadata=metadata)
|
history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
|
||||||
history.append(memory)
|
else:
|
||||||
|
memory = ChatMessage()
|
||||||
|
memory.add_user_message(question, metadata=metadata)
|
||||||
|
memory.add_ai_message(answer, metadata=metadata)
|
||||||
|
history.append(memory)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
def _serialize_json(self, metadata: Dict[str, Any]):
|
def _serialize_json(self, metadata: Dict[str, Any]):
|
||||||
|
|||||||
@@ -69,4 +69,4 @@ class ChatMessage(JSONSerializable):
|
|||||||
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
|
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.human_message} | {self.ai_message}"
|
return f"{self.human_message}\n{self.ai_message}"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.3"
|
version = "0.1.4"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
@@ -191,6 +191,4 @@ postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
|
|||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
|||||||
@@ -34,4 +34,4 @@ def test_ec_base_chat_message():
|
|||||||
assert chat_message.ai_message.creator == "ai"
|
assert chat_message.ai_message.creator == "ai"
|
||||||
assert chat_message.ai_message.metadata == ai_metadata
|
assert chat_message.ai_message.metadata == ai_metadata
|
||||||
|
|
||||||
assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"
|
assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"
|
||||||
|
|||||||
Reference in New Issue
Block a user