Support async client (#1980)

This commit is contained in:
Dev Khant
2024-10-22 12:42:55 +05:30
committed by GitHub
parent c5d298eec8
commit fbf1d8c372
11 changed files with 213 additions and 58 deletions

View File

@@ -56,12 +56,12 @@ class MemoryClient:
"""
def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None
):
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
"""Initialize the MemoryClient.
Args:
@@ -275,9 +275,7 @@ class MemoryClient:
params = {"org_name": self.organization, "project_name": self.project}
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
)
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("client.delete_users", self)
@@ -362,3 +360,120 @@ class MemoryClient:
kwargs["run_id"] = kwargs.pop("session_id")
return {k: v for k, v in kwargs.items() if v is not None}
class AsyncMemoryClient:
"""Asynchronous client for interacting with the Mem0 API."""
def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
self.sync_client = MemoryClient(api_key, host, organization, project)
self.async_client = httpx.AsyncClient(
base_url=self.sync_client.host,
headers=self.sync_client.client.headers,
timeout=60,
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.async_client.aclose()
@api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("async_client.add", self.sync_client)
return response.json()
@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.get", self.sync_client)
return response.json()
@api_error_handler
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
params = self.sync_client._prepare_params(kwargs)
if version == "v1":
response = await self.async_client.get(f"/{version}/memories/", params=params)
elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
)
return response.json()
@api_error_handler
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
payload = {"query": query}
payload.update(self.sync_client._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
return response.json()
@api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("async_client.update", self.sync_client)
return response.json()
@api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.delete", self.sync_client)
return response.json()
@api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
return response.json()
@api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("async_client.history", self.sync_client)
return response.json()
@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
return response.json()
@api_error_handler
async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
entities = await self.users()
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_users", self.sync_client)
return {"message": "All users, agents, and sessions deleted."}
@api_error_handler
async def reset(self) -> Dict[str, str]:
await self.delete_users()
capture_client_event("async_client.reset", self.sync_client)
return {"message": "Client reset successful. All users and memories deleted."}
async def chat(self):
raise NotImplementedError("Chat is not implemented yet")