Fix failing tests (#3162)

This commit is contained in:
Parshva Daftari
2025-07-25 00:58:45 +05:30
committed by GitHub
parent 37ee3c5eb2
commit 4433666117
11 changed files with 144 additions and 169 deletions

View File

@@ -8,35 +8,30 @@ You need MEM0_API_KEY and OPENAI_API_KEY to run the example.
""" """
import asyncio import asyncio
from datetime import datetime
from dotenv import load_dotenv
import logging import logging
from datetime import datetime
from dotenv import load_dotenv
# LlamaIndex imports # LlamaIndex imports
from llama_index.core.agent.workflow import AgentWorkflow, FunctionAgent from llama_index.core.agent.workflow import AgentWorkflow, FunctionAgent
from llama_index.llms.openai import OpenAI
from llama_index.core.tools import FunctionTool from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI
# Memory integration # Memory integration
from llama_index.memory.mem0 import Mem0Memory from llama_index.memory.mem0 import Mem0Memory
import warnings
load_dotenv() load_dotenv()
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[ handlers=[logging.StreamHandler(), logging.FileHandler("learning_system.log")],
logging.StreamHandler(),
logging.FileHandler('learning_system.log')
]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MultiAgentLearningSystem: class MultiAgentLearningSystem:
""" """
Multi-Agent Architecture: Multi-Agent Architecture:
@@ -51,9 +46,7 @@ class MultiAgentLearningSystem:
# Memory context for this student # Memory context for this student
self.memory_context = {"user_id": student_id, "app": "learning_assistant"} self.memory_context = {"user_id": student_id, "app": "learning_assistant"}
self.memory = Mem0Memory.from_client( self.memory = Mem0Memory.from_client(context=self.memory_context)
context=self.memory_context
)
self._setup_agents() self._setup_agents()
@@ -84,7 +77,7 @@ class MultiAgentLearningSystem:
# Convert to FunctionTools # Convert to FunctionTools
tools = [ tools = [
FunctionTool.from_defaults(async_fn=assess_understanding), FunctionTool.from_defaults(async_fn=assess_understanding),
FunctionTool.from_defaults(async_fn=track_progress) FunctionTool.from_defaults(async_fn=track_progress),
] ]
# === AGENTS === # === AGENTS ===
@@ -111,7 +104,7 @@ class MultiAgentLearningSystem:
""", """,
tools=tools, tools=tools,
llm=self.llm, llm=self.llm,
can_handoff_to=["PracticeAgent"] can_handoff_to=["PracticeAgent"],
) )
# Practice Agent - Exercises and reinforcement # Practice Agent - Exercises and reinforcement
@@ -137,7 +130,7 @@ class MultiAgentLearningSystem:
""", """,
tools=tools, tools=tools,
llm=self.llm, llm=self.llm,
can_handoff_to=["TutorAgent"] can_handoff_to=["TutorAgent"],
) )
# Create the multi-agent workflow # Create the multi-agent workflow
@@ -148,8 +141,8 @@ class MultiAgentLearningSystem:
"current_topic": "", "current_topic": "",
"student_level": "beginner", "student_level": "beginner",
"learning_style": "unknown", "learning_style": "unknown",
"session_goals": [] "session_goals": [],
} },
) )
async def start_learning_session(self, topic: str, student_message: str = "") -> str: async def start_learning_session(self, topic: str, student_message: str = "") -> str:
@@ -163,10 +156,7 @@ class MultiAgentLearningSystem:
request = f"I want to learn about {topic}." request = f"I want to learn about {topic}."
# The magic happens here - multi-agent memory is automatically shared! # The magic happens here - multi-agent memory is automatically shared!
response = await self.workflow.run( response = await self.workflow.run(user_msg=request, memory=self.memory)
user_msg=request,
memory=self.memory
)
return str(response) return str(response)
@@ -174,10 +164,7 @@ class MultiAgentLearningSystem:
"""Show what the system remembers about this student""" """Show what the system remembers about this student"""
try: try:
# Search memory for learning patterns # Search memory for learning patterns
memories = self.memory.search( memories = self.memory.search(user_id=self.student_id, query="learning machine learning")
user_id=self.student_id,
query="learning machine learning"
)
if memories and len(memories): if memories and len(memories):
history = "\n".join(f"- {m['memory']}" for m in memories) history = "\n".join(f"- {m['memory']}" for m in memories)
@@ -190,20 +177,19 @@ class MultiAgentLearningSystem:
async def run_learning_agent(): async def run_learning_agent():
learning_system = MultiAgentLearningSystem(student_id="Alexander") learning_system = MultiAgentLearningSystem(student_id="Alexander")
# First session # First session
logger.info("Session 1:") logger.info("Session 1:")
response = await learning_system.start_learning_session( response = await learning_system.start_learning_session(
"Vision Language Models", "Vision Language Models",
"I'm new to machine learning but I have good hold on Python and have 4 years of work experience.") "I'm new to machine learning but I have good hold on Python and have 4 years of work experience.",
)
logger.info(response) logger.info(response)
# Second session - multi-agent memory will remember the first # Second session - multi-agent memory will remember the first
logger.info("\nSession 2:") logger.info("\nSession 2:")
response2 = await learning_system.start_learning_session( response2 = await learning_system.start_learning_session("Machine Learning", "what all did I cover so far?")
"Machine Learning", "what all did I cover so far?")
logger.info(response2) logger.info(response2)
# Show what the multi-agent system remembers # Show what the multi-agent system remembers

View File

@@ -2,16 +2,15 @@ import hashlib
import logging import logging
import os import os
import warnings import warnings
from functools import wraps
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import httpx import httpx
import requests import requests
from mem0.client.project import AsyncProject, Project
from mem0.client.utils import api_error_handler
from mem0.memory.setup import get_user_id, setup_config from mem0.memory.setup import get_user_id, setup_config
from mem0.memory.telemetry import capture_client_event from mem0.memory.telemetry import capture_client_event
from mem0.client.project import Project, AsyncProject
from mem0.client.utils import api_error_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -562,7 +561,9 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.") logger.warning(
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
)
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories") raise ValueError("org_id and project_id must be set to access instructions or categories")
@@ -604,7 +605,9 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.") logger.warning(
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
)
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories") raise ValueError("org_id and project_id must be set to update instructions or categories")
@@ -1330,7 +1333,9 @@ class AsyncMemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.") logger.warning(
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
)
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories") raise ValueError("org_id and project_id must be set to access instructions or categories")
@@ -1368,7 +1373,9 @@ class AsyncMemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.") logger.warning(
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
)
if not (self.org_id and self.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories") raise ValueError("org_id and project_id must be set to update instructions or categories")

View File

@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional
import httpx import httpx
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from mem0.memory.telemetry import capture_client_event
from mem0.client.utils import api_error_handler from mem0.client.utils import api_error_handler
from mem0.memory.telemetry import capture_client_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,18 +16,9 @@ class ProjectConfig(BaseModel):
Configuration for project management operations. Configuration for project management operations.
""" """
org_id: Optional[str] = Field( org_id: Optional[str] = Field(default=None, description="Organization ID")
default=None, project_id: Optional[str] = Field(default=None, description="Project ID")
description="Organization ID" user_email: Optional[str] = Field(default=None, description="User email")
)
project_id: Optional[str] = Field(
default=None,
description="Project ID"
)
user_email: Optional[str] = Field(
default=None,
description="User email"
)
class Config: class Config:
validate_assignment = True validate_assignment = True
@@ -64,11 +55,7 @@ class BaseProject(ABC):
self.config = config self.config = config
else: else:
# Create config from parameters # Create config from parameters
self.config = ProjectConfig( self.config = ProjectConfig(org_id=org_id, project_id=project_id, user_email=user_email)
org_id=org_id,
project_id=project_id,
user_email=user_email
)
@property @property
def org_id(self) -> Optional[str]: def org_id(self) -> Optional[str]:
@@ -93,13 +80,9 @@ class BaseProject(ABC):
ValueError: If org_id or project_id are not set. ValueError: If org_id or project_id are not set.
""" """
if not (self.config.org_id and self.config.project_id): if not (self.config.org_id and self.config.project_id):
raise ValueError( raise ValueError("org_id and project_id must be set to access project operations")
"org_id and project_id must be set to access project operations"
)
def _prepare_params( def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
""" """
Prepare query parameters for API requests. Prepare query parameters for API requests.
@@ -124,9 +107,7 @@ class BaseProject(ABC):
return {k: v for k, v in kwargs.items() if v is not None} return {k: v for k, v in kwargs.items() if v is not None}
def _prepare_org_params( def _prepare_org_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
""" """
Prepare query parameters for organization-level API requests. Prepare query parameters for organization-level API requests.
@@ -423,7 +404,7 @@ class Project(BaseProject):
"custom_instructions": custom_instructions, "custom_instructions": custom_instructions,
"custom_categories": custom_categories, "custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria, "retrieval_criteria": retrieval_criteria,
"enable_graph": enable_graph "enable_graph": enable_graph,
} }
) )
response = self._client.patch( response = self._client.patch(
@@ -716,7 +697,7 @@ class AsyncProject(BaseProject):
"custom_instructions": custom_instructions, "custom_instructions": custom_instructions,
"custom_categories": custom_categories, "custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria, "retrieval_criteria": retrieval_criteria,
"enable_graph": enable_graph "enable_graph": enable_graph,
} }
) )
response = await self._client.patch( response = await self._client.patch(

View File

@@ -1,10 +1,13 @@
import httpx
import logging import logging
import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class APIError(Exception): class APIError(Exception):
"""Exception raised for errors in the API.""" """Exception raised for errors in the API."""
pass pass

View File

@@ -82,11 +82,11 @@ class AzureOpenAILLM(LLMBase):
str: The generated response. str: The generated response.
""" """
user_prompt = messages[-1]['content'] user_prompt = messages[-1]["content"]
user_prompt = user_prompt.replace("assistant", "ai") user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt messages[-1]["content"] = user_prompt
common_params = { common_params = {
"model": self.config.model, "model": self.config.model,

View File

@@ -49,11 +49,11 @@ class AzureOpenAIStructuredLLM(LLMBase):
str: The generated response. str: The generated response.
""" """
user_prompt = messages[-1]['content'] user_prompt = messages[-1]["content"]
user_prompt = user_prompt.replace("assistant", "ai") user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt messages[-1]["content"] = user_prompt
params = { params = {
"model": self.config.model, "model": self.config.model,

View File

@@ -4,8 +4,6 @@ from typing import Dict, List, Optional
from openai import OpenAI from openai import OpenAI
from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json from mem0.memory.utils import extract_json

View File

@@ -20,9 +20,9 @@ def config():
def test_embed_query(mock_genai, config): def test_embed_query(mock_genai, config):
mock_embedding_response = type('Response', (), { mock_embedding_response = type(
'embeddings': [type('Embedding', (), {'values': [0.1, 0.2, 0.3, 0.4]})] "Response", (), {"embeddings": [type("Embedding", (), {"values": [0.1, 0.2, 0.3, 0.4]})]}
})() )()
mock_genai.return_value = mock_embedding_response mock_genai.return_value = mock_embedding_response
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
@@ -35,16 +35,16 @@ def test_embed_query(mock_genai, config):
def test_embed_returns_empty_list_if_none(mock_genai, config): def test_embed_returns_empty_list_if_none(mock_genai, config):
mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})() mock_genai.return_value = type("Response", (), {"embeddings": [type("Embedding", (), {"values": []})]})()
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)
with pytest.raises(IndexError): # This will raise IndexError when trying to access [0] result = embedder.embed("test")
embedder.embed("test") assert result == []
def test_embed_raises_on_error(mock_genai_client, config): def test_embed_raises_on_error(mock_genai, config):
mock_genai_client.models.embed_content.side_effect = RuntimeError("Embedding failed") mock_genai.side_effect = RuntimeError("Embedding failed")
embedder = GoogleGenAIEmbedding(config) embedder = GoogleGenAIEmbedding(config)

View File

@@ -37,11 +37,11 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
call_args = mock_gemini_client.models.generate_content.call_args call_args = mock_gemini_client.models.generate_content.call_args
# Verify model and contents # Verify model and contents
assert call_args.kwargs['model'] == "gemini-2.0-flash-latest" assert call_args.kwargs["model"] == "gemini-2.0-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message assert len(call_args.kwargs["contents"]) == 1 # Only user message
# Verify config has system instruction # Verify config has system instruction
config_arg = call_args.kwargs['config'] config_arg = call_args.kwargs["config"]
assert config_arg.system_instruction == "You are a helpful assistant." assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7 assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100 assert config_arg.max_output_tokens == 100
@@ -72,9 +72,6 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
} }
] ]
# Create a proper mock for the function call arguments
mock_args = {"data": "Today is a sunny day."}
mock_tool_call = Mock() mock_tool_call = Mock()
mock_tool_call.name = "add_memory" mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."} mock_tool_call.args = {"data": "Today is a sunny day."}
@@ -104,11 +101,11 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
call_args = mock_gemini_client.models.generate_content.call_args call_args = mock_gemini_client.models.generate_content.call_args
# Verify model and contents # Verify model and contents
assert call_args.kwargs['model'] == "gemini-1.5-flash-latest" assert call_args.kwargs["model"] == "gemini-1.5-flash-latest"
assert len(call_args.kwargs['contents']) == 1 # Only user message assert len(call_args.kwargs["contents"]) == 1 # Only user message
# Verify config has system instruction and tools # Verify config has system instruction and tools
config_arg = call_args.kwargs['config'] config_arg = call_args.kwargs["config"]
assert config_arg.system_instruction == "You are a helpful assistant." assert config_arg.system_instruction == "You are a helpful assistant."
assert config_arg.temperature == 0.7 assert config_arg.temperature == 0.7
assert config_arg.max_output_tokens == 100 assert config_arg.max_output_tokens == 100

View File

@@ -1,67 +1,65 @@
from unittest.mock import MagicMock, patch
import pytest import pytest
from mem0 import Memory from mem0 import Memory
@pytest.fixture @pytest.fixture
def memory_store(): def memory_client():
return Memory() with patch.object(Memory, "__init__", return_value=None):
client = Memory()
client.add = MagicMock(return_value={"results": [{"id": "1", "memory": "Name is John Doe.", "event": "ADD"}]})
client.get = MagicMock(return_value={"id": "1", "memory": "Name is John Doe."})
client.update = MagicMock(return_value={"message": "Memory updated successfully!"})
client.delete = MagicMock(return_value={"message": "Memory deleted successfully!"})
client.history = MagicMock(return_value=[{"memory": "I like Indian food."}, {"memory": "I like Italian food."}])
client.get_all = MagicMock(return_value=["Name is John Doe.", "Name is John Doe. I like to code in Python."])
yield client
@pytest.mark.skip(reason="Not implemented") def test_create_memory(memory_client):
def test_create_memory(memory_store):
data = "Name is John Doe." data = "Name is John Doe."
memory_id = memory_store.create(data=data) result = memory_client.add([{"role": "user", "content": data}], user_id="test_user")
assert memory_store.get(memory_id) == data assert result["results"][0]["memory"] == data
@pytest.mark.skip(reason="Not implemented") def test_get_memory(memory_client):
def test_get_memory(memory_store):
data = "Name is John Doe." data = "Name is John Doe."
memory_id = memory_store.create(data=data) memory_client.add([{"role": "user", "content": data}], user_id="test_user")
retrieved_data = memory_store.get(memory_id) result = memory_client.get("1")
assert retrieved_data == data assert result["memory"] == data
@pytest.mark.skip(reason="Not implemented") def test_update_memory(memory_client):
def test_update_memory(memory_store):
data = "Name is John Doe." data = "Name is John Doe."
memory_id = memory_store.create(data=data) memory_client.add([{"role": "user", "content": data}], user_id="test_user")
new_data = "Name is John Kapoor." new_data = "Name is John Kapoor."
updated_memory = memory_store.update(memory_id, new_data) update_result = memory_client.update("1", text=new_data)
assert updated_memory == new_data assert update_result["message"] == "Memory updated successfully!"
assert memory_store.get(memory_id) == new_data
@pytest.mark.skip(reason="Not implemented") def test_delete_memory(memory_client):
def test_delete_memory(memory_store):
data = "Name is John Doe." data = "Name is John Doe."
memory_id = memory_store.create(data=data) memory_client.add([{"role": "user", "content": data}], user_id="test_user")
memory_store.delete(memory_id) delete_result = memory_client.delete("1")
assert memory_store.get(memory_id) is None assert delete_result["message"] == "Memory deleted successfully!"
@pytest.mark.skip(reason="Not implemented") def test_history(memory_client):
def test_history(memory_store):
data = "I like Indian food." data = "I like Indian food."
memory_id = memory_store.create(data=data) memory_client.add([{"role": "user", "content": data}], user_id="test_user")
history = memory_store.history(memory_id) memory_client.update("1", text="I like Italian food.")
assert history == [data] history = memory_client.history("1")
assert memory_store.get(memory_id) == data assert history[0]["memory"] == "I like Indian food."
assert history[1]["memory"] == "I like Italian food."
new_data = "I like Italian food."
memory_store.update(memory_id, new_data)
history = memory_store.history(memory_id)
assert history == [data, new_data]
assert memory_store.get(memory_id) == new_data
@pytest.mark.skip(reason="Not implemented") def test_list_memories(memory_client):
def test_list_memories(memory_store):
data1 = "Name is John Doe." data1 = "Name is John Doe."
data2 = "Name is John Doe. I like to code in Python." data2 = "Name is John Doe. I like to code in Python."
memory_store.create(data=data1) memory_client.add([{"role": "user", "content": data1}], user_id="test_user")
memory_store.create(data=data2) memory_client.add([{"role": "user", "content": data2}], user_id="test_user")
memories = memory_store.list() memories = memory_client.get_all(user_id="test_user")
assert data1 in memories assert data1 in memories
assert data2 in memories assert data2 in memories

