Add organizations/projects support (#1857)

This commit is contained in:
Dev Khant
2024-09-20 10:51:02 +05:30
committed by GitHub
parent 6102aa76bb
commit db5cb1986a
2 changed files with 23 additions and 5 deletions

View File

@@ -33,8 +33,8 @@ Example with the mem0 Python package:
from mem0 import MemoryClient from mem0 import MemoryClient
client = MemoryClient( client = MemoryClient(
organization_name='YOUR_ORG_NAME', organization='YOUR_ORG_NAME',
project_name='YOUR_PROJECT_NAME', project='YOUR_PROJECT_NAME',
) )
``` ```

View File

@@ -55,19 +55,29 @@ class MemoryClient:
client (httpx.Client): The HTTP client used for making API requests. client (httpx.Client): The HTTP client used for making API requests.
""" """
def __init__(self, api_key: Optional[str] = None, host: Optional[str] = None): def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None
):
"""Initialize the MemoryClient. """Initialize the MemoryClient.
Args: Args:
api_key: The API key for authenticating with the Mem0 API. If not provided, api_key: The API key for authenticating with the Mem0 API. If not provided,
it will attempt to use the MEM0_API_KEY environment variable. it will attempt to use the MEM0_API_KEY environment variable.
host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai". host: The base URL for the Mem0 API. Defaults to "https://api.mem0.ai".
org_name: The name of the organization. Optional.
project_name: The name of the project. Optional.
Raises: Raises:
ValueError: If no API key is provided or found in the environment. ValueError: If no API key is provided or found in the environment.
""" """
self.api_key = api_key or os.getenv("MEM0_API_KEY") self.api_key = api_key or os.getenv("MEM0_API_KEY")
self.host = host or "https://api.mem0.ai" self.host = host or "https://api.mem0.ai"
self.organization = organization
self.project = project
self.user_id = get_user_id() self.user_id = get_user_id()
if not self.api_key: if not self.api_key:
@@ -103,6 +113,7 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
kwargs.update({"org_name": self.organization, "project_name": self.project})
payload = self._prepare_payload(messages, kwargs) payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/v1/memories/", json=payload) response = self.client.post("/v1/memories/", json=payload)
response.raise_for_status() response.raise_for_status()
@@ -140,6 +151,7 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
response = self.client.get("/v1/memories/", params=params) response = self.client.get("/v1/memories/", params=params)
response.raise_for_status() response.raise_for_status()
@@ -166,6 +178,7 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
""" """
payload = {"query": query} payload = {"query": query}
kwargs.update({"org_name": self.organization, "project_name": self.project})
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update({k: v for k, v in kwargs.items() if v is not None})
response = self.client.post(f"/{version}/memories/search/", json=payload) response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status() response.raise_for_status()
@@ -217,6 +230,7 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params) response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status() response.raise_for_status()
@@ -244,7 +258,8 @@ class MemoryClient:
@api_error_handler @api_error_handler
def users(self): def users(self):
"""Get all users, agents, and sessions for which memories exist.""" """Get all users, agents, and sessions for which memories exist."""
response = self.client.get("/v1/entities/") params = {"org_name": self.organization, "project_name": self.project}
response = self.client.get("/v1/entities/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.users", self) capture_client_event("client.users", self)
return response.json() return response.json()
@@ -252,9 +267,12 @@ class MemoryClient:
@api_error_handler @api_error_handler
def delete_users(self) -> Dict[str, str]: def delete_users(self) -> Dict[str, str]:
"""Delete all users, agents, or sessions.""" """Delete all users, agents, or sessions."""
params = {"org_name": self.organization, "project_name": self.project}
entities = self.users() entities = self.users()
for entity in entities["results"]: for entity in entities["results"]:
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/") response = self.client.delete(
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete_users", self) capture_client_event("client.delete_users", self)