Fix failing tests (#3162)

This commit is contained in:
Parshva Daftari
2025-07-25 00:58:45 +05:30
committed by GitHub
parent 37ee3c5eb2
commit 4433666117
11 changed files with 144 additions and 169 deletions

View File

@@ -8,35 +8,30 @@ You need MEM0_API_KEY and OPENAI_API_KEY to run the example.
"""
import asyncio
from datetime import datetime
from dotenv import load_dotenv
import logging
from datetime import datetime
from dotenv import load_dotenv
# LlamaIndex imports
from llama_index.core.agent.workflow import AgentWorkflow, FunctionAgent
from llama_index.llms.openai import OpenAI
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI
# Memory integration
from llama_index.memory.mem0 import Mem0Memory
import warnings
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('learning_system.log')
]
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler("learning_system.log")],
)
logger = logging.getLogger(__name__)
class MultiAgentLearningSystem:
"""
Multi-Agent Architecture:
@@ -51,9 +46,7 @@ class MultiAgentLearningSystem:
# Memory context for this student
self.memory_context = {"user_id": student_id, "app": "learning_assistant"}
self.memory = Mem0Memory.from_client(
context=self.memory_context
)
self.memory = Mem0Memory.from_client(context=self.memory_context)
self._setup_agents()
@@ -84,7 +77,7 @@ class MultiAgentLearningSystem:
# Convert to FunctionTools
tools = [
FunctionTool.from_defaults(async_fn=assess_understanding),
FunctionTool.from_defaults(async_fn=track_progress)
FunctionTool.from_defaults(async_fn=track_progress),
]
# === AGENTS ===
@@ -111,7 +104,7 @@ class MultiAgentLearningSystem:
""",
tools=tools,
llm=self.llm,
can_handoff_to=["PracticeAgent"]
can_handoff_to=["PracticeAgent"],
)
# Practice Agent - Exercises and reinforcement
@@ -137,7 +130,7 @@ class MultiAgentLearningSystem:
""",
tools=tools,
llm=self.llm,
can_handoff_to=["TutorAgent"]
can_handoff_to=["TutorAgent"],
)
# Create the multi-agent workflow
@@ -148,8 +141,8 @@ class MultiAgentLearningSystem:
"current_topic": "",
"student_level": "beginner",
"learning_style": "unknown",
"session_goals": []
}
"session_goals": [],
},
)
async def start_learning_session(self, topic: str, student_message: str = "") -> str:
@@ -163,10 +156,7 @@ class MultiAgentLearningSystem:
request = f"I want to learn about {topic}."
# The magic happens here - multi-agent memory is automatically shared!
response = await self.workflow.run(
user_msg=request,
memory=self.memory
)
response = await self.workflow.run(user_msg=request, memory=self.memory)
return str(response)
@@ -174,10 +164,7 @@ class MultiAgentLearningSystem:
"""Show what the system remembers about this student"""
try:
# Search memory for learning patterns
memories = self.memory.search(
user_id=self.student_id,
query="learning machine learning"
)
memories = self.memory.search(user_id=self.student_id, query="learning machine learning")
if memories and len(memories):
history = "\n".join(f"- {m['memory']}" for m in memories)
@@ -190,20 +177,19 @@ class MultiAgentLearningSystem:
async def run_learning_agent():
learning_system = MultiAgentLearningSystem(student_id="Alexander")
# First session
logger.info("Session 1:")
response = await learning_system.start_learning_session(
"Vision Language Models",
"I'm new to machine learning but I have good hold on Python and have 4 years of work experience.")
"I'm new to machine learning but I have good hold on Python and have 4 years of work experience.",
)
logger.info(response)
# Second session - multi-agent memory will remember the first
logger.info("\nSession 2:")
response2 = await learning_system.start_learning_session(
"Machine Learning", "what all did I cover so far?")
response2 = await learning_system.start_learning_session("Machine Learning", "what all did I cover so far?")
logger.info(response2)
# Show what the multi-agent system remembers

View File

