Formatting and version bump -> 0.1.107 (#2927)

This commit is contained in:
Dev Khant
2025-06-07 12:27:22 +05:30
committed by GitHub
parent 9a12ea7b3c
commit e1dc27276b
4 changed files with 110 additions and 68 deletions

View File

@@ -526,7 +526,13 @@ class MemoryClient:
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")
if custom_instructions is None and custom_categories is None and retrieval_criteria is None and enable_graph is None and version is None:
if (
custom_instructions is None
and custom_categories is None
and retrieval_criteria is None
and enable_graph is None
and version is None
):
raise ValueError(
"Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them"
)
@@ -675,9 +681,7 @@ class MemoryClient:
capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
return response.json()
def _prepare_payload(
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare the payload for API requests.
Args:
@@ -803,9 +807,7 @@ class AsyncMemoryClient:
error_message = str(e)
raise ValueError(f"Error: {error_message}")
def _prepare_payload(
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare the payload for API requests.
Args:
@@ -1115,7 +1117,13 @@ class AsyncMemoryClient:
if not (self.org_id and self.project_id):
raise ValueError("org_id and project_id must be set to update instructions or categories")
if custom_instructions is None and custom_categories is None and retrieval_criteria is None and enable_graph is None and version is None:
if (
custom_instructions is None
and custom_categories is None
and retrieval_criteria is None
and enable_graph is None
and version is None
):
raise ValueError(
"Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them"
)

View File

@@ -18,22 +18,15 @@ class SarvamLLM(LLMBase):
if not self.api_key:
raise ValueError(
"Sarvam API key is required. Set SARVAM_API_KEY environment variable "
"or provide api_key in config."
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
)
# Set base URL - use config value or environment or default
self.base_url = (
getattr(self.config, 'sarvam_base_url', None) or
os.getenv("SARVAM_API_BASE") or
"https://api.sarvam.ai/v1"
getattr(self.config, "sarvam_base_url", None) or os.getenv("SARVAM_API_BASE") or "https://api.sarvam.ai/v1"
)
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None
) -> str:
def generate_response(self, messages: List[Dict[str, str]], response_format=None) -> str:
"""
Generate a response based on the given messages using Sarvam-M.
@@ -47,10 +40,7 @@ class SarvamLLM(LLMBase):
"""
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# Prepare the request payload
params = {
@@ -74,10 +64,7 @@ class SarvamLLM(LLMBase):
params["model"] = self.config.model.get("name", "sarvam-m")
# Add Sarvam-specific parameters
sarvam_specific_params = [
'reasoning_effort', 'frequency_penalty', 'presence_penalty',
'seed', 'stop', 'n'
]
sarvam_specific_params = ["reasoning_effort", "frequency_penalty", "presence_penalty", "seed", "stop", "n"]
for param in sarvam_specific_params:
if param in self.config.model:
@@ -89,8 +76,8 @@ class SarvamLLM(LLMBase):
result = response.json()
if 'choices' in result and len(result['choices']) > 0:
return result['choices'][0]['message']['content']
if "choices" in result and len(result["choices"]) > 0:
return result["choices"][0]["message"]["content"]
else:
raise ValueError("No response choices found in Sarvam API response")

View File

@@ -29,6 +29,7 @@ from mem0.memory.utils import (
parse_messages,
parse_vision_messages,
remove_code_blocks,
process_telemetry_filters,
)
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
@@ -45,13 +46,13 @@ def _build_filters_and_metadata(
"""
Constructs metadata for storage and filters for querying based on session and actor identifiers.
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
This helper supports multiple session identifiers (`user_id`, `agent_id`, and/or `run_id`)
for flexible session scoping and optionally narrows queries to a specific `actor_id`. It returns two dicts:
1. `base_metadata_template`: Used as a template for metadata when storing new memories.
It includes the primary session identifier(s) and any `input_metadata`.
2. `effective_query_filters`: Used for querying existing memories. It includes the
primary session identifier(s), any `input_filters`, and a resolved actor
It includes all provided session identifier(s) and any `input_metadata`.
2. `effective_query_filters`: Used for querying existing memories. It includes all
provided session identifier(s), any `input_filters`, and a resolved actor
identifier for targeted filtering if specified by any actor-related inputs.
Actor filtering precedence: explicit `actor_id` arg → `filters["actor_id"]`
@@ -59,11 +60,9 @@ def _build_filters_and_metadata(
as the actor for storage is typically derived from message content at a later stage.
Args:
user_id (Optional[str]): User identifier, primarily for Classic Mode session scoping.
agent_id (Optional[str]): Agent identifier, for Classic Mode session scoping or
as auxiliary information in Group Mode.
run_id (Optional[str]): Run identifier, for Classic Mode session scoping or
as auxiliary information in Group Mode.
user_id (Optional[str]): User identifier, for session scoping.
agent_id (Optional[str]): Agent identifier, for session scoping.
run_id (Optional[str]): Run identifier, for session scoping.
actor_id (Optional[str]): Explicit actor identifier, used as a potential source for
actor-specific filtering. See actor resolution precedence in the main description.
input_metadata (Optional[Dict[str, Any]]): Base dictionary to be augmented with
@@ -74,28 +73,34 @@ def _build_filters_and_metadata(
Returns:
tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
- base_metadata_template (Dict[str, Any]): Metadata template for storing memories,
scoped to the determined session.
scoped to the provided session(s).
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
scoped to the determined session and potentially a resolved actor.
scoped to the provided session(s) and potentially a resolved actor.
"""
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
effective_query_filters = deepcopy(input_filters) if input_filters else {}
# ---------- resolve session id (mandatory) ----------
session_key, session_val = None, None
# ---------- add all provided session ids ----------
session_ids_provided = []
if user_id:
session_key, session_val = "user_id", user_id
elif agent_id:
session_key, session_val = "agent_id", agent_id
elif run_id:
session_key, session_val = "run_id", run_id
base_metadata_template["user_id"] = user_id
effective_query_filters["user_id"] = user_id
session_ids_provided.append("user_id")
if session_key is None:
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
if agent_id:
base_metadata_template["agent_id"] = agent_id
effective_query_filters["agent_id"] = agent_id
session_ids_provided.append("agent_id")
base_metadata_template[session_key] = session_val
effective_query_filters[session_key] = session_val
if run_id:
base_metadata_template["run_id"] = run_id
effective_query_filters["run_id"] = run_id
session_ids_provided.append("run_id")
if not session_ids_provided:
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be provided.")
# ---------- optional actor filter ----------
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
@@ -433,10 +438,11 @@ class Memory(MemoryBase):
except Exception as e:
logging.error(f"Error iterating new_memories_with_actions: {e}")
keys, encoded_ids = process_telemetry_filters(filters)
capture_event(
"mem0.add",
self,
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "sync"},
{"version": self.api_version, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"},
)
return returned_memories
@@ -529,8 +535,9 @@ class Memory(MemoryBase):
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
"mem0.get_all", self, {"limit": limit, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"}
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -564,7 +571,9 @@ class Memory(MemoryBase):
def _get_all_from_vector_store(self, filters, limit):
memories_result = self.vector_store.list(filters=filters, limit=limit)
actual_memories = (
memories_result[0] if isinstance(memories_result, (tuple, list)) and len(memories_result) > 0 else memories_result
memories_result[0]
if isinstance(memories_result, (tuple, list)) and len(memories_result) > 0
else memories_result
)
promoted_payload_keys = [
@@ -632,10 +641,18 @@ class Memory(MemoryBase):
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync", "threshold": threshold},
{
"limit": limit,
"version": self.api_version,
"keys": keys,
"encoded_ids": encoded_ids,
"sync_type": "sync",
"threshold": threshold,
},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -755,7 +772,8 @@ class Memory(MemoryBase):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"keys": list(filters.keys()), "sync_type": "sync"})
keys, encoded_ids = process_telemetry_filters(filters)
capture_event("mem0.delete_all", self, {"keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
@@ -1089,7 +1107,7 @@ class AsyncMemory(MemoryBase):
self,
messages: list,
metadata: dict,
filters: dict,
effective_filters: dict,
infer: bool,
):
if not infer:
@@ -1163,7 +1181,7 @@ class AsyncMemory(MemoryBase):
query=new_mem_content,
vectors=embeddings,
limit=5,
filters=filters, # 'filters' is query_filters_for_inference
filters=effective_filters, # 'filters' is query_filters_for_inference
)
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
@@ -1192,7 +1210,6 @@ class AsyncMemory(MemoryBase):
response_format={"type": "json_object"},
)
except Exception as e:
response = ""
logging.error(f"Error in new memory actions response: {e}")
response = ""
@@ -1200,7 +1217,6 @@ class AsyncMemory(MemoryBase):
response = remove_code_blocks(response)
new_memories_with_actions = json.loads(response)
except Exception as e:
new_memories_with_actions = {}
if not new_memories_with_actions:
@@ -1210,7 +1226,6 @@ class AsyncMemory(MemoryBase):
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = {}
returned_memories = []
try:
memory_tasks = []
@@ -1270,8 +1285,11 @@ class AsyncMemory(MemoryBase):
except Exception as e:
logging.error(f"Error in memory processing loop (async): {e}")
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
"mem0.add",
self,
{"version": self.api_version, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"},
)
return returned_memories
@@ -1367,8 +1385,9 @@ class AsyncMemory(MemoryBase):
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
)
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
"mem0.get_all", self, {"limit": limit, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"}
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -1471,10 +1490,18 @@ class AsyncMemory(MemoryBase):
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
keys, encoded_ids = process_telemetry_filters(effective_filters)
capture_event(
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async", "threshold": threshold},
{
"limit": limit,
"version": self.api_version,
"keys": keys,
"encoded_ids": encoded_ids,
"sync_type": "async",
"threshold": threshold,
},
)
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit, threshold))
@@ -1599,7 +1626,8 @@ class AsyncMemory(MemoryBase):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"keys": list(filters.keys()), "sync_type": "async"})
keys, encoded_ids = process_telemetry_filters(filters)
capture_event("mem0.delete_all", self, {"keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"})
memories = await asyncio.to_thread(self.vector_store.list, filters=filters)
delete_tasks = []

View File

@@ -1,4 +1,5 @@
import re
import hashlib
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
@@ -98,3 +99,21 @@ def parse_vision_messages(messages, llm=None, vision_details="auto"):
returned_messages.append(msg)
return returned_messages
def process_telemetry_filters(filters):
"""
Process the telemetry filters
"""
if filters is None:
return {}
encoded_ids = {}
if "user_id" in filters:
encoded_ids["user_id"] = hashlib.md5(filters["user_id"].encode()).hexdigest()
if "agent_id" in filters:
encoded_ids["agent_id"] = hashlib.md5(filters["agent_id"].encode()).hexdigest()
if "run_id" in filters:
encoded_ids["run_id"] = hashlib.md5(filters["run_id"].encode()).hexdigest()
return list(filters.keys()), encoded_ids