Update Client (#2640)

This commit is contained in:
Dev Khant
2025-05-08 00:09:43 +05:30
committed by GitHub
parent c01221d4aa
commit 326f33757b

View File

@@ -6,6 +6,7 @@ from functools import wraps
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import httpx import httpx
import requests
from mem0.memory.setup import get_user_id, setup_config from mem0.memory.setup import get_user_id, setup_config
from mem0.memory.telemetry import capture_client_event from mem0.memory.telemetry import capture_client_event
@@ -62,6 +63,7 @@ class MemoryClient:
host: Optional[str] = None, host: Optional[str] = None,
org_id: Optional[str] = None, org_id: Optional[str] = None,
project_id: Optional[str] = None, project_id: Optional[str] = None,
client: Optional[httpx.Client] = None,
): ):
"""Initialize the MemoryClient. """Initialize the MemoryClient.
@@ -71,6 +73,8 @@ class MemoryClient:
host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai". host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai".
org_id: The ID of the organization. org_id: The ID of the organization.
project_id: The ID of the project. project_id: The ID of the project.
client: A custom httpx.Client instance. If provided, it will be used instead of creating a new one.
Note that base_url and headers will be set/overridden as needed.
Raises: Raises:
ValueError: If no API key is provided or found in the environment. ValueError: If no API key is provided or found in the environment.
@@ -87,11 +91,20 @@ class MemoryClient:
# Create MD5 hash of API key for user_id # Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
self.client = httpx.Client( if client is not None:
base_url=self.host, self.client = client
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, # Ensure the client has the correct base_url and headers
timeout=300, self.client.base_url = httpx.URL(self.host)
) self.client.headers.update({
"Authorization": f"Token {self.api_key}",
"Mem0-User-ID": self.user_id
})
else:
self.client = httpx.Client(
base_url=self.host,
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},
timeout=300,
)
self.user_email = self._validate_api_key() self.user_email = self._validate_api_key()
capture_client_event("client.init", self, {"sync_type": "sync"}) capture_client_event("client.init", self, {"sync_type": "sync"})
@@ -696,10 +709,6 @@ class AsyncMemoryClient:
This class provides asynchronous versions of all MemoryClient methods. This class provides asynchronous versions of all MemoryClient methods.
It uses httpx.AsyncClient for making non-blocking API requests. It uses httpx.AsyncClient for making non-blocking API requests.
Attributes:
sync_client (MemoryClient): Underlying synchronous client instance.
async_client (httpx.AsyncClient): Async HTTP client for making API requests.
""" """
def __init__( def __init__(
@@ -708,13 +717,121 @@ class AsyncMemoryClient:
host: Optional[str] = None, host: Optional[str] = None,
org_id: Optional[str] = None, org_id: Optional[str] = None,
project_id: Optional[str] = None, project_id: Optional[str] = None,
client: Optional[httpx.AsyncClient] = None,
): ):
self.sync_client = MemoryClient(api_key, host, org_id, project_id) """Initialize the AsyncMemoryClient.
self.async_client = httpx.AsyncClient(
base_url=self.sync_client.host, Args:
headers=self.sync_client.client.headers, api_key: The API key for authenticating with the Mem0 API. If not provided,
timeout=300, it will attempt to use the MEM0_API_KEY environment variable.
) host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai".
org_id: The ID of the organization.
project_id: The ID of the project.
client: A custom httpx.AsyncClient instance. If provided, it will be used instead
of creating a new one. Note that base_url and headers will be set/overridden
as needed.
Raises:
ValueError: If no API key is provided or found in the environment.
"""
self.api_key = api_key or os.getenv("MEM0_API_KEY")
self.host = host or "https://api.mem0.ai"
self.org_id = org_id
self.project_id = project_id
self.user_id = get_user_id()
if not self.api_key:
raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
# Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
if client is not None:
self.async_client = client
# Ensure the client has the correct base_url and headers
self.async_client.base_url = httpx.URL(self.host)
self.async_client.headers.update({
"Authorization": f"Token {self.api_key}",
"Mem0-User-ID": self.user_id
})
else:
self.async_client = httpx.AsyncClient(
base_url=self.host,
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},
timeout=300,
)
self.user_email = self._validate_api_key()
capture_client_event("client.init", self, {"sync_type": "async"})
def _validate_api_key(self):
"""Validate the API key by making a test request."""
try:
params = self._prepare_params()
response = requests.get(f"{self.host}/v1/ping/", headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, params=params)
data = response.json()
response.raise_for_status()
if data.get("org_id") and data.get("project_id"):
self.org_id = data.get("org_id")
self.project_id = data.get("project_id")
return data.get("user_email")
except requests.HTTPStatusError as e:
try:
error_data = e.response.json()
error_message = error_data.get("detail", str(e))
except Exception:
error_message = str(e)
raise ValueError(f"Error: {error_message}")
def _prepare_payload(
self, messages: Union[str, List[Dict[str, str]], None], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepare the payload for API requests.
Args:
messages: The messages to include in the payload.
kwargs: Additional keyword arguments to include in the payload.
Returns:
A dictionary containing the prepared payload.
"""
payload = {}
if isinstance(messages, str):
payload["messages"] = [{"role": "user", "content": messages}]
elif isinstance(messages, list):
payload["messages"] = messages
payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Prepare query parameters for API requests.
Args:
kwargs: Keyword arguments to include in the parameters.
Returns:
A dictionary containing the prepared parameters.
Raises:
ValueError: If either org_id or project_id is provided but not both.
"""
if kwargs is None:
kwargs = {}
# Add org_id and project_id if both are available
if self.org_id and self.project_id:
kwargs["org_id"] = self.org_id
kwargs["project_id"] = self.project_id
elif self.org_id or self.project_id:
raise ValueError("Please provide both org_id and project_id")
return {k: v for k, v in kwargs.items() if v is not None}
async def __aenter__(self): async def __aenter__(self):
return self return self
@@ -724,89 +841,102 @@ class AsyncMemoryClient:
@api_error_handler @api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]: async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
kwargs = self.sync_client._prepare_params(kwargs) kwargs = self._prepare_params(kwargs)
payload = self.sync_client._prepare_payload(messages, kwargs) if kwargs.get("output_format") != "v1.1":
kwargs["output_format"] = "v1.1"
warnings.warn(
"output_format='v1.0' is deprecated therefore setting it to 'v1.1' by default."
"Check out the docs for more information: https://docs.mem0.ai/platform/quickstart#4-1-create-memories",
DeprecationWarning,
stacklevel=2,
)
kwargs["version"] = "v2"
payload = self._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload) response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event("client.add", self.sync_client, {"keys": list(kwargs.keys()), "sync_type": "async"}) capture_client_event("client.add", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]: async def get(self, memory_id: str) -> Dict[str, Any]:
params = self.sync_client._prepare_params() params = self._prepare_params()
response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params) response = await self.async_client.get(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) capture_client_event("client.get", self, {"memory_id": memory_id, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
params = self.sync_client._prepare_params(kwargs) params = self._prepare_params(kwargs)
if version == "v1": if version == "v1":
response = await self.async_client.get(f"/{version}/memories/", params=params) response = await self.async_client.get(f"/{version}/memories/", params=params)
elif version == "v2": elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params) if "page" in params and "page_size" in params:
query_params = {"page": params.pop("page"), "page_size": params.pop("page_size")}
response = await self.async_client.post(f"/{version}/memories/", json=params, params=query_params)
else:
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event( capture_client_event(
"client.get_all", self.sync_client, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"} "client.get_all", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"}
) )
return response.json() return response.json()
@api_error_handler @api_error_handler
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
payload = {"query": query} payload = {"query": query}
payload.update(self.sync_client._prepare_params(kwargs)) payload.update(self._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload) response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
del kwargs["metadata"] del kwargs["metadata"]
capture_client_event( capture_client_event(
"client.search", self.sync_client, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"} "client.search", self, {"api_version": version, "keys": list(kwargs.keys()), "sync_type": "async"}
) )
return response.json() return response.json()
@api_error_handler @api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]: async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
params = self.sync_client._prepare_params() params = self._prepare_params()
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data}, params=params) response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data}, params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.update", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) capture_client_event("client.update", self, {"memory_id": memory_id, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]: async def delete(self, memory_id: str) -> Dict[str, Any]:
params = self.sync_client._prepare_params() params = self._prepare_params()
response = await self.async_client.delete(f"/v1/memories/{memory_id}/", params=params) response = await self.async_client.delete(f"/v1/memories/{memory_id}/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) capture_client_event("client.delete", self, {"memory_id": memory_id, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]: async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs) params = self._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params) response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete_all", self.sync_client, {"keys": list(kwargs.keys()), "sync_type": "async"}) capture_client_event("client.delete_all", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]: async def history(self, memory_id: str) -> List[Dict[str, Any]]:
params = self.sync_client._prepare_params() params = self._prepare_params()
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/", params=params) response = await self.async_client.get(f"/v1/memories/{memory_id}/history/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.history", self.sync_client, {"memory_id": memory_id, "sync_type": "async"}) capture_client_event("client.history", self, {"memory_id": memory_id, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def users(self) -> Dict[str, Any]: async def users(self) -> Dict[str, Any]:
params = self.sync_client._prepare_params() params = self._prepare_params()
response = await self.async_client.get("/v1/entities/", params=params) response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.users", self.sync_client, {"sync_type": "async"}) capture_client_event("client.users", self, {"sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -848,7 +978,7 @@ class AsyncMemoryClient:
for entity in entities["results"] for entity in entities["results"]
] ]
params = self.sync_client._prepare_params() params = self._prepare_params()
if not to_delete: if not to_delete:
raise ValueError("No entities to delete") raise ValueError("No entities to delete")
@@ -858,7 +988,7 @@ class AsyncMemoryClient:
response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params) response = await self.async_client.delete(f"/v2/entities/{entity['type']}/{entity['name']}/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete_users", self.sync_client, {"sync_type": "async"}) capture_client_event("client.delete_users", self, {"user_id": user_id, "agent_id": agent_id, "app_id": app_id, "run_id": run_id, "sync_type": "async"})
return { return {
"message": "Entity deleted successfully." "message": "Entity deleted successfully."
if (user_id or agent_id or app_id or run_id) if (user_id or agent_id or app_id or run_id)
@@ -868,7 +998,7 @@ class AsyncMemoryClient:
@api_error_handler @api_error_handler
async def reset(self) -> Dict[str, str]: async def reset(self) -> Dict[str, str]:
await self.delete_users() await self.delete_users()
capture_client_event("client.reset", self.sync_client, {"sync_type": "async"}) capture_client_event("client.reset", self, {"sync_type": "async"})
return {"message": "Client reset successful. All users and memories deleted."} return {"message": "Client reset successful. All users and memories deleted."}
@api_error_handler @api_error_handler
@@ -889,7 +1019,7 @@ class AsyncMemoryClient:
response = await self.async_client.put("/v1/batch/", json={"memories": memories}) response = await self.async_client.put("/v1/batch/", json={"memories": memories})
response.raise_for_status() response.raise_for_status()
capture_client_event("client.batch_update", self.sync_client, {"sync_type": "async"}) capture_client_event("client.batch_update", self, {"sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -909,7 +1039,7 @@ class AsyncMemoryClient:
response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories}) response = await self.async_client.request("DELETE", "/v1/batch/", json={"memories": memories})
response.raise_for_status() response.raise_for_status()
capture_client_event("client.batch_delete", self.sync_client, {"sync_type": "async"}) capture_client_event("client.batch_delete", self, {"sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -926,7 +1056,7 @@ class AsyncMemoryClient:
response = await self.async_client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)}) response = await self.async_client.post("/v1/exports/", json={"schema": schema, **self._prepare_params(kwargs)})
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
"client.create_memory_export", self.sync_client, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "async"} "client.create_memory_export", self, {"schema": schema, "keys": list(kwargs.keys()), "sync_type": "async"}
) )
return response.json() return response.json()
@@ -942,21 +1072,21 @@ class AsyncMemoryClient:
""" """
response = await self.async_client.post("/v1/exports/get/", json=self._prepare_params(kwargs)) response = await self.async_client.post("/v1/exports/get/", json=self._prepare_params(kwargs))
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get_memory_export", self.sync_client, {"keys": list(kwargs.keys()), "sync_type": "async"}) capture_client_event("client.get_memory_export", self, {"keys": list(kwargs.keys()), "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]: async def get_project(self, fields: Optional[List[str]] = None) -> Dict[str, Any]:
if not (self.sync_client.org_id and self.sync_client.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to access instructions or categories") raise ValueError("org_id and project_id must be set to access instructions or categories")
params = self.sync_client._prepare_params({"fields": fields}) params = self._prepare_params({"fields": fields})
response = await self.async_client.get( response = await self.async_client.get(
f"/api/v1/orgs/organizations/{self.sync_client.org_id}/projects/{self.sync_client.project_id}/", f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
params=params, params=params,
) )
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get_project", self.sync_client, {"fields": fields, "sync_type": "async"}) capture_client_event("client.get_project", self, {"fields": fields, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -964,7 +1094,7 @@ class AsyncMemoryClient:
self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None, self, custom_instructions: Optional[str] = None, custom_categories: Optional[List[str]] = None,
retrieval_criteria: Optional[List[Dict[str, Any]]] = None retrieval_criteria: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if not (self.sync_client.org_id and self.sync_client.project_id): if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories") raise ValueError("org_id and project_id must be set to update instructions or categories")
if custom_instructions is None and custom_categories is None and retrieval_criteria is None: if custom_instructions is None and custom_categories is None and retrieval_criteria is None:
@@ -972,17 +1102,17 @@ class AsyncMemoryClient:
"Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them" "Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them"
) )
payload = self.sync_client._prepare_params( payload = self._prepare_params(
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria} {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria}
) )
response = await self.async_client.patch( response = await self.async_client.patch(
f"/api/v1/orgs/organizations/{self.sync_client.org_id}/projects/{self.sync_client.project_id}/", f"/api/v1/orgs/organizations/{self.org_id}/projects/{self.project_id}/",
json=payload, json=payload,
) )
response.raise_for_status() response.raise_for_status()
capture_client_event( capture_client_event(
"client.update_project", "client.update_project",
self.sync_client, self,
{"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"}, {"custom_instructions": custom_instructions, "custom_categories": custom_categories, "retrieval_criteria": retrieval_criteria, "sync_type": "async"},
) )
return response.json() return response.json()
@@ -996,7 +1126,7 @@ class AsyncMemoryClient:
f"api/v1/webhooks/projects/{project_id}/", f"api/v1/webhooks/projects/{project_id}/",
) )
response.raise_for_status() response.raise_for_status()
capture_client_event("client.get_webhook", self.sync_client, {"sync_type": "async"}) capture_client_event("client.get_webhook", self, {"sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1004,7 +1134,7 @@ class AsyncMemoryClient:
payload = {"url": url, "name": name, "event_types": event_types} payload = {"url": url, "name": name, "event_types": event_types}
response = await self.async_client.post(f"api/v1/webhooks/projects/{project_id}/", json=payload) response = await self.async_client.post(f"api/v1/webhooks/projects/{project_id}/", json=payload)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.create_webhook", self.sync_client, {"sync_type": "async"}) capture_client_event("client.create_webhook", self, {"sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1018,14 +1148,14 @@ class AsyncMemoryClient:
payload = {k: v for k, v in {"name": name, "url": url, "event_types": event_types}.items() if v is not None} payload = {k: v for k, v in {"name": name, "url": url, "event_types": event_types}.items() if v is not None}
response = await self.async_client.put(f"api/v1/webhooks/{webhook_id}/", json=payload) response = await self.async_client.put(f"api/v1/webhooks/{webhook_id}/", json=payload)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.update_webhook", self.sync_client, {"webhook_id": webhook_id, "sync_type": "async"}) capture_client_event("client.update_webhook", self, {"webhook_id": webhook_id, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
async def delete_webhook(self, webhook_id: int) -> Dict[str, str]: async def delete_webhook(self, webhook_id: int) -> Dict[str, str]:
response = await self.async_client.delete(f"api/v1/webhooks/{webhook_id}/") response = await self.async_client.delete(f"api/v1/webhooks/{webhook_id}/")
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete_webhook", self.sync_client, {"webhook_id": webhook_id, "sync_type": "async"}) capture_client_event("client.delete_webhook", self, {"webhook_id": webhook_id, "sync_type": "async"})
return response.json() return response.json()
@api_error_handler @api_error_handler
@@ -1042,5 +1172,6 @@ class AsyncMemoryClient:
response = await self.async_client.post("/v1/feedback/", json=data) response = await self.async_client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.feedback", self.sync_client, data, {"sync_type": "async"}) capture_client_event("client.feedback", self, data, {"sync_type": "async"})
return response.json() return response.json()