@@ -2,16 +2,15 @@ import hashlib
import logging
import os
import warnings
from functools import wraps
from typing import Any, Dict, List, Optional
import httpx
import requests
from mem0.client.project import AsyncProject, Project
from mem0.client.utils import api_error_handler
from mem0.memory.setup import get_user_id, setup_config
from mem0.memory.telemetry import capture_client_event
from mem0.client.project import Project, AsyncProject
from mem0.client.utils import api_error_handler
logger = logging.getLogger(__name__)
@@ -562,7 +561,9 @@ class MemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.")
logger.warning(
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories")
@@ -604,7 +605,9 @@ class MemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.")
logger.warning(
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")
@@ -1330,7 +1333,9 @@ class AsyncMemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.")
logger.warning(
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories")
@@ -1368,7 +1373,9 @@ class AsyncMemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.")
logger.warning(
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")

View File

@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional
import httpx
from pydantic import BaseModel, Field
from mem0.memory.telemetry import capture_client_event
from mem0.client.utils import api_error_handler
from mem0.memory.telemetry import capture_client_event
logger = logging.getLogger(__name__)
@@ -16,18 +16,9 @@ class ProjectConfig(BaseModel):
Configuration for project management operations.
"""
org_id: Optional[str] = Field(
default=None,
description="Organization ID"
)
project_id: Optional[str] = Field(
default=None,
description="Project ID"
)
user_email: Optional[str] = Field(
default=None,
description="User email"
)
org_id: Optional[str] = Field(default=None, description="Organization ID")
project_id: Optional[str] = Field(default=None, description="Project ID")
user_email: Optional[str] = Field(default=None, description="User email")
class Config:
validate_assignment = True
@@ -64,11 +55,7 @@ class BaseProject(ABC):
self.config = config
else:
# Create config from parameters
self.config = ProjectConfig(
org_id=org_id,
project_id=project_id,
user_email=user_email
)
self.config = ProjectConfig(org_id=org_id, project_id=project_id, user_email=user_email)
@property
def org_id(self) -> Optional[str]:
@@ -93,13 +80,9 @@ class BaseProject(ABC):
ValueError: If org_id or project_id are not set.
"""
if not (self.config.org_id and self.config.project_id):
raise ValueError(
"org_id and project_id must be set to access project operations"
)
raise ValueError("org_id and project_id must be set to access project operations")
def _prepare_params(
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Prepare query parameters for API requests.
@@ -124,9 +107,7 @@ class BaseProject(ABC):
return {k: v for k, v in kwargs.items() if v is not None}
def _prepare_org_params(
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def _prepare_org_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Prepare query parameters for organization-level API requests.
@@ -423,7 +404,7 @@ class Project(BaseProject):
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"enable_graph": enable_graph
"enable_graph": enable_graph,
}
)
response = self._client.patch(
@@ -716,7 +697,7 @@ class AsyncProject(BaseProject):
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"enable_graph": enable_graph
"enable_graph": enable_graph,
}
)
response = await self._client.patch(

View File

@@ -1,10 +1,13 @@
import httpx
import logging
import httpx
logger = logging.getLogger(__name__)
class APIError(Exception):
"""Exception raised for errors in the API."""
pass

View File

@@ -82,11 +82,11 @@ class AzureOpenAILLM(LLMBase):
str: The generated response.
"""
user_prompt = messages[-1]['content']
user_prompt = messages[-1]["content"]
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
messages[-1]["content"] = user_prompt
common_params = {
"model": self.config.model,

View File

@@ -49,11 +49,11 @@ class AzureOpenAIStructuredLLM(LLMBase):
str: The generated response.
"""
user_prompt = messages[-1]['content']
user_prompt = messages[-1]["content"]
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
messages[-1]["content"] = user_prompt
params = {
"model": self.config.model,

View File

@@ -4,8 +4,6 @@ from typing import Dict, List, Optional
from openai import OpenAI
from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json

View File

@@ -20,9 +20,9 @@ def config():
def test_embed_query(mock_genai, config):
mock_embedding_response = type('Response', (), {
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})]
})()
mock_embedding_response = type(
"Response", (), {"embeddings": [type("Embedding", (), {"values": [0.1, 0.2, 0.3, 0.4]})]}
)()
mock_genai.return_value = mock_embedding_response
embedder = GoogleGenAIEmbedding(config)
@@ -35,16 +35,16 @@ def test_embed_query(mock_genai, config):
def test_embed_returns_empty_list_if_none(mock_genai, config):
mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})()
mock_genai.return_value = type("Response", (), {"embeddings": [type("Embedding", (), {"values": []})]})()
embedder = GoogleGenAIEmbedding(config)
with pytest.raises(IndexError): # This will raise IndexError when trying to access [0]
embedder.embed("test")
result = embedder.embed("test")
assert result == []
def test_embed_raises_on_error(mock_genai_client, config):
mock_genai_client.models.embed_content.side_effect = RuntimeError("Embedding failed")
def test_embed_raises_on_error(mock_genai, config):
mock_genai.side_effect = RuntimeError("Embedding failed")
embedder = GoogleGenAIEmbedding(config)

