[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -4,42 +4,39 @@ from unittest.mock import Mock, patch
from mem0.memory.main import Memory
from mem0.configs.base import MemoryConfig
@pytest.fixture(autouse=True)
def mock_openai():
os.environ['OPENAI_API_KEY'] = "123"
with patch('openai.OpenAI') as mock:
os.environ["OPENAI_API_KEY"] = "123"
with patch("openai.OpenAI") as mock:
mock.return_value = Mock()
yield mock
@pytest.fixture
def memory_instance():
with patch('mem0.utils.factory.EmbedderFactory') as mock_embedder, \
patch('mem0.utils.factory.VectorStoreFactory') as mock_vector_store, \
patch('mem0.utils.factory.LlmFactory') as mock_llm, \
patch('mem0.memory.telemetry.capture_event'), \
patch('mem0.memory.graph_memory.MemoryGraph'):
with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch(
"mem0.utils.factory.VectorStoreFactory"
) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch(
"mem0.memory.telemetry.capture_event"
), patch("mem0.memory.graph_memory.MemoryGraph"):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
config = MemoryConfig(version="v1.1")
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.mark.parametrize("version, enable_graph", [
("v1.0", False),
("v1.1", True)
])
@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_add(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
memory_instance._add_to_graph = Mock(return_value=[])
result = memory_instance.add(
messages=[{"role": "user", "content": "Test message"}],
user_id="test_user"
)
result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user")
assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
@@ -47,26 +44,27 @@ def test_add(memory_instance, version, enable_graph):
assert result["relations"] == []
memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}],
{"user_id": "test_user"},
{"user_id": "test_user"}
)
# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}],
{"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
)
# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}
)
def test_get(memory_instance):
mock_memory = Mock(id="test_id", payload={
"data": "Test memory",
"user_id": "test_user",
"hash": "test_hash",
"created_at": "2023-01-01T00:00:00",
"updated_at": "2023-01-02T00:00:00",
"extra_field": "extra_value"
})
mock_memory = Mock(
id="test_id",
payload={
"data": "Test memory",
"user_id": "test_user",
"hash": "test_hash",
"created_at": "2023-01-01T00:00:00",
"updated_at": "2023-01-02T00:00:00",
"extra_field": "extra_value",
},
)
memory_instance.vector_store.get = Mock(return_value=mock_memory)
result = memory_instance.get("test_id")
@@ -79,16 +77,14 @@ def test_get(memory_instance):
assert result["updated_at"] == "2023-01-02T00:00:00"
assert result["metadata"] == {"extra_field": "extra_value"}
@pytest.mark.parametrize("version, enable_graph", [
("v1.0", False),
("v1.1", True)
])
@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_search(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [
Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"}, score=0.9),
Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8)
Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8),
]
memory_instance.vector_store.search = Mock(return_value=mock_memories)
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
@@ -118,17 +114,16 @@ def test_search(memory_instance, version, enable_graph):
assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3],
limit=100,
filters={"user_id": "test_user"}
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
)
memory_instance.embedding_model.embed.assert_called_once_with("test query")
if enable_graph:
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"})
else:
memory_instance.graph.search.assert_not_called()
def test_update(memory_instance):
memory_instance._update_memory = Mock()
@@ -137,6 +132,7 @@ def test_update(memory_instance):
memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory")
assert result["message"] == "Memory updated successfully!"
def test_delete(memory_instance):
memory_instance._delete_memory = Mock()
@@ -145,10 +141,8 @@ def test_delete(memory_instance):
memory_instance._delete_memory.assert_called_once_with("test_id")
assert result["message"] == "Memory deleted successfully!"
@pytest.mark.parametrize("version, enable_graph", [
("v1.0", False),
("v1.1", True)
])
@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_delete_all(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
@@ -160,14 +154,15 @@ def test_delete_all(memory_instance, version, enable_graph):
result = memory_instance.delete_all(user_id="test_user")
assert memory_instance._delete_memory.call_count == 2
if enable_graph:
memory_instance.graph.delete_all.assert_called_once_with({"user_id": "test_user"})
else:
memory_instance.graph.delete_all.assert_not_called()
assert result["message"] == "Memories deleted successfully!"
def test_reset(memory_instance):
memory_instance.vector_store.delete_col = Mock()
memory_instance.db.reset = Mock()
@@ -177,22 +172,30 @@ def test_reset(memory_instance):
memory_instance.vector_store.delete_col.assert_called_once()
memory_instance.db.reset.assert_called_once()
@pytest.mark.parametrize("version, enable_graph, expected_result", [
("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
("v1.1", True, {
"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}],
"relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}]
})
])
@pytest.mark.parametrize(
"version, enable_graph, expected_result",
[
("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
(
"v1.1",
True,
{
"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}],
"relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}],
},
),
],
)
def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"})]
memory_instance.vector_store.list = Mock(return_value=(mock_memories, None))
memory_instance.graph.get_all = Mock(return_value=[
{"source": "entity1", "relationship": "rel", "target": "entity2"}
])
memory_instance.graph.get_all = Mock(
return_value=[{"source": "entity1", "relationship": "rel", "target": "entity2"}]
)
result = memory_instance.get_all(user_id="test_user")
@@ -204,7 +207,7 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
assert result_item["id"] == expected_item["id"]
assert result_item["memory"] == expected_item["memory"]
assert result_item["user_id"] == expected_item["user_id"]
if enable_graph:
assert "relations" in result
assert result["relations"] == expected_result["relations"]
@@ -212,7 +215,7 @@ def test_get_all(memory_instance, version, enable_graph, expected_result):
assert "relations" not in result
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
if enable_graph:
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"})
else: