Support Custom Prompt for Memory Action Decision (#2371)
This commit is contained in:
17
tests/configs/test_prompts.py
Normal file
17
tests/configs/test_prompts.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from mem0.configs import prompts
|
||||
|
||||
|
||||
def test_get_update_memory_messages():
|
||||
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
|
||||
response_content = ["new fact"]
|
||||
custom_update_memory_prompt = "custom prompt determining memory update"
|
||||
|
||||
## When custom update memory prompt is provided
|
||||
##
|
||||
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt)
|
||||
assert result.startswith(custom_update_memory_prompt)
|
||||
|
||||
## When custom update memory prompt is not provided
|
||||
##
|
||||
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
|
||||
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)
|
||||
@@ -30,6 +30,25 @@ def memory_instance():
|
||||
config = MemoryConfig(version="v1.1")
|
||||
config.graph_store.config = {"some_config": "value"}
|
||||
return Memory(config)
|
||||
|
||||
@pytest.fixture
|
||||
def memory_custom_instance():
|
||||
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
|
||||
"mem0.utils.factory.VectorStoreFactory"
|
||||
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
|
||||
"mem0.memory.telemetry.capture_event"
|
||||
), patch("mem0.memory.graph_memory.MemoryGraph"):
|
||||
mock_embedder.create.return_value = Mock()
|
||||
mock_vector_store.create.return_value = Mock()
|
||||
mock_llm.create.return_value = Mock()
|
||||
|
||||
config = MemoryConfig(
|
||||
version="v1.1",
|
||||
custom_fact_extraction_prompt="custom prompt extracting memory",
|
||||
custom_update_memory_prompt="custom prompt determining memory update"
|
||||
)
|
||||
config.graph_store.config = {"some_config": "value"}
|
||||
return Memory(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
|
||||
@@ -239,3 +258,33 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
|
||||
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
|
||||
else:
|
||||
memory_instance.graph.get_all.assert_not_called()
|
||||
|
||||
|
||||
def test_custom_prompts(memory_custom_instance):
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
memory_custom_instance.llm.generate_response = Mock()
|
||||
|
||||
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
|
||||
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages:
|
||||
memory_custom_instance.add(messages=messages, user_id="test_user")
|
||||
|
||||
## custom prompt
|
||||
##
|
||||
mock_parse_messages.assert_called_once_with(messages)
|
||||
|
||||
memory_custom_instance.llm.generate_response.assert_any_call(
|
||||
messages=[
|
||||
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
|
||||
{"role": "user", "content": f"Input:\n{mock_parse_messages.return_value}"},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
## custom update memory prompt
|
||||
##
|
||||
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt)
|
||||
|
||||
memory_custom_instance.llm.generate_response.assert_any_call(
|
||||
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
Reference in New Issue
Block a user