View File

@@ -37,11 +37,11 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
call_args = mock_gemini_client.models.generate_content.call_args
# Verify model and contents
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message
assert call_args.kwargs["model"] == "gemini-2.0-flash-latest"
assert len(call_args.kwargs["contents"]) == 1 # Only user message
# Verify config has system instruction
config_arg = call_args.kwargs['config']
config_arg = call_args.kwargs["config"]
assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100
@@ -72,9 +72,6 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
}
]
# Create a proper mock for the function call arguments
mock_args = {"data": "Today is a sunny day."}
mock_tool_call = Mock()
mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."}
@@ -104,11 +101,11 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
call_args = mock_gemini_client.models.generate_content.call_args
# Verify model and contents
assert call_args.kwargs['model'] == "gemini-1.5-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message
assert call_args.kwargs["model"] == "gemini-1.5-flash-latest"
assert len(call_args.kwargs["contents"]) == 1 # Only user message
# Verify config has system instruction and tools
config_arg = call_args.kwargs['config']
config_arg = call_args.kwargs["config"]
assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100

View File

@@ -1,67 +1,65 @@
from unittest.mock import MagicMock, patch
import pytest
from mem0 import Memory
@pytest.fixture
def memory_store():
return Memory()
def memory_client():
with patch.object(Memory, "__init__", return_value=None):
client = Memory()
client.add = MagicMock(return_value={"results": [{"id": "1", "memory": "Name is John Doe.", "event": "ADD"}]})
client.get = MagicMock(return_value={"id": "1", "memory": "Name is John Doe."})
client.update = MagicMock(return_value={"message": "Memory updated successfully!"})
client.delete = MagicMock(return_value={"message": "Memory deleted successfully!"})
client.history = MagicMock(return_value=[{"memory": "I like Indian food."}, {"memory": "I like Italian food."}])
client.get_all = MagicMock(return_value=["Name is John Doe.", "Name is John Doe. I like to code in Python."])
yield client
@pytest.mark.skip(reason="Not implemented")
def test_create_memory(memory_store):
def test_create_memory(memory_client):
data = "Name is John Doe."
memory_id = memory_store.create(data=data)
assert memory_store.get(memory_id) == data
result = memory_client.add([{"role": "user", "content": data}], user_id="test_user")
assert result["results"][0]["memory"] == data
@pytest.mark.skip(reason="Not implemented")
def test_get_memory(memory_store):
def test_get_memory(memory_client):
data = "Name is John Doe."
memory_id = memory_store.create(data=data)
retrieved_data = memory_store.get(memory_id)
assert retrieved_data == data
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
result = memory_client.get("1")
assert result["memory"] == data
@pytest.mark.skip(reason="Not implemented")
def test_update_memory(memory_store):
def test_update_memory(memory_client):
data = "Name is John Doe."
memory_id = memory_store.create(data=data)
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
new_data = "Name is John Kapoor."
updated_memory = memory_store.update(memory_id, new_data)
assert updated_memory == new_data
assert memory_store.get(memory_id) == new_data
update_result = memory_client.update("1", text=new_data)
assert update_result["message"] == "Memory updated successfully!"
@pytest.mark.skip(reason="Not implemented")
def test_delete_memory(memory_store):
def test_delete_memory(memory_client):
data = "Name is John Doe."
memory_id = memory_store.create(data=data)
memory_store.delete(memory_id)
assert memory_store.get(memory_id) is None
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
delete_result = memory_client.delete("1")
assert delete_result["message"] == "Memory deleted successfully!"
@pytest.mark.skip(reason="Not implemented")
def test_history(memory_store):
def test_history(memory_client):
data = "I like Indian food."
memory_id = memory_store.create(data=data)
history = memory_store.history(memory_id)
assert history == [data]
assert memory_store.get(memory_id) == data
new_data = "I like Italian food."
memory_store.update(memory_id, new_data)
history = memory_store.history(memory_id)
assert history == [data, new_data]
assert memory_store.get(memory_id) == new_data
memory_client.add([{"role": "user", "content": data}], user_id="test_user")
memory_client.update("1", text="I like Italian food.")
history = memory_client.history("1")
assert history[0]["memory"] == "I like Indian food."
assert history[1]["memory"] == "I like Italian food."
@pytest.mark.skip(reason="Not implemented")
def test_list_memories(memory_store):
def test_list_memories(memory_client):
data1 = "Name is John Doe."
data2 = "Name is John Doe. I like to code in Python."
memory_store.create(data=data1)
memory_store.create(data=data2)
memories = memory_store.list()
memory_client.add([{"role": "user", "content": data1}], user_id="test_user")
memory_client.add([{"role": "user", "content": data2}], user_id="test_user")
memories = memory_client.get_all(user_id="test_user")
assert data1 in memories
assert data2 in memories

