Version bump and client fixes (#2017)

This commit is contained in:
Dev Khant
2024-11-07 11:36:56 +05:30
committed by GitHub
parent 549e5e3ce8
commit 3731965537
2 changed files with 15 additions and 12 deletions

View File

@@ -1,8 +1,8 @@
import logging
import os
import warnings
from functools import wraps
from typing import Any, Dict, List, Optional, Union
import warnings
import httpx
@@ -122,7 +122,7 @@ class MemoryClient:
Raises:
APIError: If the API request fails.
"""
kwargs.update({"org_name": self.organization, "project_name": self.project})
kwargs = self._prepare_params(kwargs)
payload = self._prepare_payload(messages, kwargs)
response = self.client.post("/v1/memories/", json=payload)
response.raise_for_status()
@@ -163,7 +163,6 @@ class MemoryClient:
Raises:
APIError: If the API request fails.
"""
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs)
if version == "v1":
response = self.client.get(f"/{version}/memories/", params=params)
@@ -195,8 +194,8 @@ class MemoryClient:
APIError: If the API request fails.
"""
payload = {"query": query}
kwargs.update({"org_name": self.organization, "project_name": self.project})
payload.update({k: v for k, v in kwargs.items() if v is not None})
params = self._prepare_params(kwargs)
payload.update(params)
response = self.client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
if "metadata" in kwargs:
@@ -250,7 +249,6 @@ class MemoryClient:
Raises:
APIError: If the API request fails.
"""
kwargs.update({"org_name": self.organization, "project_name": self.project})
params = self._prepare_params(kwargs)
response = self.client.delete("/v1/memories/", params=params)
response.raise_for_status()
@@ -278,7 +276,7 @@ class MemoryClient:
@api_error_handler
def users(self) -> Dict[str, Any]:
"""Get all users, agents, and sessions for which memories exist."""
params = {"org_name": self.organization, "project_name": self.project}
params = self._prepare_params()
response = self.client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("client.users", self)
@@ -287,7 +285,7 @@ class MemoryClient:
@api_error_handler
def delete_users(self) -> Dict[str, str]:
"""Delete all users, agents, or sessions."""
params = {"org_name": self.organization, "project_name": self.project}
params = self._prepare_params()
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
@@ -344,7 +342,7 @@ class MemoryClient:
payload.update({k: v for k, v in kwargs.items() if v is not None})
return payload
def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_params(self, kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Prepare query parameters for API requests.
Args:
@@ -356,6 +354,10 @@ class MemoryClient:
Raises:
ValueError: If both org_id/project_id and org_name/project_name are provided.
"""
if kwargs is None:
kwargs = {}
has_new = bool(self.org_id or self.project_id)
has_old = bool(self.organization or self.project)
@@ -414,6 +416,7 @@ class AsyncMemoryClient:
@api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
kwargs = self.sync_client._prepare_params(kwargs)
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
@@ -488,7 +491,7 @@ class AsyncMemoryClient:
@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
params = self.sync_client._prepare_params()
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
@@ -496,7 +499,7 @@ class AsyncMemoryClient:
@api_error_handler
async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
params = self.sync_client._prepare_params()
entities = await self.users()
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.1.28"
version = "0.1.29"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"]
exclude = [