[Improvement] update LLM memory get function (#1162)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -667,8 +667,11 @@ class EmbedChain(JSONSerializable):
|
|||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||||
|
|
||||||
def get_history(self, num_rounds: int = 10, display_format: bool = True):
|
def get_history(self, num_rounds: int = 10, display_format: bool = True, session_id: Optional[str] = "default"):
|
||||||
return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format)
|
history = self.llm.memory.get(
|
||||||
|
app_id=self.config.id, session_id=session_id, num_rounds=num_rounds, display_format=display_format
|
||||||
|
)
|
||||||
|
return history
|
||||||
|
|
||||||
def delete_session_chat_history(self, session_id: str = "default"):
|
def delete_session_chat_history(self, session_id: str = "default"):
|
||||||
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
|
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
|
||||||
|
|||||||
@@ -73,21 +73,40 @@ class ChatHistory:
|
|||||||
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
|
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
|
def get(
|
||||||
|
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
|
||||||
|
) -> list[ChatMessage]:
|
||||||
"""
|
"""
|
||||||
Get the most recent num_rounds rounds of conversations
|
Get the chat history for a given app_id.
|
||||||
between human and AI, for a given app_id.
|
|
||||||
|
param: app_id - The app_id to get chat history
|
||||||
|
param: session_id (optional) - The session_id to get chat history. Defaults to "default"
|
||||||
|
param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
|
||||||
|
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
|
||||||
|
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
QUERY = """
|
base_query = """
|
||||||
SELECT * FROM ec_chat_history
|
SELECT * FROM ec_chat_history
|
||||||
WHERE app_id=? AND session_id=?
|
WHERE app_id=?
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT ?
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if fetch_all:
|
||||||
|
additional_query = "ORDER BY created_at DESC"
|
||||||
|
params = (app_id,)
|
||||||
|
else:
|
||||||
|
additional_query = """
|
||||||
|
AND session_id=?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
params = (app_id, session_id, num_rounds)
|
||||||
|
|
||||||
|
QUERY = base_query + additional_query
|
||||||
|
|
||||||
self.cursor.execute(
|
self.cursor.execute(
|
||||||
QUERY,
|
QUERY,
|
||||||
(app_id, session_id, num_rounds),
|
params,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.cursor.fetchall()
|
results = self.cursor.fetchall()
|
||||||
@@ -97,7 +116,15 @@ class ChatHistory:
|
|||||||
metadata = self._deserialize_json(metadata=metadata)
|
metadata = self._deserialize_json(metadata=metadata)
|
||||||
# Return list of dict if display_format is True
|
# Return list of dict if display_format is True
|
||||||
if display_format:
|
if display_format:
|
||||||
history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
|
history.append(
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"human": question,
|
||||||
|
"ai": answer,
|
||||||
|
"metadata": metadata,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
memory = ChatMessage()
|
memory = ChatMessage()
|
||||||
memory.add_user_message(question, metadata=metadata)
|
memory.add_user_message(question, metadata=metadata)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.62"
|
version = "0.1.63"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ def test_get(chat_memory_instance):
|
|||||||
|
|
||||||
assert len(recent_memories) == 5
|
assert len(recent_memories) == 5
|
||||||
|
|
||||||
|
all_memories = chat_memory_instance.get(app_id, fetch_all=True)
|
||||||
|
|
||||||
|
assert len(all_memories) == 6
|
||||||
|
|
||||||
|
|
||||||
def test_delete_chat_history(chat_memory_instance):
|
def test_delete_chat_history(chat_memory_instance):
|
||||||
app_id = "test_app"
|
app_id = "test_app"
|
||||||
|
|||||||
Reference in New Issue
Block a user