Version bump and client fixes (#2017)

This commit is contained in:
Dev Khant
2024-11-07 11:36:56 +05:30
committed by GitHub
parent 549e5e3ce8
commit 3731965537
2 changed files with 15 additions and 12 deletions

View File

@@ -1,8 +1,8 @@
import logging import logging
import os import os
import warnings
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import warnings
import httpx import httpx
@@ -122,7 +122,7 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
kwargs.update({"org_name": self.organization, "project_name": self.project}) kwargs = self._prepare_params(kwargs)
payload = self._prepare_payload(messages, kwargs) payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/v1/memories/", json=payload) response = self.client.post("/v1/memories/", json=payload)
response.raise_for_status() response.raise_for_status()
@@ -163,7 +163,6 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
if version == "v1": if version == "v1":
response = self.client.get(f"/{version}/memories/", params=params) response = self.client.get(f"/{version}/memories/", params=params)
@@ -195,8 +194,8 @@ class MemoryClient:
APIError: If the API request fails. APIError: If the API request fails.
""" """
payload = {"query": query} payload = {"query": query}
kwargs.update({"org_name": self.organization, "project_name": self.project}) params = self._prepare_params(kwargs)
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update(params)
response = self.client.post(f"/{version}/memories/search/", json=payload) response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status() response.raise_for_status()
if "metadata" in kwargs: if "metadata" in kwargs:
@@ -250,7 +249,6 @@ class MemoryClient:
Raises: Raises:
APIError: If the API request fails. APIError: If the API request fails.
""" """
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs) params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params) response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status() response.raise_for_status()
@@ -278,7 +276,7 @@ class MemoryClient:
@api_error_handler @api_error_handler
def users(self) -> Dict[str, Any]: def users(self) -> Dict[str, Any]:
"""Get all users, agents, and sessions for which memories exist.""" """Get all users, agents, and sessions for which memories exist."""
params = {"org_name": self.organization, "project_name": self.project} params = self._prepare_params()
response = self.client.get("/v1/entities/", params=params) response = self.client.get("/v1/entities/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.users", self) capture_client_event("client.users", self)
@@ -287,7 +285,7 @@ class MemoryClient:
@api_error_handler @api_error_handler
def delete_users(self) -> Dict[str, str]: def delete_users(self) -> Dict[str, str]:
"""Delete all users, agents, or sessions.""" """Delete all users, agents, or sessions."""
params = {"org_name": self.organization, "project_name": self.project} params = self._prepare_params()
entities = self.users() entities = self.users()
for entity in entities["results"]: for entity in entities["results"]:
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params) response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
@@ -344,7 +342,7 @@ class MemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None}) payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload return payload
def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Prepare query parameters for API requests. """Prepare query parameters for API requests.
Args: Args:
@@ -356,6 +354,10 @@ class MemoryClient:
Raises: Raises:
ValueError: If both org_id/project_id and org_name/project_name are provided. ValueError: If both org_id/project_id and org_name/project_name are provided.
""" """
if kwargs is None:
kwargs = {}
has_new = bool(self.org_id or self.project_id) has_new = bool(self.org_id or self.project_id)
has_old = bool(self.organization or self.project) has_old = bool(self.organization or self.project)
@@ -414,6 +416,7 @@ class AsyncMemoryClient:
@api_error_handler @api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]: async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
kwargs = self.sync_client._prepare_params(kwargs)
payload = self.sync_client._prepare_payload(messages, kwargs) payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload) response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status() response.raise_for_status()
@@ -488,7 +491,7 @@ class AsyncMemoryClient:
@api_error_handler @api_error_handler
async def users(self) -> Dict[str, Any]: async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project} params = self.sync_client._prepare_params()
response = await self.async_client.get("/v1/entities/", params=params) response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status() response.raise_for_status()
capture_client_event("async_client.users", self.sync_client) capture_client_event("async_client.users", self.sync_client)
@@ -496,7 +499,7 @@ class AsyncMemoryClient:
@api_error_handler @api_error_handler
async def delete_users(self) -> Dict[str, str]: async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project} params = self.sync_client._prepare_params()
entities = await self.users() entities = await self.users()
for entity in entities["results"]: for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params) response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.28" version = "0.1.29"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [