Fix failing tests (#3162)

This commit is contained in:
Parshva Daftari
2025-07-25 00:58:45 +05:30
committed by GitHub
parent 37ee3c5eb2
commit 4433666117
11 changed files with 144 additions and 169 deletions

View File

@@ -2,16 +2,15 @@ import hashlib
import logging
import os
import warnings
from functools import wraps
from typing import Any, Dict, List, Optional
import httpx
import requests
from mem0.client.project import AsyncProject, Project
from mem0.client.utils import api_error_handler
from mem0.memory.setup import get_user_id, setup_config
from mem0.memory.telemetry import capture_client_event
from mem0.client.project import Project, AsyncProject
from mem0.client.utils import api_error_handler
logger = logging.getLogger(__name__)
@@ -562,7 +561,9 @@ class MemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.")
logger.warning(
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories")
@@ -604,7 +605,9 @@ class MemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.")
logger.warning(
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")
@@ -1330,7 +1333,9 @@ class AsyncMemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.")
logger.warning(
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories")
@@ -1368,7 +1373,9 @@ class AsyncMemoryClient:
APIError: If the API request fails.
ValueError: If org_id or project_id are not set.
"""
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.")
logger.warning(
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
)
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")

View File

@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional
import httpx
from pydantic import BaseModel, Field
from mem0.memory.telemetry import capture_client_event
from mem0.client.utils import api_error_handler
from mem0.memory.telemetry import capture_client_event
logger = logging.getLogger(__name__)
@@ -16,18 +16,9 @@ class ProjectConfig(BaseModel):
Configuration for project management operations.
"""
org_id: Optional[str] = Field(
default=None,
description="Organization ID"
)
project_id: Optional[str] = Field(
default=None,
description="Project ID"
)
user_email: Optional[str] = Field(
default=None,
description="User email"
)
org_id: Optional[str] = Field(default=None, description="Organization ID")
project_id: Optional[str] = Field(default=None, description="Project ID")
user_email: Optional[str] = Field(default=None, description="User email")
class Config:
validate_assignment = True
@@ -64,11 +55,7 @@ class BaseProject(ABC):
self.config = config
else:
# Create config from parameters
self.config = ProjectConfig(
org_id=org_id,
project_id=project_id,
user_email=user_email
)
self.config = ProjectConfig(org_id=org_id, project_id=project_id, user_email=user_email)
@property
def org_id(self) -> Optional[str]:
@@ -93,13 +80,9 @@ class BaseProject(ABC):
ValueError: If org_id or project_id are not set.
"""
if not (self.config.org_id and self.config.project_id):
raise ValueError(
"org_id and project_id must be set to access project operations"
)
raise ValueError("org_id and project_id must be set to access project operations")
def _prepare_params(
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Prepare query parameters for API requests.
@@ -124,9 +107,7 @@ class BaseProject(ABC):
return {k: v for k, v in kwargs.items() if v is not None}
def _prepare_org_params(
self, kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
def _prepare_org_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Prepare query parameters for organization-level API requests.
@@ -423,7 +404,7 @@ class Project(BaseProject):
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"enable_graph": enable_graph
"enable_graph": enable_graph,
}
)
response = self._client.patch(
@@ -716,7 +697,7 @@ class AsyncProject(BaseProject):
"custom_instructions": custom_instructions,
"custom_categories": custom_categories,
"retrieval_criteria": retrieval_criteria,
"enable_graph": enable_graph
"enable_graph": enable_graph,
}
)
response = await self._client.patch(

View File

@@ -1,10 +1,13 @@
import httpx
import logging
import httpx
logger = logging.getLogger(__name__)
class APIError(Exception):
"""Exception raised for errors in the API."""
pass

View File

@@ -82,11 +82,11 @@ class AzureOpenAILLM(LLMBase):
str: The generated response.
"""
user_prompt = messages[-1]['content']
user_prompt = messages[-1]["content"]
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
messages[-1]["content"] = user_prompt
common_params = {
"model": self.config.model,

View File

@@ -49,11 +49,11 @@ class AzureOpenAIStructuredLLM(LLMBase):
str: The generated response.
"""
user_prompt = messages[-1]['content']
user_prompt = messages[-1]["content"]
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
messages[-1]["content"] = user_prompt
params = {
"model": self.config.model,

View File

@@ -4,8 +4,6 @@ from typing import Dict, List, Optional
from openai import OpenAI
from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json