Fix failing tests (#3162)
This commit is contained in:
@@ -2,16 +2,15 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from mem0.client.project import AsyncProject, Project
|
||||
from mem0.client.utils import api_error_handler
|
||||
from mem0.memory.setup import get_user_id, setup_config
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
from mem0.client.project import Project, AsyncProject
|
||||
from mem0.client.utils import api_error_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -562,7 +561,9 @@ class MemoryClient:
|
||||
APIError: If the API request fails.
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.")
|
||||
logger.warning(
|
||||
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
|
||||
)
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to access instructions or categories")
|
||||
|
||||
@@ -604,7 +605,9 @@ class MemoryClient:
|
||||
APIError: If the API request fails.
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.")
|
||||
logger.warning(
|
||||
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
|
||||
)
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
|
||||
@@ -1330,7 +1333,9 @@ class AsyncMemoryClient:
|
||||
APIError: If the API request fails.
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.")
|
||||
logger.warning(
|
||||
"get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead."
|
||||
)
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to access instructions or categories")
|
||||
|
||||
@@ -1368,7 +1373,9 @@ class AsyncMemoryClient:
|
||||
APIError: If the API request fails.
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.")
|
||||
logger.warning(
|
||||
"update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead."
|
||||
)
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
from mem0.client.utils import api_error_handler
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,18 +16,9 @@ class ProjectConfig(BaseModel):
|
||||
Configuration for project management operations.
|
||||
"""
|
||||
|
||||
org_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Organization ID"
|
||||
)
|
||||
project_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Project ID"
|
||||
)
|
||||
user_email: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User email"
|
||||
)
|
||||
org_id: Optional[str] = Field(default=None, description="Organization ID")
|
||||
project_id: Optional[str] = Field(default=None, description="Project ID")
|
||||
user_email: Optional[str] = Field(default=None, description="User email")
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
@@ -64,11 +55,7 @@ class BaseProject(ABC):
|
||||
self.config = config
|
||||
else:
|
||||
# Create config from parameters
|
||||
self.config = ProjectConfig(
|
||||
org_id=org_id,
|
||||
project_id=project_id,
|
||||
user_email=user_email
|
||||
)
|
||||
self.config = ProjectConfig(org_id=org_id, project_id=project_id, user_email=user_email)
|
||||
|
||||
@property
|
||||
def org_id(self) -> Optional[str]:
|
||||
@@ -93,13 +80,9 @@ class BaseProject(ABC):
|
||||
ValueError: If org_id or project_id are not set.
|
||||
"""
|
||||
if not (self.config.org_id and self.config.project_id):
|
||||
raise ValueError(
|
||||
"org_id and project_id must be set to access project operations"
|
||||
)
|
||||
raise ValueError("org_id and project_id must be set to access project operations")
|
||||
|
||||
def _prepare_params(
|
||||
self, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare query parameters for API requests.
|
||||
|
||||
@@ -124,9 +107,7 @@ class BaseProject(ABC):
|
||||
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
def _prepare_org_params(
|
||||
self, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_org_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare query parameters for organization-level API requests.
|
||||
|
||||
@@ -423,7 +404,7 @@ class Project(BaseProject):
|
||||
"custom_instructions": custom_instructions,
|
||||
"custom_categories": custom_categories,
|
||||
"retrieval_criteria": retrieval_criteria,
|
||||
"enable_graph": enable_graph
|
||||
"enable_graph": enable_graph,
|
||||
}
|
||||
)
|
||||
response = self._client.patch(
|
||||
@@ -716,7 +697,7 @@ class AsyncProject(BaseProject):
|
||||
"custom_instructions": custom_instructions,
|
||||
"custom_categories": custom_categories,
|
||||
"retrieval_criteria": retrieval_criteria,
|
||||
"enable_graph": enable_graph
|
||||
"enable_graph": enable_graph,
|
||||
}
|
||||
)
|
||||
response = await self._client.patch(
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""Exception raised for errors in the API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -82,11 +82,11 @@ class AzureOpenAILLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
user_prompt = messages[-1]['content']
|
||||
user_prompt = messages[-1]["content"]
|
||||
|
||||
user_prompt = user_prompt.replace("assistant", "ai")
|
||||
|
||||
messages[-1]['content'] = user_prompt
|
||||
messages[-1]["content"] = user_prompt
|
||||
|
||||
common_params = {
|
||||
"model": self.config.model,
|
||||
|
||||
@@ -49,11 +49,11 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
user_prompt = messages[-1]['content']
|
||||
user_prompt = messages[-1]["content"]
|
||||
|
||||
user_prompt = user_prompt.replace("assistant", "ai")
|
||||
|
||||
messages[-1]['content'] = user_prompt
|
||||
messages[-1]["content"] = user_prompt
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
Reference in New Issue
Block a user