Formatting and version bump -> 0.1.107 (#2927)
This commit is contained in:
@@ -526,7 +526,13 @@ class MemoryClient:
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
|
||||
if custom_instructions is None and custom_categories is None and retrieval_criteria is None and enable_graph is None and version is None:
|
||||
if (
|
||||
custom_instructions is None
|
||||
and custom_categories is None
|
||||
and retrieval_criteria is None
|
||||
and enable_graph is None
|
||||
and version is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them"
|
||||
)
|
||||
@@ -675,9 +681,7 @@ class MemoryClient:
|
||||
capture_client_event("client.feedback", self, data, {"sync_type": "sync"})
|
||||
return response.json()
|
||||
|
||||
def _prepare_payload(
|
||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare the payload for API requests.
|
||||
|
||||
Args:
|
||||
@@ -803,9 +807,7 @@ class AsyncMemoryClient:
|
||||
error_message = str(e)
|
||||
raise ValueError(f"Error: {error_message}")
|
||||
|
||||
def _prepare_payload(
|
||||
self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def _prepare_payload(self, messages: List[Dict[str, str]], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare the payload for API requests.
|
||||
|
||||
Args:
|
||||
@@ -1115,7 +1117,13 @@ class AsyncMemoryClient:
|
||||
if not (self.org_id and self.project_id):
|
||||
raise ValueError("org_id and project_id must be set to update instructions or categories")
|
||||
|
||||
if custom_instructions is None and custom_categories is None and retrieval_criteria is None and enable_graph is None and version is None:
|
||||
if (
|
||||
custom_instructions is None
|
||||
and custom_categories is None
|
||||
and retrieval_criteria is None
|
||||
and enable_graph is None
|
||||
and version is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently we only support updating custom_instructions or custom_categories or retrieval_criteria, so you must provide at least one of them"
|
||||
)
|
||||
|
||||
@@ -18,22 +18,15 @@ class SarvamLLM(LLMBase):
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable "
|
||||
"or provide api_key in config."
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable " "or provide api_key in config."
|
||||
)
|
||||
|
||||
# Set base URL - use config value or environment or default
|
||||
self.base_url = (
|
||||
getattr(self.config, 'sarvam_base_url', None) or
|
||||
os.getenv("SARVAM_API_BASE") or
|
||||
"https://api.sarvam.ai/v1"
|
||||
getattr(self.config, "sarvam_base_url", None) or os.getenv("SARVAM_API_BASE") or "https://api.sarvam.ai/v1"
|
||||
)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None
|
||||
) -> str:
|
||||
def generate_response(self, messages: List[Dict[str, str]], response_format=None) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Sarvam-M.
|
||||
|
||||
@@ -47,10 +40,7 @@ class SarvamLLM(LLMBase):
|
||||
"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# Prepare the request payload
|
||||
params = {
|
||||
@@ -74,10 +64,7 @@ class SarvamLLM(LLMBase):
|
||||
params["model"] = self.config.model.get("name", "sarvam-m")
|
||||
|
||||
# Add Sarvam-specific parameters
|
||||
sarvam_specific_params = [
|
||||
'reasoning_effort', 'frequency_penalty', 'presence_penalty',
|
||||
'seed', 'stop', 'n'
|
||||
]
|
||||
sarvam_specific_params = ["reasoning_effort", "frequency_penalty", "presence_penalty", "seed", "stop", "n"]
|
||||
|
||||
for param in sarvam_specific_params:
|
||||
if param in self.config.model:
|
||||
@@ -89,8 +76,8 @@ class SarvamLLM(LLMBase):
|
||||
|
||||
result = response.json()
|
||||
|
||||
if 'choices' in result and len(result['choices']) > 0:
|
||||
return result['choices'][0]['message']['content']
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise ValueError("No response choices found in Sarvam API response")
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from mem0.memory.utils import (
|
||||
parse_messages,
|
||||
parse_vision_messages,
|
||||
remove_code_blocks,
|
||||
process_telemetry_filters,
|
||||
)
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
@@ -45,13 +46,13 @@ def _build_filters_and_metadata(
|
||||
"""
|
||||
Constructs metadata for storage and filters for querying based on session and actor identifiers.
|
||||
|
||||
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
|
||||
|
||||
This helper supports multiple session identifiers (`user_id`, `agent_id`, and/or `run_id`)
|
||||
for flexible session scoping and optionally narrows queries to a specific `actor_id`. It returns two dicts:
|
||||
|
||||
1. `base_metadata_template`: Used as a template for metadata when storing new memories.
|
||||
It includes the primary session identifier(s) and any `input_metadata`.
|
||||
2. `effective_query_filters`: Used for querying existing memories. It includes the
|
||||
primary session identifier(s), any `input_filters`, and a resolved actor
|
||||
It includes all provided session identifier(s) and any `input_metadata`.
|
||||
2. `effective_query_filters`: Used for querying existing memories. It includes all
|
||||
provided session identifier(s), any `input_filters`, and a resolved actor
|
||||
identifier for targeted filtering if specified by any actor-related inputs.
|
||||
|
||||
Actor filtering precedence: explicit `actor_id` arg → `filters["actor_id"]`
|
||||
@@ -59,11 +60,9 @@ def _build_filters_and_metadata(
|
||||
as the actor for storage is typically derived from message content at a later stage.
|
||||
|
||||
Args:
|
||||
user_id (Optional[str]): User identifier, primarily for Classic Mode session scoping.
|
||||
agent_id (Optional[str]): Agent identifier, for Classic Mode session scoping or
|
||||
as auxiliary information in Group Mode.
|
||||
run_id (Optional[str]): Run identifier, for Classic Mode session scoping or
|
||||
as auxiliary information in Group Mode.
|
||||
user_id (Optional[str]): User identifier, for session scoping.
|
||||
agent_id (Optional[str]): Agent identifier, for session scoping.
|
||||
run_id (Optional[str]): Run identifier, for session scoping.
|
||||
actor_id (Optional[str]): Explicit actor identifier, used as a potential source for
|
||||
actor-specific filtering. See actor resolution precedence in the main description.
|
||||
input_metadata (Optional[Dict[str, Any]]): Base dictionary to be augmented with
|
||||
@@ -74,28 +73,34 @@ def _build_filters_and_metadata(
|
||||
Returns:
|
||||
tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
|
||||
- base_metadata_template (Dict[str, Any]): Metadata template for storing memories,
|
||||
scoped to the determined session.
|
||||
scoped to the provided session(s).
|
||||
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
|
||||
scoped to the determined session and potentially a resolved actor.
|
||||
scoped to the provided session(s) and potentially a resolved actor.
|
||||
"""
|
||||
|
||||
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
|
||||
effective_query_filters = deepcopy(input_filters) if input_filters else {}
|
||||
|
||||
# ---------- resolve session id (mandatory) ----------
|
||||
session_key, session_val = None, None
|
||||
# ---------- add all provided session ids ----------
|
||||
session_ids_provided = []
|
||||
|
||||
if user_id:
|
||||
session_key, session_val = "user_id", user_id
|
||||
elif agent_id:
|
||||
session_key, session_val = "agent_id", agent_id
|
||||
elif run_id:
|
||||
session_key, session_val = "run_id", run_id
|
||||
base_metadata_template["user_id"] = user_id
|
||||
effective_query_filters["user_id"] = user_id
|
||||
session_ids_provided.append("user_id")
|
||||
|
||||
if session_key is None:
|
||||
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
|
||||
if agent_id:
|
||||
base_metadata_template["agent_id"] = agent_id
|
||||
effective_query_filters["agent_id"] = agent_id
|
||||
session_ids_provided.append("agent_id")
|
||||
|
||||
base_metadata_template[session_key] = session_val
|
||||
effective_query_filters[session_key] = session_val
|
||||
if run_id:
|
||||
base_metadata_template["run_id"] = run_id
|
||||
effective_query_filters["run_id"] = run_id
|
||||
session_ids_provided.append("run_id")
|
||||
|
||||
if not session_ids_provided:
|
||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be provided.")
|
||||
|
||||
# ---------- optional actor filter ----------
|
||||
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
|
||||
@@ -433,10 +438,11 @@ class Memory(MemoryBase):
|
||||
except Exception as e:
|
||||
logging.error(f"Error iterating new_memories_with_actions: {e}")
|
||||
|
||||
keys, encoded_ids = process_telemetry_filters(filters)
|
||||
capture_event(
|
||||
"mem0.add",
|
||||
self,
|
||||
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "sync"},
|
||||
{"version": self.api_version, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"},
|
||||
)
|
||||
return returned_memories
|
||||
|
||||
@@ -529,8 +535,9 @@ class Memory(MemoryBase):
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
||||
|
||||
keys, encoded_ids = process_telemetry_filters(effective_filters)
|
||||
capture_event(
|
||||
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
|
||||
"mem0.get_all", self, {"limit": limit, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"}
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -564,7 +571,9 @@ class Memory(MemoryBase):
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories_result = self.vector_store.list(filters=filters, limit=limit)
|
||||
actual_memories = (
|
||||
memories_result[0] if isinstance(memories_result, (tuple, list)) and len(memories_result) > 0 else memories_result
|
||||
memories_result[0]
|
||||
if isinstance(memories_result, (tuple, list)) and len(memories_result) > 0
|
||||
else memories_result
|
||||
)
|
||||
|
||||
promoted_payload_keys = [
|
||||
@@ -632,10 +641,18 @@ class Memory(MemoryBase):
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
|
||||
|
||||
keys, encoded_ids = process_telemetry_filters(effective_filters)
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync", "threshold": threshold},
|
||||
{
|
||||
"limit": limit,
|
||||
"version": self.api_version,
|
||||
"keys": keys,
|
||||
"encoded_ids": encoded_ids,
|
||||
"sync_type": "sync",
|
||||
"threshold": threshold,
|
||||
},
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -755,7 +772,8 @@ class Memory(MemoryBase):
|
||||
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
|
||||
)
|
||||
|
||||
capture_event("mem0.delete_all", self, {"keys": list(filters.keys()), "sync_type": "sync"})
|
||||
keys, encoded_ids = process_telemetry_filters(filters)
|
||||
capture_event("mem0.delete_all", self, {"keys": keys, "encoded_ids": encoded_ids, "sync_type": "sync"})
|
||||
memories = self.vector_store.list(filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory(memory.id)
|
||||
@@ -1089,7 +1107,7 @@ class AsyncMemory(MemoryBase):
|
||||
self,
|
||||
messages: list,
|
||||
metadata: dict,
|
||||
filters: dict,
|
||||
effective_filters: dict,
|
||||
infer: bool,
|
||||
):
|
||||
if not infer:
|
||||
@@ -1163,7 +1181,7 @@ class AsyncMemory(MemoryBase):
|
||||
query=new_mem_content,
|
||||
vectors=embeddings,
|
||||
limit=5,
|
||||
filters=filters, # 'filters' is query_filters_for_inference
|
||||
filters=effective_filters, # 'filters' is query_filters_for_inference
|
||||
)
|
||||
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
|
||||
|
||||
@@ -1192,7 +1210,6 @@ class AsyncMemory(MemoryBase):
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
response = ""
|
||||
logging.error(f"Error in new memory actions response: {e}")
|
||||
response = ""
|
||||
@@ -1200,7 +1217,6 @@ class AsyncMemory(MemoryBase):
|
||||
response = remove_code_blocks(response)
|
||||
new_memories_with_actions = json.loads(response)
|
||||
except Exception as e:
|
||||
|
||||
new_memories_with_actions = {}
|
||||
|
||||
if not new_memories_with_actions:
|
||||
@@ -1210,7 +1226,6 @@ class AsyncMemory(MemoryBase):
|
||||
logging.error(f"Invalid JSON response: {e}")
|
||||
new_memories_with_actions = {}
|
||||
|
||||
|
||||
returned_memories = []
|
||||
try:
|
||||
memory_tasks = []
|
||||
@@ -1270,8 +1285,11 @@ class AsyncMemory(MemoryBase):
|
||||
except Exception as e:
|
||||
logging.error(f"Error in memory processing loop (async): {e}")
|
||||
|
||||
keys, encoded_ids = process_telemetry_filters(effective_filters)
|
||||
capture_event(
|
||||
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
|
||||
"mem0.add",
|
||||
self,
|
||||
{"version": self.api_version, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"},
|
||||
)
|
||||
return returned_memories
|
||||
|
||||
@@ -1367,8 +1385,9 @@ class AsyncMemory(MemoryBase):
|
||||
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
|
||||
)
|
||||
|
||||
keys, encoded_ids = process_telemetry_filters(effective_filters)
|
||||
capture_event(
|
||||
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
|
||||
"mem0.get_all", self, {"limit": limit, "keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"}
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
@@ -1471,10 +1490,18 @@ class AsyncMemory(MemoryBase):
|
||||
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
|
||||
|
||||
keys, encoded_ids = process_telemetry_filters(effective_filters)
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async", "threshold": threshold},
|
||||
{
|
||||
"limit": limit,
|
||||
"version": self.api_version,
|
||||
"keys": keys,
|
||||
"encoded_ids": encoded_ids,
|
||||
"sync_type": "async",
|
||||
"threshold": threshold,
|
||||
},
|
||||
)
|
||||
|
||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit, threshold))
|
||||
@@ -1599,7 +1626,8 @@ class AsyncMemory(MemoryBase):
|
||||
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
|
||||
)
|
||||
|
||||
capture_event("mem0.delete_all", self, {"keys": list(filters.keys()), "sync_type": "async"})
|
||||
keys, encoded_ids = process_telemetry_filters(filters)
|
||||
capture_event("mem0.delete_all", self, {"keys": keys, "encoded_ids": encoded_ids, "sync_type": "async"})
|
||||
memories = await asyncio.to_thread(self.vector_store.list, filters=filters)
|
||||
|
||||
delete_tasks = []
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
|
||||
@@ -98,3 +99,21 @@ def parse_vision_messages(messages, llm=None, vision_details="auto"):
|
||||
returned_messages.append(msg)
|
||||
|
||||
return returned_messages
|
||||
|
||||
|
||||
def process_telemetry_filters(filters):
|
||||
"""
|
||||
Process the telemetry filters
|
||||
"""
|
||||
if filters is None:
|
||||
return {}
|
||||
|
||||
encoded_ids = {}
|
||||
if "user_id" in filters:
|
||||
encoded_ids["user_id"] = hashlib.md5(filters["user_id"].encode()).hexdigest()
|
||||
if "agent_id" in filters:
|
||||
encoded_ids["agent_id"] = hashlib.md5(filters["agent_id"].encode()).hexdigest()
|
||||
if "run_id" in filters:
|
||||
encoded_ids["run_id"] = hashlib.md5(filters["run_id"].encode()).hexdigest()
|
||||
|
||||
return list(filters.keys()), encoded_ids
|
||||
|
||||
Reference in New Issue
Block a user