Fix failing tests (#3162)
This commit is contained in:
@@ -20,9 +20,9 @@ def config():
|
||||
|
||||
|
||||
def test_embed_query(mock_genai, config):
|
||||
mock_embedding_response = type('Response', (), {
|
||||
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
|
||||
})()
|
||||
mock_embedding_response = type(
|
||||
"Response", (), {"embeddings": [type("Embedding", (), {"values": [0.1, 0.2, 0.3, 0.4]})]}
|
||||
)()
|
||||
mock_genai.return_value = mock_embedding_response
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
@@ -35,16 +35,16 @@ def test_embed_query(mock_genai, config):
|
||||
|
||||
|
||||
def test_embed_returns_empty_list_if_none(mock_genai, config):
|
||||
mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
|
||||
mock_genai.return_value = type("Response", (), {"embeddings": [type("Embedding", (), {"values": []})]})()
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
|
||||
with pytest.raises(IndexError): # This will raise IndexError when trying to access [0]
|
||||
embedder.embed("test")
|
||||
|
||||
result = embedder.embed("test")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_embed_raises_on_error(mock_genai_client, config):
|
||||
mock_genai_client.models.embed_content.side_effect = RuntimeError("Embedding failed")
|
||||
def test_embed_raises_on_error(mock_genai, config):
|
||||
mock_genai.side_effect = RuntimeError("Embedding failed")
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
|
||||
|
||||
@@ -37,11 +37,11 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
call_args = mock_gemini_client.models.generate_content.call_args
|
||||
|
||||
# Verify model and contents
|
||||
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
|
||||
assert len(call_args.kwargs['contents']) == 1 # Only user message
|
||||
assert call_args.kwargs["model"] == "gemini-2.0-flash-latest"
|
||||
assert len(call_args.kwargs["contents"]) == 1 # Only user message
|
||||
|
||||
# Verify config has system instruction
|
||||
config_arg = call_args.kwargs['config']
|
||||
config_arg = call_args.kwargs["config"]
|
||||
assert config_arg.system_instruction == "You are a helpful assistant."
|
||||
assert config_arg.temperature == 0.7
|
||||
assert config_arg.max_output_tokens == 100
|
||||
@@ -72,9 +72,6 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
}
|
||||
]
|
||||
|
||||
# Create a proper mock for the function call arguments
|
||||
mock_args = {"data": "Today is a sunny day."}
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.name = "add_memory"
|
||||
mock_tool_call.args = {"data": "Today is a sunny day."}
|
||||
@@ -104,11 +101,11 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
call_args = mock_gemini_client.models.generate_content.call_args
|
||||
|
||||
# Verify model and contents
|
||||
assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
|
||||
assert len(call_args.kwargs['contents']) == 1 # Only user message
|
||||
assert call_args.kwargs["model"] == "gemini-1.5-flash-latest"
|
||||
assert len(call_args.kwargs["contents"]) == 1 # Only user message
|
||||
|
||||
# Verify config has system instruction and tools
|
||||
config_arg = call_args.kwargs['config']
|
||||
config_arg = call_args.kwargs["config"]
|
||||
assert config_arg.system_instruction == "You are a helpful assistant."
|
||||
assert config_arg.temperature == 0.7
|
||||
assert config_arg.max_output_tokens == 100
|
||||
|
||||
@@ -1,67 +1,65 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0 import Memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_store():
|
||||
return Memory()
|
||||
def memory_client():
|
||||
with patch.object(Memory, "__init__", return_value=None):
|
||||
client = Memory()
|
||||
client.add = MagicMock(return_value={"results": [{"id": "1", "memory": "Name is John Doe.", "event": "ADD"}]})
|
||||
client.get = MagicMock(return_value={"id": "1", "memory": "Name is John Doe."})
|
||||
client.update = MagicMock(return_value={"message": "Memory updated successfully!"})
|
||||
client.delete = MagicMock(return_value={"message": "Memory deleted successfully!"})
|
||||
client.history = MagicMock(return_value=[{"memory": "I like Indian food."}, {"memory": "I like Italian food."}])
|
||||
client.get_all = MagicMock(return_value=["Name is John Doe.", "Name is John Doe. I like to code in Python."])
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_create_memory(memory_store):
|
||||
def test_create_memory(memory_client):
|
||||
data = "Name is John Doe."
|
||||
memory_id = memory_store.create(data=data)
|
||||
assert memory_store.get(memory_id) == data
|
||||
result = memory_client.add([{"role": "user", "content": data}], user_id="test_user")
|
||||
assert result["results"][0]["memory"] == data
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_get_memory(memory_store):
|
||||
def test_get_memory(memory_client):
|
||||
data = "Name is John Doe."
|
||||
memory_id = memory_store.create(data=data)
|
||||
retrieved_data = memory_store.get(memory_id)
|
||||
assert retrieved_data == data
|
||||
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
|
||||
result = memory_client.get("1")
|
||||
assert result["memory"] == data
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_update_memory(memory_store):
|
||||
def test_update_memory(memory_client):
|
||||
data = "Name is John Doe."
|
||||
memory_id = memory_store.create(data=data)
|
||||
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
|
||||
new_data = "Name is John Kapoor."
|
||||
updated_memory = memory_store.update(memory_id, new_data)
|
||||
assert updated_memory == new_data
|
||||
assert memory_store.get(memory_id) == new_data
|
||||
update_result = memory_client.update("1", text=new_data)
|
||||
assert update_result["message"] == "Memory updated successfully!"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_delete_memory(memory_store):
|
||||
def test_delete_memory(memory_client):
|
||||
data = "Name is John Doe."
|
||||
memory_id = memory_store.create(data=data)
|
||||
memory_store.delete(memory_id)
|
||||
assert memory_store.get(memory_id) is None
|
||||
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
|
||||
delete_result = memory_client.delete("1")
|
||||
assert delete_result["message"] == "Memory deleted successfully!"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_history(memory_store):
|
||||
def test_history(memory_client):
|
||||
data = "I like Indian food."
|
||||
memory_id = memory_store.create(data=data)
|
||||
history = memory_store.history(memory_id)
|
||||
assert history == [data]
|
||||
assert memory_store.get(memory_id) == data
|
||||
|
||||
new_data = "I like Italian food."
|
||||
memory_store.update(memory_id, new_data)
|
||||
history = memory_store.history(memory_id)
|
||||
assert history == [data, new_data]
|
||||
assert memory_store.get(memory_id) == new_data
|
||||
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
|
||||
memory_client.update("1", text="I like Italian food.")
|
||||
history = memory_client.history("1")
|
||||
assert history[0]["memory"] == "I like Indian food."
|
||||
assert history[1]["memory"] == "I like Italian food."
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_list_memories(memory_store):
|
||||
def test_list_memories(memory_client):
|
||||
data1 = "Name is John Doe."
|
||||
data2 = "Name is John Doe. I like to code in Python."
|
||||
memory_store.create(data=data1)
|
||||
memory_store.create(data=data2)
|
||||
memories = memory_store.list()
|
||||
memory_client.add([{"role": "user", "content": data1}], user_id="test_user")
|
||||
memory_client.add([{"role": "user", "content": data2}], user_id="test_user")
|
||||
memories = memory_client.get_all(user_id="test_user")
|
||||
assert data1 in memories
|
||||
assert data2 in memories
|
||||
|
||||
@@ -5,7 +5,7 @@ from mem0.memory.main import Memory
|
||||
|
||||
def test_memory_configuration_without_env_vars():
|
||||
"""Test Memory configuration with mock config instead of environment variables"""
|
||||
|
||||
|
||||
# Mock configuration without relying on environment variables
|
||||
mock_config = {
|
||||
"llm": {
|
||||
@@ -14,60 +14,62 @@ def test_memory_configuration_without_env_vars():
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 1500,
|
||||
}
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": "test_collection",
|
||||
"path": "./test_db",
|
||||
}
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-ada-002",
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Test messages similar to the main.py file
|
||||
test_messages = [
|
||||
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
|
||||
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Mock the Memory class methods to avoid actual API calls
|
||||
with patch.object(Memory, '__init__', return_value=None):
|
||||
with patch.object(Memory, 'from_config') as mock_from_config:
|
||||
with patch.object(Memory, 'add') as mock_add:
|
||||
with patch.object(Memory, 'get_all') as mock_get_all:
|
||||
|
||||
with patch.object(Memory, "__init__", return_value=None):
|
||||
with patch.object(Memory, "from_config") as mock_from_config:
|
||||
with patch.object(Memory, "add") as mock_add:
|
||||
with patch.object(Memory, "get_all") as mock_get_all:
|
||||
# Configure mocks
|
||||
mock_memory_instance = MagicMock()
|
||||
mock_from_config.return_value = mock_memory_instance
|
||||
|
||||
|
||||
mock_add.return_value = {
|
||||
"results": [
|
||||
{"id": "1", "text": "Alex is a vegetarian"},
|
||||
{"id": "2", "text": "Alex is allergic to nuts"}
|
||||
{"id": "2", "text": "Alex is allergic to nuts"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
mock_get_all.return_value = [
|
||||
{"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}},
|
||||
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}}
|
||||
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}},
|
||||
]
|
||||
|
||||
|
||||
# Test the workflow
|
||||
mem = Memory.from_config(config_dict=mock_config)
|
||||
assert mem is not None
|
||||
|
||||
|
||||
# Test adding memories
|
||||
result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"})
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 2
|
||||
|
||||
|
||||
# Test retrieving memories
|
||||
all_memories = mock_get_all(user_id="alice")
|
||||
assert len(all_memories) == 2
|
||||
@@ -77,7 +79,7 @@ def test_memory_configuration_without_env_vars():
|
||||
|
||||
def test_azure_config_structure():
|
||||
"""Test that Azure configuration structure is properly formatted"""
|
||||
|
||||
|
||||
# Test Azure configuration structure (without actual credentials)
|
||||
azure_config = {
|
||||
"llm": {
|
||||
@@ -91,8 +93,8 @@ def test_azure_config_structure():
|
||||
"api_version": "2023-12-01-preview",
|
||||
"azure_endpoint": "https://test.openai.azure.com/",
|
||||
"api_key": "test-key",
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": "azure_ai_search",
|
||||
@@ -101,7 +103,7 @@ def test_azure_config_structure():
|
||||
"api_key": "test-key",
|
||||
"collection_name": "test-collection",
|
||||
"embedding_model_dims": 1536,
|
||||
}
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "azure_openai",
|
||||
@@ -113,46 +115,49 @@ def test_azure_config_structure():
|
||||
"azure_deployment": "test-embedding-deployment",
|
||||
"azure_endpoint": "https://test.openai.azure.com/",
|
||||
"api_key": "test-key",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Validate configuration structure
|
||||
assert "llm" in azure_config
|
||||
assert "vector_store" in azure_config
|
||||
assert "embedder" in azure_config
|
||||
|
||||
|
||||
# Validate Azure-specific configurations
|
||||
assert azure_config["llm"]["provider"] == "azure_openai"
|
||||
assert "azure_kwargs" in azure_config["llm"]["config"]
|
||||
assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"]
|
||||
|
||||
|
||||
assert azure_config["vector_store"]["provider"] == "azure_ai_search"
|
||||
assert "service_name" in azure_config["vector_store"]["config"]
|
||||
|
||||
|
||||
assert azure_config["embedder"]["provider"] == "azure_openai"
|
||||
assert "azure_kwargs" in azure_config["embedder"]["config"]
|
||||
|
||||
|
||||
def test_memory_messages_format():
|
||||
"""Test that memory messages are properly formatted"""
|
||||
|
||||
|
||||
# Test message format from main.py
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
|
||||
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Validate message structure
|
||||
assert len(messages) == 2
|
||||
assert all("role" in msg for msg in messages)
|
||||
assert all("content" in msg for msg in messages)
|
||||
|
||||
|
||||
# Validate roles
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
|
||||
# Validate content
|
||||
assert "vegetarian" in messages[0]["content"].lower()
|
||||
assert "allergic to nuts" in messages[0]["content"].lower()
|
||||
@@ -162,12 +167,12 @@ def test_memory_messages_format():
|
||||
|
||||
def test_safe_update_prompt_constant():
|
||||
"""Test the SAFE_UPDATE_PROMPT constant from main.py"""
|
||||
|
||||
|
||||
SAFE_UPDATE_PROMPT = """
|
||||
Based on the user's latest messages, what new preference can be inferred?
|
||||
Reply only in this json_object format:
|
||||
"""
|
||||
|
||||
|
||||
# Validate prompt structure
|
||||
assert isinstance(SAFE_UPDATE_PROMPT, str)
|
||||
assert "user's latest messages" in SAFE_UPDATE_PROMPT
|
||||
|
||||
Reference in New Issue
Block a user