[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
# Fixture for creating an instance of ECChatMemory
|
||||
# Fixture for creating an instance of ChatHistory
|
||||
@pytest.fixture
|
||||
def chat_memory_instance():
|
||||
return ECChatMemory()
|
||||
return ChatHistory()
|
||||
|
||||
|
||||
def test_add_chat_memory(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
session_id = "test_session"
|
||||
human_message = "Hello, how are you?"
|
||||
ai_message = "I'm fine, thank you!"
|
||||
|
||||
@@ -19,14 +20,15 @@ def test_add_chat_memory(chat_memory_instance):
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
chat_memory_instance.add(app_id, session_id, chat_message)
|
||||
|
||||
assert chat_memory_instance.count_history_messages(app_id) == 1
|
||||
chat_memory_instance.delete_chat_history(app_id)
|
||||
assert chat_memory_instance.count(app_id, session_id) == 1
|
||||
chat_memory_instance.delete(app_id, session_id)
|
||||
|
||||
|
||||
def test_get_recent_memories(chat_memory_instance):
|
||||
def test_get(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
session_id = "test_session"
|
||||
|
||||
for i in range(1, 7):
|
||||
human_message = f"Question {i}"
|
||||
@@ -36,15 +38,16 @@ def test_get_recent_memories(chat_memory_instance):
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
chat_memory_instance.add(app_id, session_id, chat_message)
|
||||
|
||||
recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
|
||||
recent_memories = chat_memory_instance.get(app_id, session_id, num_rounds=5)
|
||||
|
||||
assert len(recent_memories) == 5
|
||||
|
||||
|
||||
def test_delete_chat_history(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
session_id = "test_session"
|
||||
|
||||
for i in range(1, 6):
|
||||
human_message = f"Question {i}"
|
||||
@@ -54,11 +57,11 @@ def test_delete_chat_history(chat_memory_instance):
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
chat_memory_instance.add(app_id, session_id, chat_message)
|
||||
|
||||
chat_memory_instance.delete_chat_history(app_id)
|
||||
chat_memory_instance.delete(app_id, session_id)
|
||||
|
||||
assert chat_memory_instance.count_history_messages(app_id) == 0
|
||||
assert chat_memory_instance.count(app_id, session_id) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -3,17 +3,17 @@ from embedchain.memory.message import BaseMessage, ChatMessage
|
||||
|
||||
def test_ec_base_message():
|
||||
content = "Hello, how are you?"
|
||||
creator = "human"
|
||||
created_by = "human"
|
||||
metadata = {"key": "value"}
|
||||
|
||||
message = BaseMessage(content=content, creator=creator, metadata=metadata)
|
||||
message = BaseMessage(content=content, created_by=created_by, metadata=metadata)
|
||||
|
||||
assert message.content == content
|
||||
assert message.creator == creator
|
||||
assert message.created_by == created_by
|
||||
assert message.metadata == metadata
|
||||
assert message.type is None
|
||||
assert message.is_lc_serializable() is True
|
||||
assert str(message) == f"{creator}: {content}"
|
||||
assert str(message) == f"{created_by}: {content}"
|
||||
|
||||
|
||||
def test_ec_base_chat_message():
|
||||
@@ -27,11 +27,11 @@ def test_ec_base_chat_message():
|
||||
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
|
||||
|
||||
assert chat_message.human_message.content == human_message_content
|
||||
assert chat_message.human_message.creator == "human"
|
||||
assert chat_message.human_message.created_by == "human"
|
||||
assert chat_message.human_message.metadata == human_metadata
|
||||
|
||||
assert chat_message.ai_message.content == ai_message_content
|
||||
assert chat_message.ai_message.creator == "ai"
|
||||
assert chat_message.ai_message.created_by == "ai"
|
||||
assert chat_message.ai_message.metadata == ai_metadata
|
||||
|
||||
assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"
|
||||
|
||||
Reference in New Issue
Block a user