[Improvement] Use SQLite for chat memory (#910)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
@@ -25,6 +26,11 @@ def test_whole_app(app_instance, mocker):
|
||||
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "generate_prompt")
|
||||
mocker.patch.object(
|
||||
BaseLlm,
|
||||
"add_history",
|
||||
)
|
||||
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||
|
||||
app_instance.add(knowledge, data_type="text")
|
||||
app_instance.query("What text did I give you?")
|
||||
@@ -41,6 +47,10 @@ def test_add_after_reset(app_instance, mocker):
|
||||
chroma_config = {"allow_reset": True}
|
||||
|
||||
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||
|
||||
# mock delete chat history
|
||||
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||
|
||||
app_instance.reset()
|
||||
|
||||
app_instance.db.client.heartbeat()
|
||||
|
||||
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -31,14 +33,14 @@ class TestApp(unittest.TestCase):
|
||||
"""
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 4)
|
||||
with patch.object(BaseLlm, "add_history") as mock_history:
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer")
|
||||
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
|
||||
|
||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
@@ -49,16 +51,22 @@ class TestApp(unittest.TestCase):
|
||||
|
||||
Also tests that a dry run does not change the history
|
||||
"""
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
history = app.llm.history
|
||||
dry_run = app.chat("Test query 2", dry_run=True)
|
||||
self.assertIn("History:", dry_run)
|
||||
self.assertEqual(history, app.llm.history)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
with patch.object(ECChatMemory, "get_recent_memories") as mock_memory:
|
||||
mock_message = ChatMessage()
|
||||
mock_message.add_user_message("Test query 1")
|
||||
mock_message.add_ai_message("Test answer")
|
||||
mock_memory.return_value = [mock_message]
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
history = app.llm.history
|
||||
dry_run = app.chat("Test query 2", dry_run=True)
|
||||
self.assertIn("History:", dry_run)
|
||||
self.assertEqual(history, app.llm.history)
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
|
||||
67
tests/memory/test_chat_memory.py
Normal file
67
tests/memory/test_chat_memory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
# Fixture for creating an instance of ECChatMemory
|
||||
@pytest.fixture
|
||||
def chat_memory_instance():
|
||||
return ECChatMemory()
|
||||
|
||||
|
||||
def test_add_chat_memory(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
human_message = "Hello, how are you?"
|
||||
ai_message = "I'm fine, thank you!"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
|
||||
assert chat_memory_instance.count_history_messages(app_id) == 1
|
||||
chat_memory_instance.delete_chat_history(app_id)
|
||||
|
||||
|
||||
def test_get_recent_memories(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
|
||||
for i in range(1, 7):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
|
||||
recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
|
||||
|
||||
assert len(recent_memories) == 5
|
||||
|
||||
|
||||
def test_delete_chat_history(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
|
||||
for i in range(1, 6):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
|
||||
chat_memory_instance.delete_chat_history(app_id)
|
||||
|
||||
assert chat_memory_instance.count_history_messages(app_id) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def close_connection(chat_memory_instance):
|
||||
yield
|
||||
chat_memory_instance.close_connection()
|
||||
37
tests/memory/test_memory_messages.py
Normal file
37
tests/memory/test_memory_messages.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from embedchain.memory.message import BaseMessage, ChatMessage
|
||||
|
||||
|
||||
def test_ec_base_message():
|
||||
content = "Hello, how are you?"
|
||||
creator = "human"
|
||||
metadata = {"key": "value"}
|
||||
|
||||
message = BaseMessage(content=content, creator=creator, metadata=metadata)
|
||||
|
||||
assert message.content == content
|
||||
assert message.creator == creator
|
||||
assert message.metadata == metadata
|
||||
assert message.type is None
|
||||
assert message.is_lc_serializable() is True
|
||||
assert str(message) == f"{creator}: {content}"
|
||||
|
||||
|
||||
def test_ec_base_chat_message():
|
||||
human_message_content = "Hello, how are you?"
|
||||
ai_message_content = "I'm fine, thank you!"
|
||||
human_metadata = {"user": "John"}
|
||||
ai_metadata = {"response_time": 0.5}
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message_content, metadata=human_metadata)
|
||||
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
|
||||
|
||||
assert chat_message.human_message.content == human_message_content
|
||||
assert chat_message.human_message.creator == "human"
|
||||
assert chat_message.human_message.metadata == human_metadata
|
||||
|
||||
assert chat_message.ai_message.content == ai_message_content
|
||||
assert chat_message.ai_message.creator == "ai"
|
||||
assert chat_message.ai_message.metadata == ai_metadata
|
||||
|
||||
assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"
|
||||
@@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
|
||||
from embedchain.utils import validate_yaml_config
|
||||
|
||||
CONFIG_YAMLS = [
|
||||
|
||||
Reference in New Issue
Block a user