[Improvement] Use SQLite for chat memory (#910)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
|
||||
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
|
||||
|
||||
|
||||
class Client:
|
||||
|
||||
8
embedchain/constants.py
Normal file
8
embedchain/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
ABS_PATH = os.getcwd()
|
||||
HOME_DIR = str(Path.home())
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
@@ -1,9 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -12,6 +10,7 @@ from langchain.docstore.document import Document
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
@@ -25,12 +24,6 @@ from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ABS_PATH = os.getcwd()
|
||||
HOME_DIR = str(Path.home())
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
def __init__(
|
||||
@@ -602,6 +595,9 @@ class EmbedChain(JSONSerializable):
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
|
||||
# add conversation in memory
|
||||
self.llm.add_history(self.config.id, input_query, answer)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
||||
|
||||
@@ -645,5 +641,9 @@ class EmbedChain(JSONSerializable):
|
||||
self.db.reset()
|
||||
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||
self.connection.commit()
|
||||
self.clear_history()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||
|
||||
def clear_history(self):
|
||||
self.llm.memory.delete_chat_history(app_id=self.config.id)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
class BaseLlm(JSONSerializable):
|
||||
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.memory = ConversationBufferMemory()
|
||||
self.memory = ECChatMemory()
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.history: Any = None
|
||||
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
self.history = history
|
||||
|
||||
def update_history(self):
|
||||
def update_history(self, app_id: str):
|
||||
"""Update class history attribute with history in memory (for chat method)"""
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
|
||||
if chat_history:
|
||||
self.set_history(chat_history)
|
||||
self.set_history([str(history) for history in chat_history])
|
||||
|
||||
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(question, metadata=metadata)
|
||||
chat_message.add_ai_message(answer, metadata=metadata)
|
||||
self.memory.add(app_id=app_id, chat_message=chat_message)
|
||||
self.update_history(app_id=app_id)
|
||||
|
||||
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
|
||||
self.update_history()
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
self.memory.chat_memory.add_user_message(input_query)
|
||||
|
||||
if isinstance(answer, str):
|
||||
self.memory.chat_memory.add_ai_message(answer)
|
||||
logging.info(f"Answer: {answer}")
|
||||
|
||||
# NOTE: Adding to history before and after. This could be seen as redundant.
|
||||
# If we change it, we have to change the tests (no big deal).
|
||||
self.update_history()
|
||||
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
|
||||
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
|
||||
"""
|
||||
Construct a list of langchain messages
|
||||
|
||||
|
||||
0
embedchain/memory/__init__.py
Normal file
0
embedchain/memory/__init__.py
Normal file
112
embedchain/memory/base.py
Normal file
112
embedchain/memory/base.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.memory.message import ChatMessage
|
||||
from embedchain.memory.utils import merge_metadata_dict
|
||||
|
||||
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
|
||||
CREATE TABLE IF NOT EXISTS chat_history (
|
||||
app_id TEXT,
|
||||
id TEXT,
|
||||
question TEXT,
|
||||
answer TEXT,
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id, app_id)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class ECChatMemory:
|
||||
def __init__(self) -> None:
|
||||
with sqlite3.connect(SQLITE_PATH) as self.connection:
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
|
||||
self.connection.commit()
|
||||
|
||||
def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
|
||||
if metadata_dict:
|
||||
metadata = self._serialize_json(metadata_dict)
|
||||
ADD_CHAT_MESSAGE_QUERY = """
|
||||
INSERT INTO chat_history (app_id, id, question, answer, metadata)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
"""
|
||||
self.cursor.execute(
|
||||
ADD_CHAT_MESSAGE_QUERY,
|
||||
(
|
||||
app_id,
|
||||
memory_id,
|
||||
chat_message.human_message.content,
|
||||
chat_message.ai_message.content,
|
||||
metadata if metadata_dict else "{}",
|
||||
),
|
||||
)
|
||||
self.connection.commit()
|
||||
logging.info(f"Added chat memory to db with id: {memory_id}")
|
||||
return memory_id
|
||||
|
||||
def delete_chat_history(self, app_id: str):
|
||||
DELETE_CHAT_HISTORY_QUERY = """
|
||||
DELETE FROM chat_history WHERE app_id=?
|
||||
"""
|
||||
self.cursor.execute(
|
||||
DELETE_CHAT_HISTORY_QUERY,
|
||||
(app_id,),
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
|
||||
"""
|
||||
Get the most recent num_rounds rounds of conversations
|
||||
between human and AI, for a given app_id.
|
||||
"""
|
||||
|
||||
QUERY = """
|
||||
SELECT * FROM chat_history
|
||||
WHERE app_id=?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
self.cursor.execute(
|
||||
QUERY,
|
||||
(app_id, num_rounds),
|
||||
)
|
||||
|
||||
results = self.cursor.fetchall()
|
||||
history = []
|
||||
for result in results:
|
||||
app_id, id, question, answer, metadata, timestamp = result
|
||||
metadata = self._deserialize_json(metadata=metadata)
|
||||
memory = ChatMessage()
|
||||
memory.add_user_message(question, metadata=metadata)
|
||||
memory.add_ai_message(answer, metadata=metadata)
|
||||
history.append(memory)
|
||||
return history
|
||||
|
||||
def _serialize_json(self, metadata: Dict[str, Any]):
|
||||
return json.dumps(metadata)
|
||||
|
||||
def _deserialize_json(self, metadata: str):
|
||||
return json.loads(metadata)
|
||||
|
||||
def close_connection(self):
|
||||
self.connection.close()
|
||||
|
||||
def count_history_messages(self, app_id: str):
|
||||
QUERY = """
|
||||
SELECT COUNT(*) FROM chat_history
|
||||
WHERE app_id=?
|
||||
"""
|
||||
self.cursor.execute(
|
||||
QUERY,
|
||||
(app_id,),
|
||||
)
|
||||
count = self.cursor.fetchone()[0]
|
||||
return count
|
||||
72
embedchain/memory/message.py
Normal file
72
embedchain/memory/message.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseMessage(JSONSerializable):
|
||||
"""
|
||||
The base abstract message class.
|
||||
|
||||
Messages are the inputs and outputs of Models.
|
||||
"""
|
||||
|
||||
# The string content of the message.
|
||||
content: str
|
||||
|
||||
# The creator of the message. AI, Human, Bot etc.
|
||||
by: str
|
||||
|
||||
# Any additional info.
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self.content = content
|
||||
self.creator = creator
|
||||
self.metadata = metadata
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the Message, used for serialization."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.creator}: {self.content}"
|
||||
|
||||
|
||||
class ChatMessage(JSONSerializable):
|
||||
"""
|
||||
The base abstract chat message class.
|
||||
|
||||
Chat messages are the pair of (question, answer) conversation
|
||||
between human and model.
|
||||
"""
|
||||
|
||||
human_message: Optional[BaseMessage] = None
|
||||
ai_message: Optional[BaseMessage] = None
|
||||
|
||||
def add_user_message(self, message: str, metadata: Optional[dict] = None):
|
||||
if self.human_message:
|
||||
logging.info(
|
||||
"Human message already exists in the chat message,\
|
||||
overwritting it with new message."
|
||||
)
|
||||
|
||||
self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
|
||||
|
||||
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
|
||||
if self.ai_message:
|
||||
logging.info(
|
||||
"AI message already exists in the chat message,\
|
||||
overwritting it with new message."
|
||||
)
|
||||
|
||||
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.human_message} | {self.ai_message}"
|
||||
35
embedchain/memory/utils.py
Normal file
35
embedchain/memory/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Merge the metadatas of two BaseMessage types.
|
||||
|
||||
Args:
|
||||
left (Dict[str, Any]): metadata of human message
|
||||
right (Dict[str, Any]): metadata of ai message
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: combined metadata dict with dedup
|
||||
to be saved in db.
|
||||
"""
|
||||
if not left and not right:
|
||||
return None
|
||||
elif not left:
|
||||
return right
|
||||
elif not right:
|
||||
return left
|
||||
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif type(merged[k]) != type(v):
|
||||
raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.")
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = merge_metadata_dict(merged[k], v)
|
||||
else:
|
||||
raise ValueError(f"Additional kwargs key {k} already exists in this message.")
|
||||
return merged
|
||||
@@ -10,7 +10,8 @@ import yaml
|
||||
|
||||
from embedchain import Client
|
||||
from embedchain.config import ChunkerConfig, PipelineConfig
|
||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
@@ -22,8 +23,6 @@ from embedchain.utils import validate_yaml_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Pipeline(EmbedChain):
|
||||
|
||||
@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
|
||||
@@ -7,6 +7,7 @@ from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
@@ -25,6 +26,11 @@ def test_whole_app(app_instance, mocker):
|
||||
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "generate_prompt")
|
||||
mocker.patch.object(
|
||||
BaseLlm,
|
||||
"add_history",
|
||||
)
|
||||
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||
|
||||
app_instance.add(knowledge, data_type="text")
|
||||
app_instance.query("What text did I give you?")
|
||||
@@ -41,6 +47,10 @@ def test_add_after_reset(app_instance, mocker):
|
||||
chroma_config = {"allow_reset": True}
|
||||
|
||||
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||
|
||||
# mock delete chat history
|
||||
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||
|
||||
app_instance.reset()
|
||||
|
||||
app_instance.db.client.heartbeat()
|
||||
|
||||
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
@@ -31,14 +33,14 @@ class TestApp(unittest.TestCase):
|
||||
"""
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
with patch.object(BaseLlm, "add_history") as mock_history:
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer")
|
||||
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 4)
|
||||
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
|
||||
|
||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
@@ -49,16 +51,22 @@ class TestApp(unittest.TestCase):
|
||||
|
||||
Also tests that a dry run does not change the history
|
||||
"""
|
||||
with patch.object(ECChatMemory, "get_recent_memories") as mock_memory:
|
||||
mock_message = ChatMessage()
|
||||
mock_message.add_user_message("Test query 1")
|
||||
mock_message.add_ai_message("Test answer")
|
||||
mock_memory.return_value = [mock_message]
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
history = app.llm.history
|
||||
dry_run = app.chat("Test query 2", dry_run=True)
|
||||
self.assertIn("History:", dry_run)
|
||||
self.assertEqual(history, app.llm.history)
|
||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
|
||||
67
tests/memory/test_chat_memory.py
Normal file
67
tests/memory/test_chat_memory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
# Fixture for creating an instance of ECChatMemory
|
||||
@pytest.fixture
|
||||
def chat_memory_instance():
|
||||
return ECChatMemory()
|
||||
|
||||
|
||||
def test_add_chat_memory(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
human_message = "Hello, how are you?"
|
||||
ai_message = "I'm fine, thank you!"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
|
||||
assert chat_memory_instance.count_history_messages(app_id) == 1
|
||||
chat_memory_instance.delete_chat_history(app_id)
|
||||
|
||||
|
||||
def test_get_recent_memories(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
|
||||
for i in range(1, 7):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
|
||||
recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
|
||||
|
||||
assert len(recent_memories) == 5
|
||||
|
||||
|
||||
def test_delete_chat_history(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
|
||||
for i in range(1, 6):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, chat_message)
|
||||
|
||||
chat_memory_instance.delete_chat_history(app_id)
|
||||
|
||||
assert chat_memory_instance.count_history_messages(app_id) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def close_connection(chat_memory_instance):
|
||||
yield
|
||||
chat_memory_instance.close_connection()
|
||||
37
tests/memory/test_memory_messages.py
Normal file
37
tests/memory/test_memory_messages.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from embedchain.memory.message import BaseMessage, ChatMessage
|
||||
|
||||
|
||||
def test_ec_base_message():
|
||||
content = "Hello, how are you?"
|
||||
creator = "human"
|
||||
metadata = {"key": "value"}
|
||||
|
||||
message = BaseMessage(content=content, creator=creator, metadata=metadata)
|
||||
|
||||
assert message.content == content
|
||||
assert message.creator == creator
|
||||
assert message.metadata == metadata
|
||||
assert message.type is None
|
||||
assert message.is_lc_serializable() is True
|
||||
assert str(message) == f"{creator}: {content}"
|
||||
|
||||
|
||||
def test_ec_base_chat_message():
|
||||
human_message_content = "Hello, how are you?"
|
||||
ai_message_content = "I'm fine, thank you!"
|
||||
human_metadata = {"user": "John"}
|
||||
ai_metadata = {"response_time": 0.5}
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message_content, metadata=human_metadata)
|
||||
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
|
||||
|
||||
assert chat_message.human_message.content == human_message_content
|
||||
assert chat_message.human_message.creator == "human"
|
||||
assert chat_message.human_message.metadata == human_metadata
|
||||
|
||||
assert chat_message.ai_message.content == ai_message_content
|
||||
assert chat_message.ai_message.creator == "ai"
|
||||
assert chat_message.ai_message.metadata == ai_metadata
|
||||
|
||||
assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"
|
||||
@@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
|
||||
from embedchain.utils import validate_yaml_config
|
||||
|
||||
CONFIG_YAMLS = [
|
||||
|
||||
Reference in New Issue
Block a user