From 70d6f9231b96d5271e7d5cc5ea5b575d27995bbd Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Tue, 8 Jul 2025 11:33:20 +0530 Subject: [PATCH] Abstraction for Project in MemoryClient (#3067) --- docs/api-reference.mdx | 122 +++ docs/platform/features/criteria-retrieval.mdx | 6 +- docs/platform/features/custom-categories.mdx | 8 +- .../platform/features/custom-instructions.mdx | 4 +- docs/platform/features/graph-memory.mdx | 2 +- mem0/client/main.py | 47 +- mem0/client/project.py | 879 ++++++++++++++++++ mem0/client/utils.py | 26 + mem0/llms/vllm.py | 2 + tests/embeddings/test_gemini_emeddings.py | 11 +- tests/llms/test_gemini.py | 3 + 11 files changed, 1071 insertions(+), 39 deletions(-) create mode 100644 mem0/client/project.py create mode 100644 mem0/client/utils.py diff --git a/docs/api-reference.mdx b/docs/api-reference.mdx index 2c93b1de..ba78b0bc 100644 --- a/docs/api-reference.mdx +++ b/docs/api-reference.mdx @@ -60,6 +60,128 @@ const client = new MemoryClient({organizationId: "YOUR_ORG_ID", projectId: "YOUR +### Project Management Methods + +The Mem0 client provides comprehensive project management capabilities through the `client.project` interface: + +#### Get Project Details + +Retrieve information about the current project: + +```python +# Get all project details +project_info = client.project.get() + +# Get specific fields only +project_info = client.project.get(fields=["name", "description", "custom_categories"]) +``` + +#### Create a New Project + +Create a new project within your organization: + +```python +# Create a project with name and description +new_project = client.project.create( + name="My New Project", + description="A project for managing customer support memories" +) +``` + +#### Update Project Settings + +Modify project configuration including custom instructions, categories, and graph settings: + +```python +# Update project with custom categories +client.project.update( + custom_categories=[ + {"customer_preferences": "Customer likes, dislikes, and preferences"}, + {"support_history": "Previous support interactions and resolutions"} + ] +) + +# Update project with custom instructions +client.project.update( + custom_instructions="..." +) + +# Enable graph memory for the project +client.project.update(enable_graph=True) + +# Update multiple settings at once +client.project.update( + custom_instructions="...", + custom_categories=[ + {"personal_info": "User personal information and preferences"}, + {"work_context": "Professional context and work-related information"} + ], + enable_graph=True +) +``` + +#### Delete Project + + +This action will remove all memories, messages, and other related data in the project. This operation is irreversible. + + +Remove a project and all its associated data: + +```python +# Delete the current project (irreversible) +result = client.project.delete() +``` + +#### Member Management + +Manage project members and their access levels: + +```python +# Get all project members +members = client.project.get_members() + +# Add a new member as a reader +client.project.add_member( + email="colleague@company.com", + role="READER" # or "OWNER" +) + +# Update a member's role +client.project.update_member( + email="colleague@company.com", + role="OWNER" +) + +# Remove a member from the project +client.project.remove_member(email="colleague@company.com") +``` + +#### Member Roles + +- **READER**: Can view and search memories, but cannot modify project settings or manage members +- **OWNER**: Full access including project modification, member management, and all reader permissions + +#### Async Support + +All project methods are also available in async mode: + +```python +from mem0 import AsyncMemoryClient + +async def manage_project(): + client = AsyncMemoryClient(org_id='YOUR_ORG_ID', project_id='YOUR_PROJECT_ID') + + # All methods support async/await + project_info = await client.project.get() + await client.project.update(enable_graph=True) + members = await client.project.get_members() + +# To call the async function properly +import asyncio +asyncio.run(manage_project()) +``` + ## Getting Started To begin using the Mem0 API, you'll need to: diff --git a/docs/platform/features/criteria-retrieval.mdx b/docs/platform/features/criteria-retrieval.mdx index 037547e8..28da1b6c 100644 --- a/docs/platform/features/criteria-retrieval.mdx +++ b/docs/platform/features/criteria-retrieval.mdx @@ -81,7 +81,7 @@ retrieval_criteria = [ Once defined, register the criteria to your project: ```python -client.update_project(retrieval_criteria=retrieval_criteria) +client.project.update(retrieval_criteria=retrieval_criteria) ``` Criteria apply project-wide. Once set, they affect all searches using `version="v2"`. @@ -187,7 +187,7 @@ If no criteria are defined for a project, `version="v2"` behaves like normal sea ## How It Works 1. **Criteria Definition**: Define custom criteria with a name, description, and weight. These describe what matters in a memory (e.g., joy, urgency, empathy). -2. **Project Configuration**: Register these criteria using `update_project()`. They apply at the project level and influence all searches using `version="v2"`. +2. **Project Configuration**: Register these criteria using `project.update()`. They apply at the project level and influence all searches using `version="v2"`. 3. **Memory Retrieval**: When you perform a search with `version="v2"`, Mem0 first retrieves relevant memories based on the query and your defined criteria. 4. **Weighted Scoring**: Each retrieved memory is evaluated and scored against the defined criteria and weights. @@ -202,7 +202,7 @@ Criteria retrieval is currently supported only in search v2. Make sure to use `v ## Summary - Define what “relevant” means using criteria -- Apply them per project via `update_project()` +- Apply them per project via `project.update()` - Use `version="v2"` to activate criteria-aware search - Build agents that reason not just with relevance, but **contextual importance** diff --git a/docs/platform/features/custom-categories.mdx b/docs/platform/features/custom-categories.mdx index e5480ba7..9de550c8 100644 --- a/docs/platform/features/custom-categories.mdx +++ b/docs/platform/features/custom-categories.mdx @@ -35,7 +35,7 @@ new_categories = [ {"personal_information": "Basic information about the user including name, preferences, and personality traits"} ] -response = client.update_project(custom_categories = new_categories) +response = client.project.update(custom_categories=new_categories) print(response) ``` @@ -75,7 +75,7 @@ You can also retrieve the current custom categories: ```python Code # Get current custom categories -categories = client.get_project(fields=["custom_categories"]) +categories = client.project.get(fields=["custom_categories"]) print(categories) ``` @@ -185,11 +185,11 @@ Name is Alice (personal_details) ``` -You can check whether default categories are being used by calling `get_project()`. If `custom_categories` returns `None`, it means the default categories are being used. +You can check whether default categories are being used by calling `project.get()`. If `custom_categories` returns `None`, it means the default categories are being used. ```python Code -client.get_project(["custom_categories"]) +client.project.get(["custom_categories"]) ``` ```json Output diff --git a/docs/platform/features/custom-instructions.mdx b/docs/platform/features/custom-instructions.mdx index e5b9b59b..25afaf5b 100644 --- a/docs/platform/features/custom-instructions.mdx +++ b/docs/platform/features/custom-instructions.mdx @@ -50,7 +50,7 @@ Guidelines: - Focus solely on health-related content. - Maintain clarity and context accuracy while recording. """ -response = client.update_project(custom_instructions=prompt) +response = client.project.update(custom_instructions=prompt) print(response) ``` @@ -66,7 +66,7 @@ You can also retrieve the current custom instructions: ```python Code # Retrieve current custom instructions -response = client.get_project(fields=["custom_instructions"]) +response = client.project.get(fields=["custom_instructions"]) print(response) ``` diff --git a/docs/platform/features/graph-memory.mdx b/docs/platform/features/graph-memory.mdx index c3d5e839..6b40e482 100644 --- a/docs/platform/features/graph-memory.mdx +++ b/docs/platform/features/graph-memory.mdx @@ -297,7 +297,7 @@ client = MemoryClient( ) # Enable graph memory for all operations in this project -client.update_project(enable_graph=True, version="v1") +client.project.update(enable_graph=True) # Now all add operations will use graph memory by default messages = [ diff --git a/mem0/client/main.py b/mem0/client/main.py index bccc0437..11ad7306 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -10,6 +10,8 @@ import requests from mem0.memory.setup import get_user_id, setup_config from mem0.memory.telemetry import capture_client_event +from mem0.client.project import Project, AsyncProject +from mem0.client.utils import api_error_handler logger = logging.getLogger(__name__) @@ -19,29 +21,6 @@ warnings.filterwarnings("default", category=DeprecationWarning) setup_config() -class APIError(Exception): - """Exception raised for errors in the API.""" - - pass - - -def api_error_handler(func): - """Decorator to handle API errors consistently.""" - - @wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except httpx.HTTPStatusError as e: - logger.error(f"HTTP error occurred: {e}") - raise APIError(f"API request failed: {e.response.text}") - except httpx.RequestError as e: - logger.error(f"Request error occurred: {e}") - raise APIError(f"Request failed: {str(e)}") - - return wrapper - - class MemoryClient: """Client for interacting with the Mem0 API. @@ -114,6 +93,15 @@ class MemoryClient: timeout=300, ) self.user_email = self._validate_api_key() + + # Initialize project manager + self.project = Project( + client=self.client, + org_id=self.org_id, + project_id=self.project_id, + user_email=self.user_email, + ) + capture_client_event("client.init", self, {"sync_type": "sync"}) def _validate_api_key(self): @@ -574,6 +562,7 @@ class MemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ + logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.") if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to access instructions or categories") @@ -615,6 +604,7 @@ class MemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ + logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.") if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to update instructions or categories") @@ -893,6 +883,15 @@ class AsyncMemoryClient: ) self.user_email = self._validate_api_key() + + # Initialize project manager + self.project = AsyncProject( + client=self.async_client, + org_id=self.org_id, + project_id=self.project_id, + user_email=self.user_email, + ) + capture_client_event("client.init", self, {"sync_type": "async"}) def _validate_api_key(self): @@ -1331,6 +1330,7 @@ class AsyncMemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ + logger.warning("get_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.get() method instead.") if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to access instructions or categories") @@ -1368,6 +1368,7 @@ class AsyncMemoryClient: APIError: If the API request fails. ValueError: If org_id or project_id are not set. """ + logger.warning("update_project() method is going to be deprecated in version v1.0 of the package. Please use the client.project.update() method instead.") if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to update instructions or categories") diff --git a/mem0/client/project.py b/mem0/client/project.py new file mode 100644 index 00000000..c113c3b4 --- /dev/null +++ b/mem0/client/project.py @@ -0,0 +1,879 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import httpx +from pydantic import BaseModel, Field + +from mem0.memory.telemetry import capture_client_event +from mem0.client.utils import api_error_handler + +logger = logging.getLogger(__name__) + + +class ProjectConfig(BaseModel): + """ + Configuration for project management operations. + """ + + org_id: Optional[str] = Field( + default=None, + description="Organization ID" + ) + project_id: Optional[str] = Field( + default=None, + description="Project ID" + ) + user_email: Optional[str] = Field( + default=None, + description="User email" + ) + + class Config: + validate_assignment = True + extra = "forbid" + + +class BaseProject(ABC): + """ + Abstract base class for project management operations. + """ + + def __init__( + self, + client: Any, + config: Optional[ProjectConfig] = None, + org_id: Optional[str] = None, + project_id: Optional[str] = None, + user_email: Optional[str] = None, + ): + """ + Initialize the project manager. + + Args: + client: HTTP client instance + config: Project manager configuration + org_id: Organization ID + project_id: Project ID + user_email: User email + """ + self._client = client + + # Handle config initialization + if config is not None: + self.config = config + else: + # Create config from parameters + self.config = ProjectConfig( + org_id=org_id, + project_id=project_id, + user_email=user_email + ) + + @property + def org_id(self) -> Optional[str]: + """Get the organization ID.""" + return self.config.org_id + + @property + def project_id(self) -> Optional[str]: + """Get the project ID.""" + return self.config.project_id + + @property + def user_email(self) -> Optional[str]: + """Get the user email.""" + return self.config.user_email + + def _validate_org_project(self) -> None: + """ + Validate that both org_id and project_id are set. + + Raises: + ValueError: If org_id or project_id are not set. + """ + if not (self.config.org_id and self.config.project_id): + raise ValueError( + "org_id and project_id must be set to access project operations" + ) + + def _prepare_params( + self, kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Prepare query parameters for API requests. + + Args: + kwargs: Additional keyword arguments. + + Returns: + Dictionary containing prepared parameters. + + Raises: + ValueError: If org_id or project_id validation fails. + """ + if kwargs is None: + kwargs = {} + + # Add org_id and project_id if available + if self.config.org_id and self.config.project_id: + kwargs["org_id"] = self.config.org_id + kwargs["project_id"] = self.config.project_id + elif self.config.org_id or self.config.project_id: + raise ValueError("Please provide both org_id and project_id") + + return {k: v for k, v in kwargs.items() if v is not None} + + def _prepare_org_params( + self, kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Prepare query parameters for organization-level API requests. + + Args: + kwargs: Additional keyword arguments. + + Returns: + Dictionary containing prepared parameters. + + Raises: + ValueError: If org_id is not provided. + """ + if kwargs is None: + kwargs = {} + + # Add org_id if available + if self.config.org_id: + kwargs["org_id"] = self.config.org_id + else: + raise ValueError("org_id must be set for organization-level operations") + + return {k: v for k, v in kwargs.items() if v is not None} + + @abstractmethod + def get(self, fields: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Get project details. + + Args: + fields: List of fields to retrieve + + Returns: + Dictionary containing the requested project fields. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + @abstractmethod + def create(self, name: str, description: Optional[str] = None) -> Dict[str, Any]: + """ + Create a new project within the organization. + + Args: + name: Name of the project to be created + description: Optional description for the project + + Returns: + Dictionary containing the created project details. + + Raises: + APIError: If the API request fails. + ValueError: If org_id is not set. + """ + pass + + @abstractmethod + def update( + self, + custom_instructions: Optional[str] = None, + custom_categories: Optional[List[str]] = None, + retrieval_criteria: Optional[List[Dict[str, Any]]] = None, + enable_graph: Optional[bool] = None, + ) -> Dict[str, Any]: + """ + Update project settings. + + Args: + custom_instructions: New instructions for the project + custom_categories: New categories for the project + retrieval_criteria: New retrieval criteria for the project + enable_graph: Enable or disable the graph for the project + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + @abstractmethod + def delete(self) -> Dict[str, Any]: + """ + Delete the current project and its related data. + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + @abstractmethod + def get_members(self) -> Dict[str, Any]: + """ + Get all members of the current project. + + Returns: + Dictionary containing the list of project members. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + @abstractmethod + def add_member(self, email: str, role: str = "READER") -> Dict[str, Any]: + """ + Add a new member to the current project. + + Args: + email: Email address of the user to add + role: Role to assign ("READER" or "OWNER") + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + @abstractmethod + def update_member(self, email: str, role: str) -> Dict[str, Any]: + """ + Update a member's role in the current project. + + Args: + email: Email address of the user to update + role: New role to assign ("READER" or "OWNER") + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + @abstractmethod + def remove_member(self, email: str) -> Dict[str, Any]: + """ + Remove a member from the current project. + + Args: + email: Email address of the user to remove + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + pass + + +class Project(BaseProject): + """ + Synchronous project management operations. + """ + + def __init__( + self, + client: httpx.Client, + config: Optional[ProjectConfig] = None, + org_id: Optional[str] = None, + project_id: Optional[str] = None, + user_email: Optional[str] = None, + ): + """ + Initialize the synchronous project manager. + + Args: + client: HTTP client instance + config: Project manager configuration + org_id: Organization ID + project_id: Project ID + user_email: User email + """ + super().__init__(client, config, org_id, project_id, user_email) + self._validate_org_project() + + @api_error_handler + def get(self, fields: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Get project details. + + Args: + fields: List of fields to retrieve + + Returns: + Dictionary containing the requested project fields. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + params = self._prepare_params({"fields": fields}) + response = self._client.get( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/", + params=params, + ) + response.raise_for_status() + capture_client_event( + "client.project.get", + self, + {"fields": fields, "sync_type": "sync"}, + ) + return response.json() + + @api_error_handler + def create(self, name: str, description: Optional[str] = None) -> Dict[str, Any]: + """ + Create a new project within the organization. + + Args: + name: Name of the project to be created + description: Optional description for the project + + Returns: + Dictionary containing the created project details. + + Raises: + APIError: If the API request fails. + ValueError: If org_id is not set. + """ + if not self.config.org_id: + raise ValueError("org_id must be set to create a project") + + payload = {"name": name} + if description is not None: + payload["description"] = description + + response = self._client.post( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.create", + self, + {"name": name, "description": description, "sync_type": "sync"}, + ) + return response.json() + + @api_error_handler + def update( + self, + custom_instructions: Optional[str] = None, + custom_categories: Optional[List[str]] = None, + retrieval_criteria: Optional[List[Dict[str, Any]]] = None, + enable_graph: Optional[bool] = None, + ) -> Dict[str, Any]: + """ + Update project settings. + + Args: + custom_instructions: New instructions for the project + custom_categories: New categories for the project + retrieval_criteria: New retrieval criteria for the project + enable_graph: Enable or disable the graph for the project + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + if ( + custom_instructions is None + and custom_categories is None + and retrieval_criteria is None + and enable_graph is None + ): + raise ValueError( + "At least one parameter must be provided for update: " + "custom_instructions, custom_categories, retrieval_criteria, " + "enable_graph" + ) + + payload = self._prepare_params( + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + "enable_graph": enable_graph + } + ) + response = self._client.patch( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.update", + self, + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + "enable_graph": enable_graph, + "sync_type": "sync", + }, + ) + return response.json() + + @api_error_handler + def delete(self) -> Dict[str, Any]: + """ + Delete the current project and its related data. + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + response = self._client.delete( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/", + ) + response.raise_for_status() + capture_client_event( + "client.project.delete", + self, + {"sync_type": "sync"}, + ) + return response.json() + + @api_error_handler + def get_members(self) -> Dict[str, Any]: + """ + Get all members of the current project. + + Returns: + Dictionary containing the list of project members. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + response = self._client.get( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + ) + response.raise_for_status() + capture_client_event( + "client.project.get_members", + self, + {"sync_type": "sync"}, + ) + return response.json() + + @api_error_handler + def add_member(self, email: str, role: str = "READER") -> Dict[str, Any]: + """ + Add a new member to the current project. + + Args: + email: Email address of the user to add + role: Role to assign ("READER" or "OWNER") + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + if role not in ["READER", "OWNER"]: + raise ValueError("Role must be either 'READER' or 'OWNER'") + + payload = {"email": email, "role": role} + + response = self._client.post( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.add_member", + self, + {"email": email, "role": role, "sync_type": "sync"}, + ) + return response.json() + + @api_error_handler + def update_member(self, email: str, role: str) -> Dict[str, Any]: + """ + Update a member's role in the current project. + + Args: + email: Email address of the user to update + role: New role to assign ("READER" or "OWNER") + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + if role not in ["READER", "OWNER"]: + raise ValueError("Role must be either 'READER' or 'OWNER'") + + payload = {"email": email, "role": role} + + response = self._client.put( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.update_member", + self, + {"email": email, "role": role, "sync_type": "sync"}, + ) + return response.json() + + @api_error_handler + def remove_member(self, email: str) -> Dict[str, Any]: + """ + Remove a member from the current project. + + Args: + email: Email address of the user to remove + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + params = {"email": email} + + response = self._client.delete( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + params=params, + ) + response.raise_for_status() + capture_client_event( + "client.project.remove_member", + self, + {"email": email, "sync_type": "sync"}, + ) + return response.json() + + +class AsyncProject(BaseProject): + """ + Asynchronous project management operations. + """ + + def __init__( + self, + client: httpx.AsyncClient, + config: Optional[ProjectConfig] = None, + org_id: Optional[str] = None, + project_id: Optional[str] = None, + user_email: Optional[str] = None, + ): + """ + Initialize the asynchronous project manager. + + Args: + client: HTTP client instance + config: Project manager configuration + org_id: Organization ID + project_id: Project ID + user_email: User email + """ + super().__init__(client, config, org_id, project_id, user_email) + self._validate_org_project() + + @api_error_handler + async def get(self, fields: Optional[List[str]] = None) -> Dict[str, Any]: + """ + Get project details. + + Args: + fields: List of fields to retrieve + + Returns: + Dictionary containing the requested project fields. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + params = self._prepare_params({"fields": fields}) + response = await self._client.get( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/", + params=params, + ) + response.raise_for_status() + capture_client_event( + "client.project.get", + self, + {"fields": fields, "sync_type": "async"}, + ) + return response.json() + + @api_error_handler + async def create(self, name: str, description: Optional[str] = None) -> Dict[str, Any]: + """ + Create a new project within the organization. + + Args: + name: Name of the project to be created + description: Optional description for the project + + Returns: + Dictionary containing the created project details. + + Raises: + APIError: If the API request fails. + ValueError: If org_id is not set. + """ + if not self.config.org_id: + raise ValueError("org_id must be set to create a project") + + payload = {"name": name} + if description is not None: + payload["description"] = description + + response = await self._client.post( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.create", + self, + {"name": name, "description": description, "sync_type": "async"}, + ) + return response.json() + + @api_error_handler + async def update( + self, + custom_instructions: Optional[str] = None, + custom_categories: Optional[List[str]] = None, + retrieval_criteria: Optional[List[Dict[str, Any]]] = None, + enable_graph: Optional[bool] = None, + ) -> Dict[str, Any]: + """ + Update project settings. + + Args: + custom_instructions: New instructions for the project + custom_categories: New categories for the project + retrieval_criteria: New retrieval criteria for the project + enable_graph: Enable or disable the graph for the project + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + if ( + custom_instructions is None + and custom_categories is None + and retrieval_criteria is None + and enable_graph is None + ): + raise ValueError( + "At least one parameter must be provided for update: " + "custom_instructions, custom_categories, retrieval_criteria, " + "enable_graph" + ) + + payload = self._prepare_params( + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + "enable_graph": enable_graph + } + ) + response = await self._client.patch( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.update", + self, + { + "custom_instructions": custom_instructions, + "custom_categories": custom_categories, + "retrieval_criteria": retrieval_criteria, + "enable_graph": enable_graph, + "sync_type": "async", + }, + ) + return response.json() + + @api_error_handler + async def delete(self) -> Dict[str, Any]: + """ + Delete the current project and its related data. + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + response = await self._client.delete( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/", + ) + response.raise_for_status() + capture_client_event( + "client.project.delete", + self, + {"sync_type": "async"}, + ) + return response.json() + + @api_error_handler + async def get_members(self) -> Dict[str, Any]: + """ + Get all members of the current project. + + Returns: + Dictionary containing the list of project members. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + response = await self._client.get( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + ) + response.raise_for_status() + capture_client_event( + "client.project.get_members", + self, + {"sync_type": "async"}, + ) + return response.json() + + @api_error_handler + async def add_member(self, email: str, role: str = "READER") -> Dict[str, Any]: + """ + Add a new member to the current project. + + Args: + email: Email address of the user to add + role: Role to assign ("READER" or "OWNER") + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + if role not in ["READER", "OWNER"]: + raise ValueError("Role must be either 'READER' or 'OWNER'") + + payload = {"email": email, "role": role} + + response = await self._client.post( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.add_member", + self, + {"email": email, "role": role, "sync_type": "async"}, + ) + return response.json() + + @api_error_handler + async def update_member(self, email: str, role: str) -> Dict[str, Any]: + """ + Update a member's role in the current project. + + Args: + email: Email address of the user to update + role: New role to assign ("READER" or "OWNER") + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + if role not in ["READER", "OWNER"]: + raise ValueError("Role must be either 'READER' or 'OWNER'") + + payload = {"email": email, "role": role} + + response = await self._client.put( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + json=payload, + ) + response.raise_for_status() + capture_client_event( + "client.project.update_member", + self, + {"email": email, "role": role, "sync_type": "async"}, + ) + return response.json() + + @api_error_handler + async def remove_member(self, email: str) -> Dict[str, Any]: + """ + Remove a member from the current project. + + Args: + email: Email address of the user to remove + + Returns: + Dictionary containing the API response. + + Raises: + APIError: If the API request fails. + ValueError: If org_id or project_id are not set. + """ + params = {"email": email} + + response = await self._client.delete( + f"/api/v1/orgs/organizations/{self.config.org_id}/projects/{self.config.project_id}/members/", + params=params, + ) + response.raise_for_status() + capture_client_event( + "client.project.remove_member", + self, + {"email": email, "sync_type": "async"}, + ) + return response.json() diff --git a/mem0/client/utils.py b/mem0/client/utils.py new file mode 100644 index 00000000..53632b19 --- /dev/null +++ b/mem0/client/utils.py @@ -0,0 +1,26 @@ +import httpx +import logging + +logger = logging.getLogger(__name__) + +class APIError(Exception): + """Exception raised for errors in the API.""" + pass + + +def api_error_handler(func): + """Decorator to handle API errors consistently.""" + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e}") + raise APIError(f"API request failed: {e.response.text}") + except httpx.RequestError as e: + logger.error(f"Request error occurred: {e}") + raise APIError(f"Request failed: {str(e)}") + + return wrapper diff --git a/mem0/llms/vllm.py b/mem0/llms/vllm.py index 6aa13add..efd9fe6a 100644 --- a/mem0/llms/vllm.py +++ b/mem0/llms/vllm.py @@ -4,6 +4,8 @@ from typing import Dict, List, Optional from openai import OpenAI +from openai import OpenAI + from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase from mem0.memory.utils import extract_json diff --git a/tests/embeddings/test_gemini_emeddings.py b/tests/embeddings/test_gemini_emeddings.py index dff834df..3f8c431b 100644 --- a/tests/embeddings/test_gemini_emeddings.py +++ b/tests/embeddings/test_gemini_emeddings.py @@ -38,14 +38,13 @@ def test_embed_returns_empty_list_if_none(mock_genai, config): mock_genai.return_value = type('Response', (), {'embeddings': [type('Embedding', (), {'values': []})]})() embedder = GoogleGenAIEmbedding(config) - result = embedder.embed("test") - - assert result == [] - mock_genai.assert_called_once() + + with pytest.raises(IndexError): # This will raise IndexError when trying to access [0] + embedder.embed("test") -def test_embed_raises_on_error(mock_genai, config): - mock_genai.side_effect = RuntimeError("Embedding failed") +def test_embed_raises_on_error(mock_genai_client, config): + mock_genai_client.models.embed_content.side_effect = RuntimeError("Embedding failed") embedder = GoogleGenAIEmbedding(config) diff --git a/tests/llms/test_gemini.py b/tests/llms/test_gemini.py index e28f7062..f64ec6a9 100644 --- a/tests/llms/test_gemini.py +++ b/tests/llms/test_gemini.py @@ -72,6 +72,9 @@ def test_generate_response_with_tools(mock_gemini_client: Mock): } ] + # Create a proper mock for the function call arguments + mock_args = {"data": "Today is a sunny day."} + mock_tool_call = Mock() mock_tool_call.name = "add_memory" mock_tool_call.args = {"data": "Today is a sunny day."}