Support Custom Prompt for Memory Action Decision (#2371)

This commit is contained in:
Wonbin Kim
2025-03-18 14:13:01 +09:00
committed by GitHub
parent b8f40f728f
commit 66d3f9b93c
9 changed files with 498 additions and 165 deletions

View File

@@ -0,0 +1,17 @@
from mem0.configs import prompts
def test_get_update_memory_messages():
retrieved_old_memory_dict = [{"id": "1", "text": "old memory 1"}]
response_content = ["new fact"]
custom_update_memory_prompt = "custom prompt determining memory update"
## When custom update memory prompt is provided
##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt)
assert result.startswith(custom_update_memory_prompt)
## When custom update memory prompt is not provided
##
result = prompts.get_update_memory_messages(retrieved_old_memory_dict, response_content, None)
assert result.startswith(prompts.DEFAULT_UPDATE_MEMORY_PROMPT)

View File

@@ -30,6 +30,25 @@ def memory_instance():
config = MemoryConfig(version="v1.1")
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.fixture
def memory_custom_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
config = MemoryConfig(
version="v1.1",
custom_fact_extraction_prompt="custom prompt extracting memory",
custom_update_memory_prompt="custom prompt determining memory update"
)
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
@@ -239,3 +258,33 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else:
memory_instance.graph.get_all.assert_not_called()
def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}]
memory_custom_instance.llm.generate_response = Mock()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages:
memory_custom_instance.add(messages=messages, user_id="test_user")
## custom prompt
##
mock_parse_messages.assert_called_once_with(messages)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
{"role": "user", "content": f"Input:\n{mock_parse_messages.return_value}"},
],
response_format={"type": "json_object"},
)
## custom update memory prompt
##
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
response_format={"type": "json_object"},
)