[Bug Fix] Handle chat sessions properly during app.chat() calls (#1084)

This commit is contained in:
Deshraj Yadav
2023-12-30 15:36:24 +05:30
committed by GitHub
parent 52b4577d3b
commit a54dde0509
10 changed files with 124 additions and 187 deletions

View File

@@ -7,7 +7,7 @@ from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.llm.base import BaseLlm
from embedchain.memory.base import ECChatMemory
from embedchain.memory.base import ChatHistory
from embedchain.vectordb.chroma import ChromaDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
@@ -31,7 +31,7 @@ def test_whole_app(app_instance, mocker):
BaseLlm,
"add_history",
)
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
mocker.patch.object(ChatHistory, "delete", autospec=True)
app_instance.add(knowledge, data_type="text")
app_instance.query("What text did I give you?")
@@ -50,7 +50,7 @@ def test_add_after_reset(app_instance, mocker):
app_instance = App(config=config, db=db)
# mock delete chat history
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
mocker.patch.object(ChatHistory, "delete", autospec=True)
app_instance.reset()

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.base import BaseLlm
from embedchain.memory.base import ECChatMemory
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -36,11 +36,11 @@ class TestApp(unittest.TestCase):
with patch.object(BaseLlm, "add_history") as mock_history:
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer", session_id="default")
second_answer = app.chat("Test query 2")
second_answer = app.chat("Test query 2", session_id="test_session")
self.assertEqual(second_answer, "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer", session_id="test_session")
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
@@ -51,7 +51,7 @@ class TestApp(unittest.TestCase):
Also tests that a dry run does not change the history
"""
with patch.object(ECChatMemory, "get_recent_memories") as mock_memory:
with patch.object(ChatHistory, "get") as mock_memory:
mock_message = ChatMessage()
mock_message.add_user_message("Test query 1")
mock_message.add_ai_message("Test answer")

View File

@@ -1,17 +1,18 @@
import pytest
from embedchain.memory.base import ECChatMemory
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
# Fixture for creating an instance of ECChatMemory
# Fixture for creating an instance of ChatHistory
@pytest.fixture
def chat_memory_instance():
return ECChatMemory()
return ChatHistory()
def test_add_chat_memory(chat_memory_instance):
app_id = "test_app"
session_id = "test_session"
human_message = "Hello, how are you?"
ai_message = "I'm fine, thank you!"
@@ -19,14 +20,15 @@ def test_add_chat_memory(chat_memory_instance):
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, chat_message)
chat_memory_instance.add(app_id, session_id, chat_message)
assert chat_memory_instance.count_history_messages(app_id) == 1
chat_memory_instance.delete_chat_history(app_id)
assert chat_memory_instance.count(app_id, session_id) == 1
chat_memory_instance.delete(app_id, session_id)
def test_get_recent_memories(chat_memory_instance):
def test_get(chat_memory_instance):
app_id = "test_app"
session_id = "test_session"
for i in range(1, 7):
human_message = f"Question {i}"
@@ -36,15 +38,16 @@ def test_get_recent_memories(chat_memory_instance):
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, chat_message)
chat_memory_instance.add(app_id, session_id, chat_message)
recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
recent_memories = chat_memory_instance.get(app_id, session_id, num_rounds=5)
assert len(recent_memories) == 5
def test_delete_chat_history(chat_memory_instance):
app_id = "test_app"
session_id = "test_session"
for i in range(1, 6):
human_message = f"Question {i}"
@@ -54,11 +57,11 @@ def test_delete_chat_history(chat_memory_instance):
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, chat_message)
chat_memory_instance.add(app_id, session_id, chat_message)
chat_memory_instance.delete_chat_history(app_id)
chat_memory_instance.delete(app_id, session_id)
assert chat_memory_instance.count_history_messages(app_id) == 0
assert chat_memory_instance.count(app_id, session_id) == 0
@pytest.fixture

View File

@@ -3,17 +3,17 @@ from embedchain.memory.message import BaseMessage, ChatMessage
def test_ec_base_message():
content = "Hello, how are you?"
creator = "human"
created_by = "human"
metadata = {"key": "value"}
message = BaseMessage(content=content, creator=creator, metadata=metadata)
message = BaseMessage(content=content, created_by=created_by, metadata=metadata)
assert message.content == content
assert message.creator == creator
assert message.created_by == created_by
assert message.metadata == metadata
assert message.type is None
assert message.is_lc_serializable() is True
assert str(message) == f"{creator}: {content}"
assert str(message) == f"{created_by}: {content}"
def test_ec_base_chat_message():
@@ -27,11 +27,11 @@ def test_ec_base_chat_message():
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
assert chat_message.human_message.content == human_message_content
assert chat_message.human_message.creator == "human"
assert chat_message.human_message.created_by == "human"
assert chat_message.human_message.metadata == human_metadata
assert chat_message.ai_message.content == ai_message_content
assert chat_message.ai_message.creator == "ai"
assert chat_message.ai_message.created_by == "ai"
assert chat_message.ai_message.metadata == ai_metadata
assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"