Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -17,11 +17,13 @@ def mock_openai():
@pytest.fixture
def memory_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
with (
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
patch("mem0.utils.factory.LlmFactory") as mock_llm,
patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
@@ -30,13 +32,16 @@ def memory_instance():
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.fixture
def memory_custom_instance():
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
with (
patch("mem0.utils.factory.EmbedderFactory") as mock_embedder,
patch("mem0.utils.factory.VectorStoreFactory") as mock_vector_store,
patch("mem0.utils.factory.LlmFactory") as mock_llm,
patch("mem0.memory.telemetry.capture_event"),
patch("mem0.memory.graph_memory.MemoryGraph"),
):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
@@ -44,7 +49,7 @@ def memory_custom_instance():
config = MemoryConfig(
version="v1.1",
custom_fact_extraction_prompt="custom prompt extracting memory",
custom_update_memory_prompt="custom prompt determining memory update"
custom_update_memory_prompt="custom prompt determining memory update",
)
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@@ -194,7 +199,6 @@ def test_delete_all(memory_instance, version, enable_graph):
assert result["message"] == "Memories deleted successfully!"
@pytest.mark.parametrize(
"version, enable_graph, expected_result",
[
@@ -242,20 +246,22 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}, 100)
else:
memory_instance.graph.get_all.assert_not_called()
def test_custom_prompts(memory_custom_instance):
messages = [{"role": "user", "content": "Test message"}]
memory_custom_instance.llm.generate_response = Mock()
with patch("mem0.memory.main.parse_messages", return_value="Test message") as mock_parse_messages:
with patch("mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt") as mock_get_update_memory_messages:
with patch(
"mem0.memory.main.get_update_memory_messages", return_value="custom update memory prompt"
) as mock_get_update_memory_messages:
memory_custom_instance.add(messages=messages, user_id="test_user")
## custom prompt
##
mock_parse_messages.assert_called_once_with(messages)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[
{"role": "system", "content": memory_custom_instance.config.custom_fact_extraction_prompt},
@@ -263,12 +269,14 @@ def test_custom_prompts(memory_custom_instance):
],
response_format={"type": "json_object"},
)
## custom update memory prompt
##
mock_get_update_memory_messages.assert_called_once_with([],[],memory_custom_instance.config.custom_update_memory_prompt)
mock_get_update_memory_messages.assert_called_once_with(
[], [], memory_custom_instance.config.custom_update_memory_prompt
)
memory_custom_instance.llm.generate_response.assert_any_call(
messages=[{"role": "user", "content": mock_get_update_memory_messages.return_value}],
response_format={"type": "json_object"},
)
)