Support async client (#1980)
This commit is contained in:
@@ -56,12 +56,12 @@ class MemoryClient:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the MemoryClient.
|
||||
|
||||
Args:
|
||||
@@ -275,9 +275,7 @@ class MemoryClient:
|
||||
params = {"org_name": self.organization, "project_name": self.project}
|
||||
entities = self.users()
|
||||
for entity in entities["results"]:
|
||||
response = self.client.delete(
|
||||
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
|
||||
)
|
||||
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event("client.delete_users", self)
|
||||
@@ -362,3 +360,120 @@ class MemoryClient:
|
||||
kwargs["run_id"] = kwargs.pop("session_id")
|
||||
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
|
||||
class AsyncMemoryClient:
|
||||
"""Asynchronous client for interacting with the Mem0 API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
):
|
||||
self.sync_client = MemoryClient(api_key, host, organization, project)
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.sync_client.host,
|
||||
headers=self.sync_client.client.headers,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.async_client.aclose()
|
||||
|
||||
@api_error_handler
|
||||
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
|
||||
payload = self.sync_client._prepare_payload(messages, kwargs)
|
||||
response = await self.async_client.post("/v1/memories/", json=payload)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.add", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get(self, memory_id: str) -> Dict[str, Any]:
|
||||
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.get", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
params = self.sync_client._prepare_params(kwargs)
|
||||
if version == "v1":
|
||||
response = await self.async_client.get(f"/{version}/memories/", params=params)
|
||||
elif version == "v2":
|
||||
response = await self.async_client.post(f"/{version}/memories/", json=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
payload = {"query": query}
|
||||
payload.update(self.sync_client._prepare_params(kwargs))
|
||||
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
|
||||
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.update", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def delete(self, memory_id: str) -> Dict[str, Any]:
|
||||
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.delete", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def delete_all(self, **kwargs) -> Dict[str, str]:
|
||||
params = self.sync_client._prepare_params(kwargs)
|
||||
response = await self.async_client.delete("/v1/memories/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
|
||||
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.history", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def users(self) -> Dict[str, Any]:
|
||||
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
|
||||
response = await self.async_client.get("/v1/entities/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.users", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def delete_users(self) -> Dict[str, str]:
|
||||
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
|
||||
entities = await self.users()
|
||||
for entity in entities["results"]:
|
||||
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.delete_users", self.sync_client)
|
||||
return {"message": "All users, agents, and sessions deleted."}
|
||||
|
||||
@api_error_handler
|
||||
async def reset(self) -> Dict[str, str]:
|
||||
await self.delete_users()
|
||||
capture_client_event("async_client.reset", self.sync_client)
|
||||
return {"message": "Client reset successful. All users and memories deleted."}
|
||||
|
||||
async def chat(self):
|
||||
raise NotImplementedError("Chat is not implemented yet")
|
||||
|
||||
Reference in New Issue
Block a user