diff --git a/mem0/client/main.py b/mem0/client/main.py index 35b25d10..891e64a0 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -6,6 +6,7 @@ from functools import wraps from typing import Any, Dict, List, Optional, Union import httpx +import requests from mem0.memory.setup import get_user_id, setup_config from mem0.memory.telemetry import capture_client_event @@ -62,6 +63,7 @@ class MemoryClient: host: Optional[str] = None, org_id: Optional[str] = None, project_id: Optional[str] = None, + client: Optional[httpx.Client] = None, ): """Initialize the MemoryClient. @@ -71,6 +73,8 @@ class MemoryClient: host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai". org_id: The ID of the organization. project_id: The ID of the project. + client: A custom httpx.Client instance. If provided, it will be used instead of creating a new one. + Note that base_url and headers will be set/overridden as needed. Raises: ValueError: If no API key is provided or found in the environment. @@ -87,11 +91,20 @@ class MemoryClient: # Create MD5 hash of API key for user_id self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() - self.client = httpx.Client( - base_url=self.host, - headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, - timeout=300, - ) + if client is not None: + self.client = client + # Ensure the client has the correct base_url and headers + self.client.base_url = httpx.URL(self.host) + self.client.headers.update({ + "Authorization": f"Token {self.api_key}", + "Mem0-User-ID": self.user_id + }) + else: + self.client = httpx.Client( + base_url=self.host, + headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, + timeout=300, + ) self.user_email = self._validate_api_key() capture_client_event("client.init", self, {"sync_type": "sync"}) @@ -696,10 +709,6 @@ class AsyncMemoryClient: This class provides asynchronous versions of all MemoryClient methods. It uses httpx.AsyncClient for making non-blocking API requests. - - Attributes: - sync_client (MemoryClient): Underlying synchronous client instance. - async_client (httpx.AsyncClient): Async HTTP client for making API requests. """ def __init__( @@ -708,13 +717,121 @@ class AsyncMemoryClient: host: Optional[str] = None, org_id: Optional[str] = None, project_id: Optional[str] = None, + client: Optional[httpx.AsyncClient] = None, ): - self.sync_client = MemoryClient(api_key, host, org_id, project_id) - self.async_client = httpx.AsyncClient( - base_url=self.sync_client.host, - headers=self.sync_client.client.headers, - timeout=300, - ) + """Initialize the AsyncMemoryClient. + + Args: + api_key: The API key for authenticating with the Mem0 API. If not provided, + it will attempt to use the MEM0_API_KEY environment variable. + host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai". + org_id: The ID of the organization. + project_id: The ID of the project. + client: A custom httpx.AsyncClient instance. If provided, it will be used instead + of creating a new one. Note that base_url and headers will be set/overridden + as needed. + + Raises: + ValueError: If no API key is provided or found in the environment. + """ + self.api_key = api_key or os.getenv("MEM0_API_KEY") + self.host = host or "https://api.mem0.ai" + self.org_id = org_id + self.project_id = project_id + self.user_id = get_user_id() + + if not self.api_key: + raise ValueError("Mem0 API Key not provided. Please provide an API Key.") + + # Create MD5 hash of API key for user_id + self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() + + if client is not None: + self.async_client = client + # Ensure the client has the correct base_url and headers + self.async_client.base_url = httpx.URL(self.host) + self.async_client.headers.update({ + "Authorization": f"Token {self.api_key}", + "Mem0-User-ID": self.user_id + }) + else: + self.async_client = httpx.AsyncClient( + base_url=self.host, + headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, + timeout=300, + ) + + self.user_email = self._validate_api_key() + capture_client_event("client.init", self, {"sync_type": "async"}) + + def _validate_api_key(self): + """Validate the API key by making a test request.""" + try: + params = self._prepare_params() + response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params) + data = response.json() + + response.raise_for_status() + + if data.get("org_id") and data.get("project_id"): + self.org_id = data.get("org_id") + self.project_id = data.get("project_id") + + return data.get("user_email") + + except requests.HTTPStatusError as e: + try: + error_data = e.response.json() + error_message = error_data.get("detail", str(e)) + except Exception: + error_message = str(e) + raise ValueError(f"Error: {error_message}") + + def _prepare_payload( + self, messages: Union[str, List[Dict[str, str]], None], kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + """Prepare the payload for API requests. + + Args: + messages: The messages to include in the payload. + kwargs: Additional keyword arguments to include in the payload. + + Returns: + A dictionary containing the prepared payload. + """ + payload = {} + if isinstance(messages, str): + payload["messages"] = [{"role": "user", "content": messages}] + elif isinstance(messages, list): + payload["messages"] = messages + + payload.update({k: v for k, v in kwargs.items() if v is not None}) + return payload + + def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Prepare query parameters for API requests. + + Args: + kwargs: Keyword arguments to include in the parameters. + + Returns: + A dictionary containing the prepared parameters. + + Raises: + ValueError: If either org_id or project_id is provided but not both. + """ + + if kwargs is None: + kwargs = {} + + # Add org_id and project_id if both are available + if self.org_id and self.project_id: + kwargs["org_id"] = self.org_id + kwargs["project_id"] = self.project_id + elif self.org_id or self.project_id: + raise ValueError("Please provide both org_id and project_id") + + return {k: v for k, v in kwargs.items() if v is not None} async def __aenter__(self): return self @@ -724,89 +841,102 @@ class AsyncMemoryClient: @api_error_handler async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]: - kwargs = self.sync_client._prepare_params(kwargs) - payload = self.sync_client._prepare_payload(messages, kwargs) + kwargs = self._prepare_params(kwargs) + if kwargs.get("output_format") != "v1.1": + kwargs["output_format"] = "v1.1" + warnings.warn( + "output_format='v1.0' is deprecated therefore setting it to 'v1.1' by default." + "Check out the docs for more information: https://docs.mem0.ai/platform/quickstart#4-1-create-memories", + DeprecationWarning, + stacklevel=2, + ) + kwargs["version"] = "v2" + payload = self._prepare_payload(messages, kwargs) response = await self.async_client.post("/v1/memories/", json=payload) response.raise_for_status() if "metadata" in kwargs: del kwargs["metadata"] - capture_client_event("client.add", self.sync_client, {"keys": list(kwargs.keys()), "sync_type": "async"}) + capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"}) return response.json() @api_error_handler async def get(self, memory_id: str) -> Dict[str, Any]: - params = self.sync_client._prepare_params() + params = self._prepare_params() response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params) response.raise_for_status() - capture_client_event("client.get", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) + capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"}) return response.json() @api_error_handler async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: - params = self.sync_client._prepare_params(kwargs) + params = self._prepare_params(kwargs) if version == "v1": response = await self.async_client.get(f"/{version}/memories/", params=params) elif version == "v2": - response = await self.async_client.post(f"/{version}/memories/", json=params) + if "page" in params and "page_size" in params: + query_params = {"page": params.pop("page"), "page_size": params.pop("page_size")} + response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params) + else: + response = await self.async_client.post(f"/{version}/memories/", json=params) response.raise_for_status() if "metadata" in kwargs: del kwargs["metadata"] capture_client_event( - "client.get_all", self.sync_client, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"} + "client.get_all", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"} ) return response.json() @api_error_handler async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: payload = {"query": query} - payload.update(self.sync_client._prepare_params(kwargs)) + payload.update(self._prepare_params(kwargs)) response = await self.async_client.post(f"/{version}/memories/search/", json=payload) response.raise_for_status() if "metadata" in kwargs: del kwargs["metadata"] capture_client_event( - "client.search", self.sync_client, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"} + "client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"} ) return response.json() @api_error_handler async def update(self, memory_id: str, data: str) -> Dict[str, Any]: - params = self.sync_client._prepare_params() + params = self._prepare_params() response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data}, params=params) response.raise_for_status() - capture_client_event("client.update", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) + capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "async"}) return response.json() @api_error_handler async def delete(self, memory_id: str) -> Dict[str, Any]: - params = self.sync_client._prepare_params() + params = self._prepare_params() response = await self.async_client.delete(f"/v1/memories/{memory_id}/", params=params) response.raise_for_status() - capture_client_event("client.delete", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) + capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "async"}) return response.json() @api_error_handler async def delete_all(self, **kwargs) -> Dict[str, str]: - params = self.sync_client._prepare_params(kwargs) + params = self._prepare_params(kwargs) response = await self.async_client.delete("/v1/memories/", params=params) response.raise_for_status() - capture_client_event("client.delete_all", self.sync_client, {"keys": list(kwargs.keys()), "sync_type": "async"}) + capture_client_event("client.delete_all", self, {"keys": list(kwargs.keys()), "sync_type": "async"}) return response.json() @api_error_handler async def history(self, memory_id: str) -> List[Dict[str, Any]]: - params = self.sync_client._prepare_params() + params = self._prepare_params() response = await self.async_client.get(f"/v1/memories/{memory_id}/history/", params=params) response.raise_for_status() - capture_client_event("client.history", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) + capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "async"}) return response.json() @api_error_handler async def users(self) -> Dict[str, Any]: - params = self.sync_client._prepare_params() + params = self._prepare_params() response = await self.async_client.get("/v1/entities/", params=params) response.raise_for_status() - capture_client_event("client.users", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.users", self, {"sync_type": "async"}) return response.json() @api_error_handler @@ -848,7 +978,7 @@ class AsyncMemoryClient: for entity in entities["results"] ] - params = self.sync_client._prepare_params() + params = self._prepare_params() if not to_delete: raise ValueError("No entities to delete") @@ -858,7 +988,7 @@ class AsyncMemoryClient: response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params) response.raise_for_status() - capture_client_event("client.delete_users", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"}) return { "message": "Entity deleted successfully." if (user_id or agent_id or app_id or run_id) @@ -868,7 +998,7 @@ class AsyncMemoryClient: @api_error_handler async def reset(self) -> Dict[str, str]: await self.delete_users() - capture_client_event("client.reset", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.reset", self, {"sync_type": "async"}) return {"message": "Client reset successful. All users and memories deleted."} @api_error_handler @@ -889,7 +1019,7 @@ class AsyncMemoryClient: response = await self.async_client.put("/v1/batch/", json={"memories": memories}) response.raise_for_status() - capture_client_event("client.batch_update", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.batch_update", self, {"sync_type": "async"}) return response.json() @api_error_handler @@ -909,7 +1039,7 @@ class AsyncMemoryClient: response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories}) response.raise_for_status() - capture_client_event("client.batch_delete", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.batch_delete", self, {"sync_type": "async"}) return response.json() @api_error_handler @@ -926,7 +1056,7 @@ class AsyncMemoryClient: response = await self.async_client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)}) response.raise_for_status() capture_client_event( - "client.create_memory_export", self.sync_client, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "async"} + "client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "async"} ) return response.json() @@ -942,21 +1072,21 @@ class AsyncMemoryClient: """ response = await self.async_client.post("/v1/exports/get/", json=self._prepare_params(kwargs)) response.raise_for_status() - capture_client_event("client.get_memory_export", self.sync_client, {"keys": list(kwargs.keys()), "sync_type": "async"}) + capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys()), "sync_type": "async"}) return response.json() @api_error_handler async def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]: - if not (self.sync_client.org_id and self.sync_client.project_id): + if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to access instructions or categories") - params = self.sync_client._prepare_params({"fields": fields}) + params = self._prepare_params({"fields": fields}) response = await self.async_client.get( - f"/api/v1/orgs/organizations/{self.sync_client.org_id}/projects/{self.sync_client.project_id}/", + f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/", params=params, ) response.raise_for_status() - capture_client_event("client.get_project", self.sync_client, {"fields": fields, "sync_type": "async"}) + capture_client_event("client.get_project", self, {"fields": fields, "sync_type": "async"}) return response.json() @api_error_handler @@ -964,7 +1094,7 @@ class AsyncMemoryClient: self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None, retrieval_criteria: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, Any]: - if not (self.sync_client.org_id and self.sync_client.project_id): + if not (self.org_id and self.project_id): raise ValueError("org_id and project_id must be set to update instructions or categories") if custom_instructions is None and custom_categories is None and retrieval_criteria is None: @@ -972,17 +1102,17 @@ class AsyncMemoryClient: "Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them" ) - payload = self.sync_client._prepare_params( + payload = self._prepare_params( {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria} ) response = await self.async_client.patch( - f"/api/v1/orgs/organizations/{self.sync_client.org_id}/projects/{self.sync_client.project_id}/", + f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/", json=payload, ) response.raise_for_status() capture_client_event( "client.update_project", - self.sync_client, + self, {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"}, ) return response.json() @@ -996,7 +1126,7 @@ class AsyncMemoryClient: f"api/v1/webhooks/projects/{project_id}/", ) response.raise_for_status() - capture_client_event("client.get_webhook", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.get_webhook", self, {"sync_type": "async"}) return response.json() @api_error_handler @@ -1004,7 +1134,7 @@ class AsyncMemoryClient: payload = {"url": url, "name": name, "event_types": event_types} response = await self.async_client.post(f"api/v1/webhooks/projects/{project_id}/", json=payload) response.raise_for_status() - capture_client_event("client.create_webhook", self.sync_client, {"sync_type": "async"}) + capture_client_event("client.create_webhook", self, {"sync_type": "async"}) return response.json() @api_error_handler @@ -1018,14 +1148,14 @@ class AsyncMemoryClient: payload = {k: v for k, v in {"name": name, "url": url, "event_types": event_types}.items() if v is not None} response = await self.async_client.put(f"api/v1/webhooks/{webhook_id}/", json=payload) response.raise_for_status() - capture_client_event("client.update_webhook", self.sync_client, {"webhook_id": webhook_id, "sync_type": "async"}) + capture_client_event("client.update_webhook", self, {"webhook_id": webhook_id, "sync_type": "async"}) return response.json() @api_error_handler async def delete_webhook(self, webhook_id: int) -> Dict[str, str]: response = await self.async_client.delete(f"api/v1/webhooks/{webhook_id}/") response.raise_for_status() - capture_client_event("client.delete_webhook", self.sync_client, {"webhook_id": webhook_id, "sync_type": "async"}) + capture_client_event("client.delete_webhook", self, {"webhook_id": webhook_id, "sync_type": "async"}) return response.json() @api_error_handler @@ -1042,5 +1172,6 @@ class AsyncMemoryClient: response = await self.async_client.post("/v1/feedback/", json=data) response.raise_for_status() - capture_client_event("client.feedback", self.sync_client, data, {"sync_type": "async"}) + capture_client_event("client.feedback", self, data, {"sync_type": "async"}) return response.json() +