[Mem0] Update platform client, improve deduction logic and update client docs (#1510)

This commit is contained in:
Deshraj Yadav
2024-07-20 02:44:33 -07:00
committed by GitHub
parent c27ab0585c
commit c7b9498693
3 changed files with 422 additions and 378 deletions

View File

@@ -1,8 +1,10 @@
import httpx
import os
import logging
import warnings
from typing import Optional, Dict, Any
import os
from functools import wraps
from typing import Any, Dict, List, Optional, Union
import httpx
from mem0.memory.setup import setup_config
from mem0.memory.telemetry import capture_client_event
@@ -12,232 +14,258 @@ logger = logging.getLogger(__name__)
setup_config()
class APIError(Exception):
"""Exception raised for errors in the API."""
pass
def api_error_handler(func):
"""Decorator to handle API errors consistently."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred: {e}")
raise APIError(f"API request failed: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"Request error occurred: {e}")
raise APIError(f"Request failed: {str(e)}")
return wrapper
class MemoryClient:
"""Client for interacting with the Mem0 API.
This class provides methods to create, retrieve, search, and delete memories
using the Mem0 API.
Attributes:
api_key (str): The API key for authenticating with the Mem0 API.
host (str): The base URL for the Mem0 API.
client (httpx.Client): The HTTP client used for making API requests.
"""
def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None):
"""
Initialize the Mem0 client.
"""Initialize the MemoryClient.
Args:
api_key (Optional[str]): API Key from Mem0 Platform. Defaults to environment variable 'MEM0_API_KEY' if not provided.
host (Optional[str]): API host URL. Defaults to 'https://api.mem0.ai/v1'.
api_key: The API key for authenticating with the Mem0 API. If not provided,
it will attempt to use the MEM0_API_KEY environment variable.
host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai/v1".
Raises:
ValueError: If no API key is provided or found in the environment.
"""
self.api_key = api_key or os.getenv("MEM0_API_KEY")
self.host = host or "https://api.mem0.ai/v1"
if not self.api_key:
raise ValueError("API Key not provided. Please provide an API Key.")
self.client = httpx.Client(
base_url=self.host,
headers={"Authorization": f"Token {self.api_key}"},
timeout=60,
)
self._validate_api_key()
capture_client_event("client.init", self)
def _validate_api_key(self):
if not self.api_key:
warnings.warn("API Key not provided. Please provide an API Key.")
response = self.client.get("/memories/", params={"user_id": "test"})
if response.status_code != 200:
"""Validate the API key by making a test request."""
try:
response = self.client.get("/memories/", params={"user_id": "test"})
response.raise_for_status()
except httpx.HTTPStatusError:
raise ValueError(
"Invalid API Key. Please get a valid API Key from https://app.mem0.ai"
)
@api_error_handler
def add(
self,
data: str,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
filters: Optional[Dict[str, Any]] = None,
self, messages: Union[str, List[Dict[str, str]]], **kwargs
) -> Dict[str, Any]:
"""
Create a new memory.
"""Add a new memory.
Args:
data (str): The data to be stored in the memory.
user_id (Optional[str]): User ID to save the memory specific to a user. Defaults to None.
agent_id (Optional[str]): Agent ID for agent-specific memory. Defaults to None.
session_id (Optional[str]): Run ID to save memory for a specific session. Defaults to None.
metadata (Optional[Dict[str, Any]]): Metadata to be saved with the memory. Defaults to None.
filters (Optional[Dict[str, Any]]): Filters to apply to the memory. Defaults to None.
messages: Either a string message or a list of message dictionaries.
**kwargs: Additional parameters such as user_id, agent_id, session_id, metadata, filters.
Returns:
Dict[str, Any]: The response from the server.
A dictionary containing the API response.
Raises:
APIError: If the API request fails.
"""
payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/memories/", json=payload)
response.raise_for_status()
capture_client_event("client.add", self)
payload = {"text": data}
if metadata:
payload["metadata"] = metadata
if filters:
payload["filters"] = filters
if user_id:
payload["user_id"] = user_id
if agent_id:
payload["agent_id"] = agent_id
if session_id:
payload["run_id"] = session_id
response = self.client.post("/memories/", json=payload, timeout=60)
if response.status_code != 200:
logger.error(response.json())
raise ValueError(f"Failed to add memory. Response: {response.json()}")
return response.json()
@api_error_handler
def get(self, memory_id: str) -> Dict[str, Any]:
"""
Get a memory by ID.
"""Retrieve a specific memory by ID.
Args:
memory_id (str): Memory ID.
memory_id: The ID of the memory to retrieve.
Returns:
Dict[str, Any]: The memory data.
A dictionary containing the memory data.
Raises:
APIError: If the API request fails.
"""
capture_client_event("client.get", self)
response = self.client.get(f"/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("client.get", self)
return response.json()
def get_all(
self,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
session_id: Optional[str] = None,
limit: int = 100,
) -> Dict[str, Any]:
"""
Get all memories.
@api_error_handler
def get_all(self, **kwargs) -> Dict[str, Any]:
"""Retrieve all memories, with optional filtering.
Args:
user_id (Optional[str]): User ID to filter memories. Defaults to None.
agent_id (Optional[str]): Agent ID to filter memories. Defaults to None.
session_id (Optional[str]): Run ID to filter memories. Defaults to None.
limit (int): Number of memories to return. Defaults to 100.
**kwargs: Optional parameters for filtering (user_id, agent_id, session_id, limit).
Returns:
Dict[str, Any]: The list of memories.
A dictionary containing the list of memories.
Raises:
APIError: If the API request fails.
"""
params = {
"user_id": user_id,
"agent_id": agent_id,
"run_id": session_id,
"limit": limit,
}
response = self.client.get(
"/memories/", params={k: v for k, v in params.items() if v is not None}
)
params = self._prepare_params(kwargs)
response = self.client.get("/memories/", params=params)
response.raise_for_status()
capture_client_event(
"client.get_all", self, {"filters": len(params), "limit": limit}
"client.get_all",
self,
{"filters": len(params), "limit": kwargs.get("limit", 100)},
)
return response.json()
def search(
self,
query: str,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
session_id: Optional[str] = None,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Search memories.
@api_error_handler
def search(self, query: str, **kwargs) -> Dict[str, Any]:
"""Search memories based on a query.
Args:
query (str): Query to search for in the memories.
user_id (Optional[str]): User ID to filter memories. Defaults to None.
agent_id (Optional[str]): Agent ID to filter memories. Defaults to None.
session_id (Optional[str]): Run ID to filter memories. Defaults to None.
limit (int): Number of memories to return. Defaults to 100.
filters (Optional[Dict[str, Any]]): Filters to apply to the search. Defaults to None.
query: The search query string.
**kwargs: Additional parameters such as user_id, agent_id, session_id, limit, filters.
Returns:
Dict[str, Any]: The search results.
A dictionary containing the search results.
Raises:
APIError: If the API request fails.
"""
payload = {
"text": query,
"limit": limit,
"filters": filters,
"user_id": user_id,
"agent_id": agent_id,
"run_id": session_id,
}
payload = {"query": query}
payload.update({k: v for k, v in kwargs.items() if v is not None})
response = self.client.post("/memories/search/", json=payload)
capture_client_event("client.search", self, {"limit": limit})
return response.json()
def update(self, memory_id: str, data: str) -> Dict[str, Any]:
"""
Update a memory by ID.
Args:
memory_id (str): Memory ID.
data (str): Data to update in the memory.
Returns:
Dict[str, Any]: The response from the server.
"""
capture_client_event("client.update", self)
response = self.client.put(f"/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("client.search", self, {"limit": kwargs.get("limit", 100)})
return response.json()
@api_error_handler
def delete(self, memory_id: str) -> Dict[str, Any]:
"""
Delete a memory by ID.
"""Delete a specific memory by ID.
Args:
memory_id (str): Memory ID.
memory_id: The ID of the memory to delete.
Returns:
Dict[str, Any]: The response from the server.
A dictionary containing the API response.
Raises:
APIError: If the API request fails.
"""
capture_client_event("client.delete", self)
response = self.client.delete(f"/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("client.delete", self)
return response.json()
def delete_all(
self,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Delete all memories.
@api_error_handler
def delete_all(self, **kwargs) -> Dict[str, Any]:
"""Delete all memories, with optional filtering.
Args:
user_id (Optional[str]): User ID to filter memories. Defaults to None.
agent_id (Optional[str]): Agent ID to filter memories. Defaults to None.
session_id (Optional[str]): Run ID to filter memories. Defaults to None.
**kwargs: Optional parameters for filtering (user_id, agent_id, session_id).
Returns:
Dict[str, Any]: The response from the server.
A dictionary containing the API response.
Raises:
APIError: If the API request fails.
"""
params = {"user_id": user_id, "agent_id": agent_id, "run_id": session_id}
response = self.client.delete(
"/memories/", params={k: v for k, v in params.items() if v is not None}
)
params = self._prepare_params(kwargs)
response = self.client.delete("/memories/", params=params)
response.raise_for_status()
capture_client_event("client.delete_all", self, {"params": len(params)})
return response.json()
@api_error_handler
def history(self, memory_id: str) -> Dict[str, Any]:
"""
Get history of a memory by ID.
"""Retrieve the history of a specific memory.
Args:
memory_id (str): Memory ID.
memory_id: The ID of the memory to retrieve history for.
Returns:
Dict[str, Any]: The memory history.
A dictionary containing the memory history.
Raises:
APIError: If the API request fails.
"""
response = self.client.get(f"/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("client.history", self)
return response.json()
def reset(self):
"""
Reset the client. (Not implemented yet)
"""Reset the client. (Not implemented)
Raises:
NotImplementedError: This method is not implemented yet.
"""
raise NotImplementedError("Reset is not implemented yet")
def chat(self):
"""
Start a chat with the Mem0 AI. (Not implemented yet)
"""Start a chat with the Mem0 AI. (Not implemented)
Raises:
NotImplementedError: This method is not implemented yet.
"""
raise NotImplementedError("Chat is not implemented yet")
def _prepare_payload(
self, messages: Union[str, List[Dict[str, str]], None], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepare the payload for API requests.
Args:
messages: The messages to include in the payload.
kwargs: Additional keyword arguments to include in the payload.
Returns:
A dictionary containing the prepared payload.
"""
payload = {}
if isinstance(messages, str):
payload["messages"] = [{"role": "user", "content": messages}]
elif isinstance(messages, list):
payload["messages"] = messages
payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload
def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare query parameters for API requests.
Args:
kwargs: Keyword arguments to include in the parameters.
Returns:
A dictionary containing the prepared parameters.
"""
return {k: v for k, v in kwargs.items() if v is not None}