View File

@@ -14,35 +14,37 @@ def test_memory_configuration_without_env_vars():
"model": "gpt-4",
"temperature": 0.1,
"max_tokens": 1500,
}
},
},
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": "test_collection",
"path": "./test_db",
}
},
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-ada-002",
}
}
},
},
}
# Test messages similar to the main.py file
test_messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
{
"role": "assistant",
"content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.",
},
]
# Mock the Memory class methods to avoid actual API calls
with patch.object(Memory, '__init__', return_value=None):
with patch.object(Memory, 'from_config') as mock_from_config:
with patch.object(Memory, 'add') as mock_add:
with patch.object(Memory, 'get_all') as mock_get_all:
with patch.object(Memory, "__init__", return_value=None):
with patch.object(Memory, "from_config") as mock_from_config:
with patch.object(Memory, "add") as mock_add:
with patch.object(Memory, "get_all") as mock_get_all:
# Configure mocks
mock_memory_instance = MagicMock()
mock_from_config.return_value = mock_memory_instance
@@ -50,13 +52,13 @@ def test_memory_configuration_without_env_vars():
mock_add.return_value = {
"results": [
{"id": "1", "text": "Alex is a vegetarian"},
{"id": "2", "text": "Alex is allergic to nuts"}
{"id": "2", "text": "Alex is allergic to nuts"},
]
}
mock_get_all.return_value = [
{"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}},
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}}
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}},
]
# Test the workflow
@@ -91,8 +93,8 @@ def test_azure_config_structure():
"api_version": "2023-12-01-preview",
"azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key",
}
}
},
},
},
"vector_store": {
"provider": "azure_ai_search",
@@ -101,7 +103,7 @@ def test_azure_config_structure():
"api_key": "test-key",
"collection_name": "test-collection",
"embedding_model_dims": 1536,
}
},
},
"embedder": {
"provider": "azure_openai",
@@ -113,9 +115,9 @@ def test_azure_config_structure():
"azure_deployment": "test-embedding-deployment",
"azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key",
}
}
}
},
},
},
}
# Validate configuration structure
@@ -141,7 +143,10 @@ def test_memory_messages_format():
# Test message format from main.py
messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
{
"role": "assistant",
"content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.",
},
]
# Validate message structure