From 44336661174484045dfcc91f9d8b998b5bd8fd13 Mon Sep 17 00:00:00 2001 From: Parshva Daftari <89991302+parshvadaftari@users.noreply.github.com> Date: Fri, 25 Jul 2025 00:58:45 +0530 Subject: [PATCH] Fix failing tests (#3162) --- .../multiagents/llamaindex_learning_system.py | 48 ++++------- mem0/client/main.py | 21 +++-- mem0/client/project.py | 39 +++------ mem0/client/utils.py | 5 +- mem0/llms/azure_openai.py | 4 +- mem0/llms/azure_openai_structured.py | 4 +- mem0/llms/vllm.py | 2 - tests/embeddings/test_gemini_emeddings.py | 18 ++--- tests/llms/test_gemini.py | 15 ++-- tests/test_memory.py | 76 +++++++++-------- tests/test_memory_integration.py | 81 ++++++++++--------- 11 files changed, 144 insertions(+), 169 deletions(-) diff --git a/examples/multiagents/llamaindex_learning_system.py b/examples/multiagents/llamaindex_learning_system.py index 4dd23819..2896c467 100644 --- a/examples/multiagents/llamaindex_learning_system.py +++ b/examples/multiagents/llamaindex_learning_system.py @@ -8,35 +8,30 @@ You need MEM0_API_KEY and OPENAI_API_KEY to run the example. """ import asyncio -from datetime import datetime -from dotenv import load_dotenv import logging +from datetime import datetime + +from dotenv import load_dotenv # LlamaIndex imports from llama_index.core.agent.workflow import AgentWorkflow, FunctionAgent -from llama_index.llms.openai import OpenAI from llama_index.core.tools import FunctionTool +from llama_index.llms.openai import OpenAI # Memory integration from llama_index.memory.mem0 import Mem0Memory -import warnings - load_dotenv() # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(), - logging.FileHandler('learning_system.log') - ] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("learning_system.log")], ) logger = logging.getLogger(__name__) - class MultiAgentLearningSystem: """ Multi-Agent Architecture: @@ -51,9 +46,7 @@ class MultiAgentLearningSystem: # Memory context for this student self.memory_context = {"user_id": student_id, "app": "learning_assistant"} - self.memory = Mem0Memory.from_client( - context=self.memory_context - ) + self.memory = Mem0Memory.from_client(context=self.memory_context) self._setup_agents() @@ -84,7 +77,7 @@ class MultiAgentLearningSystem: # Convert to FunctionTools tools = [ FunctionTool.from_defaults(async_fn=assess_understanding), - FunctionTool.from_defaults(async_fn=track_progress) + FunctionTool.from_defaults(async_fn=track_progress), ] # === AGENTS === @@ -111,7 +104,7 @@ class MultiAgentLearningSystem: """, tools=tools, llm=self.llm, - can_handoff_to=["PracticeAgent"] + can_handoff_to=["PracticeAgent"], ) # Practice Agent - Exercises and reinforcement @@ -137,7 +130,7 @@ class MultiAgentLearningSystem: """, tools=tools, llm=self.llm, - can_handoff_to=["TutorAgent"] + can_handoff_to=["TutorAgent"], ) # Create the multi-agent workflow @@ -148,8 +141,8 @@ class MultiAgentLearningSystem: "current_topic": "", "student_level": "beginner", "learning_style": "unknown", - "session_goals": [] - } + "session_goals": [], + }, ) async def start_learning_session(self, topic: str, student_message: str = "") -> str: @@ -163,10 +156,7 @@ class MultiAgentLearningSystem: request = f"I want to learn about {topic}." # The magic happens here - multi-agent memory is automatically shared! - response = await self.workflow.run( - user_msg=request, - memory=self.memory - ) + response = await self.workflow.run(user_msg=request, memory=self.memory) return str(response) @@ -174,10 +164,7 @@ class MultiAgentLearningSystem: """Show what the system remembers about this student""" try: # Search memory for learning patterns - memories = self.memory.search( - user_id=self.student_id, - query="learning machine learning" - ) + memories = self.memory.search(user_id=self.student_id, query="learning machine learning") if memories and len(memories): history = "\n".join(f"- {m['memory']}" for m in memories) @@ -190,20 +177,19 @@ class MultiAgentLearningSystem: async def run_learning_agent(): - learning_system = MultiAgentLearningSystem(student_id="Alexander") # First session logger.info("Session 1:") response = await learning_system.start_learning_session( "Vision Language Models", - "I'm new to machine learning but I have good hold on Python and have 4 years of work experience.") + "I'm new to machine learning but I have good hold on Python and have 4 years of work experience.", + ) logger.info(response) # Second session - multi-agent memory will remember the first logger.info("\nSession 2:") - response2 = await learning_system.start_learning_session( - "Machine Learning", "what all did I cover so far?") + response2 = await learning_system.start_learning_session("Machine Learning", "what all did I cover so far?") logger.info(response2) # Show what the multi-agent system remembers diff --git a/mem0/client/main.py b/mem0/client/main.py index 11ad7306..ac15e9e7 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -2,16 +2,15 @@ import hashlib import logging import os import warnings -from functools import wraps from typing import Any, Dict, List, Optional import httpx import requests +from mem0.client.project import AsyncProject, Project +from mem0.client.utils import api_error_handler from mem0.memory.setup import get_user_id, setup_config from mem0.memory.telemetry import capture_client_event -from mem0.client.project import Project, AsyncProject -from mem0.client.utils import api_error_handler logger = logging.getLogger(__name__) @@ -562,7 +561,9 @@ class MemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ - logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.") + logger.warning( + "get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead." + ) if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to access instructions or categories") @@ -604,7 +605,9 @@ class MemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ - logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.") + logger.warning( + "update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead." + ) if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to update instructions or categories") @@ -1330,7 +1333,9 @@ class AsyncMemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ - logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.") + logger.warning( + "get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead." + ) if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to access instructions or categories") @@ -1368,7 +1373,9 @@ class AsyncMemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ - logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.") + logger.warning( + "update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead." + ) if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to update instructions or categories") diff --git a/mem0/client/project.py b/mem0/client/project.py index c113c3b4..8b1aef72 100644 --- a/mem0/client/project.py +++ b/mem0/client/project.py @@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional import httpx from pydantic import BaseModel, Field -from mem0.memory.telemetry import capture_client_event from mem0.client.utils import api_error_handler +from mem0.memory.telemetry import capture_client_event logger = logging.getLogger(__name__) @@ -16,18 +16,9 @@ class ProjectConfig(BaseModel): Configuration for project management operations. """ - org_id: Optional[str] = Field( - default=None, - description="Organization ID" - ) - project_id: Optional[str] = Field( - default=None, - description="Project ID" - ) - user_email: Optional[str] = Field( - default=None, - description="User email" - ) + org_id: Optional[str] = Field(default=None, description="Organization ID") + project_id: Optional[str] = Field(default=None, description="Project ID") + user_email: Optional[str] = Field(default=None, description="User email") class Config: validate_assignment = True @@ -64,11 +55,7 @@ class BaseProject(ABC): self.config = config else: # Create config from parameters - self.config = ProjectConfig( - org_id=org_id, - project_id=project_id, - user_email=user_email - ) + self.config = ProjectConfig(org_id=org_id, project_id=project_id, user_email=user_email) @property def org_id(self) -> Optional[str]: @@ -93,13 +80,9 @@ class BaseProject(ABC): ValueError: If org_id or project_id are not set. """ if not (self.config.org_id and self.config.project_id): - raise ValueError( - "org_id and project_id must be set to access project operations" - ) + raise ValueError("org_id and project_id must be set to access project operations") - def _prepare_params( - self, kwargs: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Prepare query parameters for API requests. @@ -124,9 +107,7 @@ class BaseProject(ABC): return {k: v for k, v in kwargs.items() if v is not None} - def _prepare_org_params( - self, kwargs: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + def _prepare_org_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Prepare query parameters for organization-level API requests. @@ -423,7 +404,7 @@ class Project(BaseProject): "custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, - "enable_graph": enable_graph + "enable_graph": enable_graph, } ) response = self._client.patch( @@ -716,7 +697,7 @@ class AsyncProject(BaseProject): "custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, - "enable_graph": enable_graph + "enable_graph": enable_graph, } ) response = await self._client.patch( diff --git a/mem0/client/utils.py b/mem0/client/utils.py index 53632b19..3eaa092c 100644 --- a/mem0/client/utils.py +++ b/mem0/client/utils.py @@ -1,10 +1,13 @@ -import httpx import logging +import httpx + logger = logging.getLogger(__name__) + class APIError(Exception): """Exception raised for errors in the API.""" + pass diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index 9a04a804..c736c43c 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -82,11 +82,11 @@ class AzureOpenAILLM(LLMBase): str: The generated response. """ - user_prompt = messages[-1]['content'] + user_prompt = messages[-1]["content"] user_prompt = user_prompt.replace("assistant", "ai") - messages[-1]['content'] = user_prompt + messages[-1]["content"] = user_prompt common_params = { "model": self.config.model, diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py index a9361fc5..5b949735 100644 --- a/mem0/llms/azure_openai_structured.py +++ b/mem0/llms/azure_openai_structured.py @@ -49,11 +49,11 @@ class AzureOpenAIStructuredLLM(LLMBase): str: The generated response. """ - user_prompt = messages[-1]['content'] + user_prompt = messages[-1]["content"] user_prompt = user_prompt.replace("assistant", "ai") - messages[-1]['content'] = user_prompt + messages[-1]["content"] = user_prompt params = { "model": self.config.model, diff --git a/mem0/llms/vllm.py b/mem0/llms/vllm.py index efd9fe6a..6aa13add 100644 --- a/mem0/llms/vllm.py +++ b/mem0/llms/vllm.py @@ -4,8 +4,6 @@ from typing import Dict, List, Optional from openai import OpenAI -from openai import OpenAI - from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase from mem0.memory.utils import extract_json diff --git a/tests/embeddings/test_gemini_emeddings.py b/tests/embeddings/test_gemini_emeddings.py index 3f8c431b..0792c3c5 100644 --- a/tests/embeddings/test_gemini_emeddings.py +++ b/tests/embeddings/test_gemini_emeddings.py @@ -20,9 +20,9 @@ def config(): def test_embed_query(mock_genai, config): - mock_embedding_response = type('Response', (), { - 'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})] - })() + mock_embedding_response = type( + "Response", (), {"embeddings": [type("Embedding", (), {"values": [0.1, 0.2, 0.3, 0.4]})]} + )() mock_genai.return_value = mock_embedding_response embedder = GoogleGenAIEmbedding(config) @@ -35,16 +35,16 @@ def test_embed_query(mock_genai, config): def test_embed_returns_empty_list_if_none(mock_genai, config): - mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})() + mock_genai.return_value = type("Response", (), {"embeddings": [type("Embedding", (), {"values": []})]})() embedder = GoogleGenAIEmbedding(config) - - with pytest.raises(IndexError): # This will raise IndexError when trying to access [0] - embedder.embed("test") + + result = embedder.embed("test") + assert result == [] -def test_embed_raises_on_error(mock_genai_client, config): - mock_genai_client.models.embed_content.side_effect = RuntimeError("Embedding failed") +def test_embed_raises_on_error(mock_genai, config): + mock_genai.side_effect = RuntimeError("Embedding failed") embedder = GoogleGenAIEmbedding(config) diff --git a/tests/llms/test_gemini.py b/tests/llms/test_gemini.py index f64ec6a9..19875d9d 100644 --- a/tests/llms/test_gemini.py +++ b/tests/llms/test_gemini.py @@ -37,11 +37,11 @@ def test_generate_response_without_tools(mock_gemini_client: Mock): call_args = mock_gemini_client.models.generate_content.call_args # Verify model and contents - assert call_args.kwargs['model'] == "gemini-2.0-flash-latest" - assert len(call_args.kwargs['contents']) == 1 # Only user message + assert call_args.kwargs["model"] == "gemini-2.0-flash-latest" + assert len(call_args.kwargs["contents"]) == 1 # Only user message # Verify config has system instruction - config_arg = call_args.kwargs['config'] + config_arg = call_args.kwargs["config"] assert config_arg.system_instruction == "You are a helpful assistant." assert config_arg.temperature == 0.7 assert config_arg.max_output_tokens == 100 @@ -72,9 +72,6 @@ def test_generate_response_with_tools(mock_gemini_client: Mock): } ] - # Create a proper mock for the function call arguments - mock_args = {"data": "Today is a sunny day."} - mock_tool_call = Mock() mock_tool_call.name = "add_memory" mock_tool_call.args = {"data": "Today is a sunny day."} @@ -104,11 +101,11 @@ def test_generate_response_with_tools(mock_gemini_client: Mock): call_args = mock_gemini_client.models.generate_content.call_args # Verify model and contents - assert call_args.kwargs['model'] == "gemini-1.5-flash-latest" - assert len(call_args.kwargs['contents']) == 1 # Only user message + assert call_args.kwargs["model"] == "gemini-1.5-flash-latest" + assert len(call_args.kwargs["contents"]) == 1 # Only user message # Verify config has system instruction and tools - config_arg = call_args.kwargs['config'] + config_arg = call_args.kwargs["config"] assert config_arg.system_instruction == "You are a helpful assistant." assert config_arg.temperature == 0.7 assert config_arg.max_output_tokens == 100 diff --git a/tests/test_memory.py b/tests/test_memory.py index 2659d06c..72e6a260 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,67 +1,65 @@ +from unittest.mock import MagicMock, patch + import pytest from mem0 import Memory @pytest.fixture -def memory_store(): - return Memory() +def memory_client(): + with patch.object(Memory, "__init__", return_value=None): + client = Memory() + client.add = MagicMock(return_value={"results": [{"id": "1", "memory": "Name is John Doe.", "event": "ADD"}]}) + client.get = MagicMock(return_value={"id": "1", "memory": "Name is John Doe."}) + client.update = MagicMock(return_value={"message": "Memory updated successfully!"}) + client.delete = MagicMock(return_value={"message": "Memory deleted successfully!"}) + client.history = MagicMock(return_value=[{"memory": "I like Indian food."}, {"memory": "I like Italian food."}]) + client.get_all = MagicMock(return_value=["Name is John Doe.", "Name is John Doe. I like to code in Python."]) + yield client -@pytest.mark.skip(reason="Not implemented") -def test_create_memory(memory_store): +def test_create_memory(memory_client): data = "Name is John Doe." - memory_id = memory_store.create(data=data) - assert memory_store.get(memory_id) == data + result = memory_client.add([{"role": "user", "content": data}], user_id="test_user") + assert result["results"][0]["memory"] == data -@pytest.mark.skip(reason="Not implemented") -def test_get_memory(memory_store): +def test_get_memory(memory_client): data = "Name is John Doe." - memory_id = memory_store.create(data=data) - retrieved_data = memory_store.get(memory_id) - assert retrieved_data == data + memory_client.add([{"role": "user", "content": data}], user_id="test_user") + result = memory_client.get("1") + assert result["memory"] == data -@pytest.mark.skip(reason="Not implemented") -def test_update_memory(memory_store): +def test_update_memory(memory_client): data = "Name is John Doe." - memory_id = memory_store.create(data=data) + memory_client.add([{"role": "user", "content": data}], user_id="test_user") new_data = "Name is John Kapoor." - updated_memory = memory_store.update(memory_id, new_data) - assert updated_memory == new_data - assert memory_store.get(memory_id) == new_data + update_result = memory_client.update("1", text=new_data) + assert update_result["message"] == "Memory updated successfully!" -@pytest.mark.skip(reason="Not implemented") -def test_delete_memory(memory_store): +def test_delete_memory(memory_client): data = "Name is John Doe." - memory_id = memory_store.create(data=data) - memory_store.delete(memory_id) - assert memory_store.get(memory_id) is None + memory_client.add([{"role": "user", "content": data}], user_id="test_user") + delete_result = memory_client.delete("1") + assert delete_result["message"] == "Memory deleted successfully!" -@pytest.mark.skip(reason="Not implemented") -def test_history(memory_store): +def test_history(memory_client): data = "I like Indian food." - memory_id = memory_store.create(data=data) - history = memory_store.history(memory_id) - assert history == [data] - assert memory_store.get(memory_id) == data - - new_data = "I like Italian food." - memory_store.update(memory_id, new_data) - history = memory_store.history(memory_id) - assert history == [data, new_data] - assert memory_store.get(memory_id) == new_data + memory_client.add([{"role": "user", "content": data}], user_id="test_user") + memory_client.update("1", text="I like Italian food.") + history = memory_client.history("1") + assert history[0]["memory"] == "I like Indian food." + assert history[1]["memory"] == "I like Italian food." -@pytest.mark.skip(reason="Not implemented") -def test_list_memories(memory_store): +def test_list_memories(memory_client): data1 = "Name is John Doe." data2 = "Name is John Doe. I like to code in Python." - memory_store.create(data=data1) - memory_store.create(data=data2) - memories = memory_store.list() + memory_client.add([{"role": "user", "content": data1}], user_id="test_user") + memory_client.add([{"role": "user", "content": data2}], user_id="test_user") + memories = memory_client.get_all(user_id="test_user") assert data1 in memories assert data2 in memories diff --git a/tests/test_memory_integration.py b/tests/test_memory_integration.py index 899eb76d..b23be49f 100644 --- a/tests/test_memory_integration.py +++ b/tests/test_memory_integration.py @@ -5,7 +5,7 @@ from mem0.memory.main import Memory def test_memory_configuration_without_env_vars(): """Test Memory configuration with mock config instead of environment variables""" - + # Mock configuration without relying on environment variables mock_config = { "llm": { @@ -14,60 +14,62 @@ def test_memory_configuration_without_env_vars(): "model": "gpt-4", "temperature": 0.1, "max_tokens": 1500, - } + }, }, "vector_store": { "provider": "chroma", "config": { "collection_name": "test_collection", "path": "./test_db", - } + }, }, "embedder": { "provider": "openai", "config": { "model": "text-embedding-ada-002", - } - } + }, + }, } - + # Test messages similar to the main.py file test_messages = [ {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, - {"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} + { + "role": "assistant", + "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.", + }, ] - + # Mock the Memory class methods to avoid actual API calls - with patch.object(Memory, '__init__', return_value=None): - with patch.object(Memory, 'from_config') as mock_from_config: - with patch.object(Memory, 'add') as mock_add: - with patch.object(Memory, 'get_all') as mock_get_all: - + with patch.object(Memory, "__init__", return_value=None): + with patch.object(Memory, "from_config") as mock_from_config: + with patch.object(Memory, "add") as mock_add: + with patch.object(Memory, "get_all") as mock_get_all: # Configure mocks mock_memory_instance = MagicMock() mock_from_config.return_value = mock_memory_instance - + mock_add.return_value = { "results": [ {"id": "1", "text": "Alex is a vegetarian"}, - {"id": "2", "text": "Alex is allergic to nuts"} + {"id": "2", "text": "Alex is allergic to nuts"}, ] } - + mock_get_all.return_value = [ {"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}}, - {"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}} + {"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}}, ] - + # Test the workflow mem = Memory.from_config(config_dict=mock_config) assert mem is not None - + # Test adding memories result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"}) assert "results" in result assert len(result["results"]) == 2 - + # Test retrieving memories all_memories = mock_get_all(user_id="alice") assert len(all_memories) == 2 @@ -77,7 +79,7 @@ def test_memory_configuration_without_env_vars(): def test_azure_config_structure(): """Test that Azure configuration structure is properly formatted""" - + # Test Azure configuration structure (without actual credentials) azure_config = { "llm": { @@ -91,8 +93,8 @@ def test_azure_config_structure(): "api_version": "2023-12-01-preview", "azure_endpoint": "https://test.openai.azure.com/", "api_key": "test-key", - } - } + }, + }, }, "vector_store": { "provider": "azure_ai_search", @@ -101,7 +103,7 @@ def test_azure_config_structure(): "api_key": "test-key", "collection_name": "test-collection", "embedding_model_dims": 1536, - } + }, }, "embedder": { "provider": "azure_openai", @@ -113,46 +115,49 @@ def test_azure_config_structure(): "azure_deployment": "test-embedding-deployment", "azure_endpoint": "https://test.openai.azure.com/", "api_key": "test-key", - } - } - } + }, + }, + }, } - + # Validate configuration structure assert "llm" in azure_config assert "vector_store" in azure_config assert "embedder" in azure_config - + # Validate Azure-specific configurations assert azure_config["llm"]["provider"] == "azure_openai" assert "azure_kwargs" in azure_config["llm"]["config"] assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"] - + assert azure_config["vector_store"]["provider"] == "azure_ai_search" assert "service_name" in azure_config["vector_store"]["config"] - + assert azure_config["embedder"]["provider"] == "azure_openai" assert "azure_kwargs" in azure_config["embedder"]["config"] def test_memory_messages_format(): """Test that memory messages are properly formatted""" - + # Test message format from main.py messages = [ {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, - {"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} + { + "role": "assistant", + "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.", + }, ] - + # Validate message structure assert len(messages) == 2 assert all("role" in msg for msg in messages) assert all("content" in msg for msg in messages) - + # Validate roles assert messages[0]["role"] == "user" assert messages[1]["role"] == "assistant" - + # Validate content assert "vegetarian" in messages[0]["content"].lower() assert "allergic to nuts" in messages[0]["content"].lower() @@ -162,12 +167,12 @@ def test_memory_messages_format(): def test_safe_update_prompt_constant(): """Test the SAFE_UPDATE_PROMPT constant from main.py""" - + SAFE_UPDATE_PROMPT = """ Based on the user's latest messages, what new preference can be inferred? Reply only in this json_object format: """ - + # Validate prompt structure assert isinstance(SAFE_UPDATE_PROMPT, str) assert "user's latest messages" in SAFE_UPDATE_PROMPT