View File

@@ -5,7 +5,7 @@ from mem0.memory.main import Memory
def test_memory_configuration_without_env_vars(): def test_memory_configuration_without_env_vars():
"""Test Memory configuration with mock config instead of environment variables""" """Test Memory configuration with mock config instead of environment variables"""
# Mock configuration without relying on environment variables # Mock configuration without relying on environment variables
mock_config = { mock_config = {
"llm": { "llm": {
@@ -14,60 +14,62 @@ def test_memory_configuration_without_env_vars():
"model": "gpt-4", "model": "gpt-4",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 1500, "max_tokens": 1500,
} },
}, },
"vector_store": { "vector_store": {
"provider": "chroma", "provider": "chroma",
"config": { "config": {
"collection_name": "test_collection", "collection_name": "test_collection",
"path": "./test_db", "path": "./test_db",
} },
}, },
"embedder": { "embedder": {
"provider": "openai", "provider": "openai",
"config": { "config": {
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",
} },
} },
} }
# Test messages similar to the main.py file # Test messages similar to the main.py file
test_messages = [ test_messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} {
"role": "assistant",
"content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.",
},
] ]
# Mock the Memory class methods to avoid actual API calls # Mock the Memory class methods to avoid actual API calls
with patch.object(Memory, '__init__', return_value=None): with patch.object(Memory, "__init__", return_value=None):
with patch.object(Memory, 'from_config') as mock_from_config: with patch.object(Memory, "from_config") as mock_from_config:
with patch.object(Memory, 'add') as mock_add: with patch.object(Memory, "add") as mock_add:
with patch.object(Memory, 'get_all') as mock_get_all: with patch.object(Memory, "get_all") as mock_get_all:
# Configure mocks # Configure mocks
mock_memory_instance = MagicMock() mock_memory_instance = MagicMock()
mock_from_config.return_value = mock_memory_instance mock_from_config.return_value = mock_memory_instance
mock_add.return_value = { mock_add.return_value = {
"results": [ "results": [
{"id": "1", "text": "Alex is a vegetarian"}, {"id": "1", "text": "Alex is a vegetarian"},
{"id": "2", "text": "Alex is allergic to nuts"} {"id": "2", "text": "Alex is allergic to nuts"},
] ]
} }
mock_get_all.return_value = [ mock_get_all.return_value = [
{"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}}, {"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}},
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}} {"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}},
] ]
# Test the workflow # Test the workflow
mem = Memory.from_config(config_dict=mock_config) mem = Memory.from_config(config_dict=mock_config)
assert mem is not None assert mem is not None
# Test adding memories # Test adding memories
result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"}) result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"})
assert "results" in result assert "results" in result
assert len(result["results"]) == 2 assert len(result["results"]) == 2
# Test retrieving memories # Test retrieving memories
all_memories = mock_get_all(user_id="alice") all_memories = mock_get_all(user_id="alice")
assert len(all_memories) == 2 assert len(all_memories) == 2
@@ -77,7 +79,7 @@ def test_memory_configuration_without_env_vars():
def test_azure_config_structure(): def test_azure_config_structure():
"""Test that Azure configuration structure is properly formatted""" """Test that Azure configuration structure is properly formatted"""
# Test Azure configuration structure (without actual credentials) # Test Azure configuration structure (without actual credentials)
azure_config = { azure_config = {
"llm": { "llm": {
@@ -91,8 +93,8 @@ def test_azure_config_structure():
"api_version": "2023-12-01-preview", "api_version": "2023-12-01-preview",
"azure_endpoint": "https://test.openai.azure.com/", "azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key", "api_key": "test-key",
} },
} },
}, },
"vector_store": { "vector_store": {
"provider": "azure_ai_search", "provider": "azure_ai_search",
@@ -101,7 +103,7 @@ def test_azure_config_structure():
"api_key": "test-key", "api_key": "test-key",
"collection_name": "test-collection", "collection_name": "test-collection",
"embedding_model_dims": 1536, "embedding_model_dims": 1536,
} },
}, },
"embedder": { "embedder": {
"provider": "azure_openai", "provider": "azure_openai",
@@ -113,46 +115,49 @@ def test_azure_config_structure():
"azure_deployment": "test-embedding-deployment", "azure_deployment": "test-embedding-deployment",
"azure_endpoint": "https://test.openai.azure.com/", "azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key", "api_key": "test-key",
} },
} },
} },
} }
# Validate configuration structure # Validate configuration structure
assert "llm" in azure_config assert "llm" in azure_config
assert "vector_store" in azure_config assert "vector_store" in azure_config
assert "embedder" in azure_config assert "embedder" in azure_config
# Validate Azure-specific configurations # Validate Azure-specific configurations
assert azure_config["llm"]["provider"] == "azure_openai" assert azure_config["llm"]["provider"] == "azure_openai"
assert "azure_kwargs" in azure_config["llm"]["config"] assert "azure_kwargs" in azure_config["llm"]["config"]
assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"] assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"]
assert azure_config["vector_store"]["provider"] == "azure_ai_search" assert azure_config["vector_store"]["provider"] == "azure_ai_search"
assert "service_name" in azure_config["vector_store"]["config"] assert "service_name" in azure_config["vector_store"]["config"]
assert azure_config["embedder"]["provider"] == "azure_openai" assert azure_config["embedder"]["provider"] == "azure_openai"
assert "azure_kwargs" in azure_config["embedder"]["config"] assert "azure_kwargs" in azure_config["embedder"]["config"]
def test_memory_messages_format(): def test_memory_messages_format():
"""Test that memory messages are properly formatted""" """Test that memory messages are properly formatted"""
# Test message format from main.py # Test message format from main.py
messages = [ messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} {
"role": "assistant",
"content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions.",
},
] ]
# Validate message structure # Validate message structure
assert len(messages) == 2 assert len(messages) == 2
assert all("role" in msg for msg in messages) assert all("role" in msg for msg in messages)
assert all("content" in msg for msg in messages) assert all("content" in msg for msg in messages)
# Validate roles # Validate roles
assert messages[0]["role"] == "user" assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant" assert messages[1]["role"] == "assistant"
# Validate content # Validate content
assert "vegetarian" in messages[0]["content"].lower() assert "vegetarian" in messages[0]["content"].lower()
assert "allergic to nuts" in messages[0]["content"].lower() assert "allergic to nuts" in messages[0]["content"].lower()
@@ -162,12 +167,12 @@ def test_memory_messages_format():
def test_safe_update_prompt_constant(): def test_safe_update_prompt_constant():
"""Test the SAFE_UPDATE_PROMPT constant from main.py""" """Test the SAFE_UPDATE_PROMPT constant from main.py"""
SAFE_UPDATE_PROMPT = """ SAFE_UPDATE_PROMPT = """
Based on the user's latest messages, what new preference can be inferred? Based on the user's latest messages, what new preference can be inferred?
Reply only in this json_object format: Reply only in this json_object format:
""" """
# Validate prompt structure # Validate prompt structure
assert isinstance(SAFE_UPDATE_PROMPT, str) assert isinstance(SAFE_UPDATE_PROMPT, str)
assert "user's latest messages" in SAFE_UPDATE_PROMPT assert "user's latest messages" in SAFE_UPDATE_PROMPT