Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -34,17 +34,17 @@ from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
def _build_filters_and_metadata(
*, # Enforce keyword-only arguments
*, # Enforce keyword-only arguments
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
actor_id: Optional[str] = None, # For query-time filtering
actor_id: Optional[str] = None, # For query-time filtering
input_metadata: Optional[Dict[str, Any]] = None,
input_filters: Optional[Dict[str, Any]] = None,
) -> tuple[Dict[str, Any], Dict[str, Any]]:
"""
Constructs metadata for storage and filters for querying based on session and actor identifiers.
This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts:
@@ -78,10 +78,10 @@ def _build_filters_and_metadata(
- effective_query_filters (Dict[str, Any]): Filters for querying memories,
scoped to the determined session and potentially a resolved actor.
"""
base_metadata_template = deepcopy(input_metadata) if input_metadata else {}
effective_query_filters = deepcopy(input_filters) if input_filters else {}
# ---------- resolve session id (mandatory) ----------
session_key, session_val = None, None
if user_id:
@@ -90,20 +90,20 @@ def _build_filters_and_metadata(
session_key, session_val = "agent_id", agent_id
elif run_id:
session_key, session_val = "run_id", run_id
if session_key is None:
raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.")
base_metadata_template[session_key] = session_val
effective_query_filters[session_key] = session_val
# ---------- optional actor filter ----------
resolved_actor_id = actor_id or effective_query_filters.get("actor_id")
if resolved_actor_id:
effective_query_filters["actor_id"] = resolved_actor_id
return base_metadata_template, effective_query_filters
setup_config()
logger = logging.getLogger(__name__)
@@ -189,7 +189,7 @@ class Memory(MemoryBase):
):
"""
Create a new memory.
Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required.
Args:
@@ -208,7 +208,7 @@ class Memory(MemoryBase):
creating procedural memories (typically requires 'agent_id'). Otherwise, memories
are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
Returns:
dict: A dictionary containing the result of the memory addition operation, typically
@@ -216,14 +216,14 @@ class Memory(MemoryBase):
and potentially "relations" if graph store is enabled.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}`
"""
processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_metadata=metadata,
)
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
@@ -231,10 +231,10 @@ class Memory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict):
messages = [messages]
elif not isinstance(messages, list):
raise ValueError("messages must be str, dict, or list[dict]")
@@ -255,7 +255,7 @@ class Memory(MemoryBase):
vector_store_result = future1.result()
graph_result = future2.result()
if self.api_version == "v1.0":
warnings.warn(
"The current add API output format is deprecated. "
@@ -277,21 +277,21 @@ class Memory(MemoryBase):
def _add_to_vector_store(self, messages, metadata, filters, infer):
if not infer:
returned_memories = []
for message_dict in messages:
if not isinstance(message_dict, dict) or \
message_dict.get("role") is None or \
message_dict.get("content") is None:
for message_dict in messages:
if (
not isinstance(message_dict, dict)
or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format: {message_dict}")
continue
if message_dict["role"] == "system":
continue
continue
per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name")
if actor_name:
per_msg_meta["actor_id"] = actor_name
@@ -311,8 +311,8 @@ class Memory(MemoryBase):
)
return returned_memories
parsed_messages = parse_messages(messages)
parsed_messages = parse_messages(messages)
if self.config.custom_fact_extraction_prompt:
system_prompt = self.config.custom_fact_extraction_prompt
user_prompt = f"Input:\n{parsed_messages}"
@@ -336,7 +336,7 @@ class Memory(MemoryBase):
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
@@ -347,7 +347,7 @@ class Memory(MemoryBase):
)
for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
unique_data = {}
for item in retrieved_old_memory:
unique_data[item["id"]] = item
@@ -389,7 +389,7 @@ class Memory(MemoryBase):
if not action_text:
logging.info("Skipping memory entry because of empty `text` field.")
continue
event_type = resp.get("event")
if event_type == "ADD":
memory_id = self._create_memory(
@@ -405,16 +405,23 @@ class Memory(MemoryBase):
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
returned_memories.append({
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
"event": event_type, "previous_memory": resp.get("old_memory"),
})
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": action_text,
"event": event_type,
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE":
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
returned_memories.append({
"id": temp_uuid_mapping[resp.get("id")], "memory": action_text,
"event": event_type,
})
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": action_text,
"event": event_type,
}
)
elif event_type == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
@@ -462,11 +469,8 @@ class Memory(MemoryBase):
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
result_item = MemoryItem(
id=memory.id,
@@ -479,18 +483,16 @@ class Memory(MemoryBase):
for key in promoted_payload_keys:
if key in memory.payload:
result_item[key] = memory.payload[key]
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
result_item["metadata"] = additional_metadata
return result_item
def get_all(
self,
*,
*,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
@@ -505,7 +507,7 @@ class Memory(MemoryBase):
agent_id (str, optional): agent id
run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example,
These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100.
@@ -515,21 +517,16 @@ class Memory(MemoryBase):
it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
capture_event(
"mem0.get_all",
self,
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"}
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -542,9 +539,9 @@ class Memory(MemoryBase):
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories_result = future_memories.result()
all_memories_result = future_memories.result()
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": all_memories_result, "relations": graph_entities_result}
@@ -556,26 +553,27 @@ class Memory(MemoryBase):
category=DeprecationWarning,
stacklevel=2,
)
return all_memories_result
return all_memories_result
else:
return {"results": all_memories_result}
def _get_all_from_vector_store(self, filters, limit):
memories_result = self.vector_store.list(filters=filters, limit=limit)
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
actual_memories = (
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
)
promoted_payload_keys = [
"user_id", "agent_id", "run_id",
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
formatted_memories = []
for mem in actual_memories:
for mem in actual_memories:
memory_item_dict = MemoryItem(
id=mem.id,
memory=mem.payload["data"],
@@ -587,15 +585,13 @@ class Memory(MemoryBase):
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict)
return formatted_memories
def search(
@@ -624,12 +620,9 @@ class Memory(MemoryBase):
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.")
@@ -651,7 +644,7 @@ class Memory(MemoryBase):
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
@@ -678,11 +671,8 @@ class Memory(MemoryBase):
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
original_memories = []
for mem in memories:
@@ -693,18 +683,16 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump()
).model_dump()
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
original_memories.append(memory_item_dict)
return original_memories
@@ -738,7 +726,7 @@ class Memory(MemoryBase):
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None):
def delete_all(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None):
"""
Delete all memories.
@@ -860,11 +848,11 @@ class Memory(MemoryBase):
except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
@@ -875,7 +863,7 @@ class Memory(MemoryBase):
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" in existing_memory.payload:
@@ -885,14 +873,14 @@ class Memory(MemoryBase):
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
@@ -1037,12 +1025,9 @@ class AsyncMemory(MemoryBase):
dict: A dictionary containing the result of the memory addition operation.
"""
processed_metadata, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_metadata=metadata
user_id=user_id, agent_id=agent_id, run_id=run_id, input_metadata=metadata
)
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
@@ -1050,15 +1035,17 @@ class AsyncMemory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, dict):
messages = [messages]
elif not isinstance(messages, list):
raise ValueError("messages must be str, dict, or list[dict]")
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm)
results = await self._create_procedural_memory(
messages, metadata=processed_metadata, prompt=prompt, llm=llm
)
return results
if self.config.llm.config.get("enable_vision"):
@@ -1066,7 +1053,9 @@ class AsyncMemory(MemoryBase):
else:
messages = parse_vision_messages(messages)
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer))
vector_store_task = asyncio.create_task(
self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)
)
graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters))
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
@@ -1090,8 +1079,8 @@ class AsyncMemory(MemoryBase):
return {"results": vector_store_result}
async def _add_to_vector_store(
self,
messages: list,
self,
messages: list,
metadata: dict,
filters: dict,
infer: bool,
@@ -1099,9 +1088,11 @@ class AsyncMemory(MemoryBase):
if not infer:
returned_memories = []
for message_dict in messages:
if not isinstance(message_dict, dict) or \
message_dict.get("role") is None or \
message_dict.get("content") is None:
if (
not isinstance(message_dict, dict)
or message_dict.get("role") is None
or message_dict.get("content") is None
):
logger.warning(f"Skipping invalid message format (async): {message_dict}")
continue
@@ -1110,20 +1101,24 @@ class AsyncMemory(MemoryBase):
per_msg_meta = deepcopy(metadata)
per_msg_meta["role"] = message_dict["role"]
actor_name = message_dict.get("name")
if actor_name:
per_msg_meta["actor_id"] = actor_name
msg_content = message_dict["content"]
msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add")
mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta)
returned_memories.append({
"id": mem_id, "memory": msg_content, "event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
})
returned_memories.append(
{
"id": mem_id,
"memory": msg_content,
"event": "ADD",
"actor_id": actor_name if actor_name else None,
"role": message_dict["role"],
}
)
return returned_memories
parsed_messages = parse_messages(messages)
@@ -1142,17 +1137,21 @@ class AsyncMemory(MemoryBase):
response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = []
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
async def process_fact_for_search(new_mem_content):
embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add")
new_message_embeddings[new_mem_content] = embeddings
existing_mems = await asyncio.to_thread(
self.vector_store.search, query=new_mem_content, vectors=embeddings,
limit=5, filters=filters, # 'filters' is query_filters_for_inference
self.vector_store.search,
query=new_mem_content,
vectors=embeddings,
limit=5,
filters=filters, # 'filters' is query_filters_for_inference
)
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
@@ -1160,9 +1159,10 @@ class AsyncMemory(MemoryBase):
search_results_list = await asyncio.gather(*search_tasks)
for result_group in search_results_list:
retrieved_old_memory.extend(result_group)
unique_data = {}
for item in retrieved_old_memory: unique_data[item["id"]] = item
for item in retrieved_old_memory:
unique_data[item["id"]] = item
retrieved_old_memory = list(unique_data.values())
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
temp_uuid_mapping = {}
@@ -1180,35 +1180,45 @@ class AsyncMemory(MemoryBase):
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new memory actions response: {e}"); response = ""
logging.error(f"Error in new memory actions response: {e}")
response = ""
try:
response = remove_code_blocks(response)
new_memories_with_actions = json.loads(response)
except Exception as e:
logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {}
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = {}
returned_memories = []
returned_memories = []
try:
memory_tasks = []
for resp in new_memories_with_actions.get("memory", []):
logging.info(resp)
try:
action_text = resp.get("text")
if not action_text: continue
if not action_text:
continue
event_type = resp.get("event")
if event_type == "ADD":
task = asyncio.create_task(self._create_memory(
data=action_text, existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata)
))
task = asyncio.create_task(
self._create_memory(
data=action_text,
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
)
memory_tasks.append((task, resp, "ADD", None))
elif event_type == "UPDATE":
task = asyncio.create_task(self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]], data=action_text,
existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata)
))
task = asyncio.create_task(
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]],
data=action_text,
existing_embeddings=new_message_embeddings,
metadata=deepcopy(metadata),
)
)
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
elif event_type == "DELETE":
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
@@ -1217,31 +1227,30 @@ class AsyncMemory(MemoryBase):
logging.info("NOOP for Memory (async).")
except Exception as e:
logging.error(f"Error processing memory action (async): {resp}, Error: {e}")
for task, resp, event_type, mem_id in memory_tasks:
try:
result_id = await task
if event_type == "ADD":
returned_memories.append({
"id": result_id, "memory": resp.get("text"), "event": event_type
})
returned_memories.append({"id": result_id, "memory": resp.get("text"), "event": event_type})
elif event_type == "UPDATE":
returned_memories.append({
"id": mem_id, "memory": resp.get("text"),
"event": event_type, "previous_memory": resp.get("old_memory")
})
returned_memories.append(
{
"id": mem_id,
"memory": resp.get("text"),
"event": event_type,
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE":
returned_memories.append({
"id": mem_id, "memory": resp.get("text"), "event": event_type
})
returned_memories.append({"id": mem_id, "memory": resp.get("text"), "event": event_type})
except Exception as e:
logging.error(f"Error awaiting memory task (async): {e}")
except Exception as e:
logging.error(f"Error in memory processing loop (async): {e}")
capture_event(
"mem0.add", self,
{"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
)
return returned_memories
@@ -1272,17 +1281,14 @@ class AsyncMemory(MemoryBase):
return None
promoted_payload_keys = [
"user_id",
"agent_id",
"run_id",
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
result_item = MemoryItem(
id=memory.id,
@@ -1295,18 +1301,16 @@ class AsyncMemory(MemoryBase):
for key in promoted_payload_keys:
if key in memory.payload:
result_item[key] = memory.payload[key]
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
result_item["metadata"] = additional_metadata
return result_item
async def get_all(
self,
*,
*,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
@@ -1314,41 +1318,36 @@ class AsyncMemory(MemoryBase):
limit: int = 100,
):
"""
List all memories.
List all memories.
Args:
user_id (str, optional): user id
agent_id (str, optional): agent id
run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100.
Args:
user_id (str, optional): user id
agent_id (str, optional): agent id
run_id (str, optional): run id
filters (dict, optional): Additional custom key-value filters to apply to the search.
These are merged with the ID-based scoping filters. For example,
`filters={"actor_id": "some_user"}`.
limit (int, optional): The maximum number of memories to return. Defaults to 100.
Returns:
dict: A dictionary containing a list of memories under the "results" key,
and potentially "relations" if graph store is enabled. For API v1.0,
it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
Returns:
dict: A dictionary containing a list of memories under the "results" key,
and potentially "relations" if graph store is enabled. For API v1.0,
it might return a direct list (see deprecation warning).
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"When 'conversation_id' is not provided (classic mode), "
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
)
raise ValueError(
"When 'conversation_id' is not provided (classic mode), "
"at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all."
)
capture_event(
"mem0.get_all",
self,
{"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
"mem0.get_all", self, {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"}
)
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -1361,9 +1360,9 @@ class AsyncMemory(MemoryBase):
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories_result = future_memories.result()
all_memories_result = future_memories.result()
graph_entities_result = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": all_memories_result, "relations": graph_entities_result}
@@ -1381,20 +1380,21 @@ class AsyncMemory(MemoryBase):
async def _get_all_from_vector_store(self, filters, limit):
memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
actual_memories = (
memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result
)
promoted_payload_keys = [
"user_id", "agent_id", "run_id",
"user_id",
"agent_id",
"run_id",
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
formatted_memories = []
for mem in actual_memories:
for mem in actual_memories:
memory_item_dict = MemoryItem(
id=mem.id,
memory=mem.payload["data"],
@@ -1406,15 +1406,13 @@ class AsyncMemory(MemoryBase):
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
formatted_memories.append(memory_item_dict)
return formatted_memories
async def search(
@@ -1442,16 +1440,13 @@ class AsyncMemory(MemoryBase):
and potentially "relations" if graph store is enabled.
Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}`
"""
_, effective_filters = _build_filters_and_metadata(
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
input_filters=filters
user_id=user_id, agent_id=agent_id, run_id=run_id, input_filters=filters
)
if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ")
capture_event(
"mem0.search",
@@ -1460,22 +1455,20 @@ class AsyncMemory(MemoryBase):
)
vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit))
graph_task = None
if self.enable_graph:
if hasattr(self.graph.search, "__await__"): # Check if graph search is async
graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit))
else:
graph_task = asyncio.create_task(
asyncio.to_thread(self.graph.search, query, effective_filters, limit)
)
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, effective_filters, limit))
if graph_task:
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
else:
original_memories = await vector_store_task
graph_entities = None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
@@ -1504,11 +1497,8 @@ class AsyncMemory(MemoryBase):
"actor_id",
"role",
]
core_and_promoted_keys = {
"data", "hash", "created_at", "updated_at", "id",
*promoted_payload_keys
}
core_and_promoted_keys = {"data", "hash", "created_at", "updated_at", "id", *promoted_payload_keys}
original_memories = []
for mem in memories:
@@ -1518,19 +1508,17 @@ class AsyncMemory(MemoryBase):
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump()
score=mem.score,
).model_dump()
for key in promoted_payload_keys:
if key in mem.payload:
memory_item_dict[key] = mem.payload[key]
additional_metadata = {
k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys
}
additional_metadata = {k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys}
if additional_metadata:
memory_item_dict["metadata"] = additional_metadata
original_memories.append(memory_item_dict)
return original_memories
@@ -1650,7 +1638,7 @@ class AsyncMemory(MemoryBase):
capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"})
return memory_id
async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None):
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
"""
Create a procedural memory asynchronously
@@ -1709,11 +1697,11 @@ class AsyncMemory(MemoryBase):
except Exception:
logger.error(f"Error getting memory with ID {memory_id} during update.")
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = deepcopy(metadata) if metadata is not None else {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
@@ -1725,8 +1713,7 @@ class AsyncMemory(MemoryBase):
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if "actor_id" in existing_memory.payload:
new_metadata["actor_id"] = existing_memory.payload["actor_id"]
if "role" in existing_memory.payload:
@@ -1736,7 +1723,7 @@ class AsyncMemory(MemoryBase):
embeddings = existing_embeddings[data]
else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
await asyncio.to_thread(
self.vector_store.update,
vector_id=memory_id,
@@ -1744,7 +1731,7 @@ class AsyncMemory(MemoryBase):
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
await asyncio.to_thread(
self.db.add_history,
memory_id,

View File

@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
try:
from langchain_memgraph import Memgraph
except ImportError:
raise ImportError(
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
)
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
)
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
@@ -74,22 +70,14 @@ class MemoryGraph:
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(
data, filters, entity_type_map
)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
to_be_deleted = self._get_delete_entities_from_search_output(
search_output, data, filters
)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
# TODO: Batch queries with APOC plugin
# TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities(
to_be_added, filters["user_id"], entity_type_map
)
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
@@ -108,16 +96,13 @@ class MemoryGraph:
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]]
for item in search_output
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
@@ -126,9 +111,7 @@ class MemoryGraph:
search_results = []
for item in reranked_results:
search_results.append(
{"source": item[0], "relationship": item[1], "destination": item[2]}
)
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
@@ -161,9 +144,7 @@ class MemoryGraph:
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(
query, params={"user_id": filters["user_id"], "limit": limit}
)
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
final_results = []
for result in results:
@@ -208,13 +189,8 @@ class MemoryGraph:
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
for k, v in entity_type_map.items()
}
logger.debug(
f"Entity type map: {entity_type_map}\n search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
@@ -223,9 +199,7 @@ class MemoryGraph:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
).replace(
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
),
},
@@ -235,9 +209,7 @@ class MemoryGraph:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
),
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
},
{
"role": "user",
@@ -304,9 +276,7 @@ class MemoryGraph:
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(
search_output_string, data, filters["user_id"]
)
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
@@ -379,12 +349,8 @@ class MemoryGraph:
# search for the nodes with the closest embeddings; this is basically
# comparison of one embedding to all embeddings in a graph -> vector
# search with cosine similarity metric
source_node_search_result = self._search_source_node(
source_embedding, user_id, threshold=0.9
)
destination_node_search_result = self._search_destination_node(
dest_embedding, user_id, threshold=0.9
)
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
@@ -424,9 +390,7 @@ class MemoryGraph:
"""
params = {
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
@@ -445,9 +409,7 @@ class MemoryGraph:
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"user_id": user_id,
}
else:

View File

@@ -1,8 +1,8 @@
import logging
import sqlite3
import threading
import uuid
import logging
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -23,9 +23,7 @@ class SQLiteManager:
"""
with self._lock, self.connection:
cur = self.connection.cursor()
cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
)
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
if cur.fetchone() is None:
return # nothing to migrate
@@ -51,13 +49,11 @@ class SQLiteManager:
logger.info("Migrating history table to new schema (no convo columns).")
cur.execute("ALTER TABLE history RENAME TO history_old")
self._create_history_table()
self._create_history_table()
intersecting = list(expected_cols & old_cols)
cols_csv = ", ".join(intersecting)
cur.execute(
f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old"
)
cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
cur.execute("DROP TABLE history_old")
def _create_history_table(self) -> None:

View File

@@ -9,8 +9,8 @@ import mem0
from mem0.memory.setup import get_or_create_user_id
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST="https://us.i.posthog.com"
PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST = "https://us.i.posthog.com"
if isinstance(MEM0_TELEMETRY, str):
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")