From a1c9a63074cb3c9291aac0eb78c92062cf86ccd4 Mon Sep 17 00:00:00 2001 From: Chaithanya Kumar Date: Fri, 16 May 2025 23:38:36 +0530 Subject: [PATCH] # feat: Add Group Chat Memory Feature support to Python SDK enhancing mem0 (#2669) --- docs/docs.json | 1 + docs/examples/collaborative-task-agent.mdx | 273 +++++ mem0/memory/main.py | 1192 ++++++++++++-------- mem0/memory/storage.py | 248 ++-- 4 files changed, 1121 insertions(+), 593 deletions(-) create mode 100644 docs/examples/collaborative-task-agent.mdx diff --git a/docs/docs.json b/docs/docs.json index e6891007..fbb9153a 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -203,6 +203,7 @@ "examples/aws_example", "examples/mem0-demo", "examples/ai_companion_js", + "examples/collaborative-task-agent", "examples/eliza_os", "examples/mem0-mastra", "examples/mem0-with-ollama", diff --git a/docs/examples/collaborative-task-agent.mdx b/docs/examples/collaborative-task-agent.mdx new file mode 100644 index 00000000..f170f305 --- /dev/null +++ b/docs/examples/collaborative-task-agent.mdx @@ -0,0 +1,273 @@ +--- +title: Collaborative Task Agent +--- + + + +# Building a Collaborative Task Management System with Mem0 + +## Overview + +Mem0's advanced attribution capabilities now allow you to create multi-user , multi-agent collaborative or chat systems by attaching an **`actor_id`** to each memory. By setting the users's name in `message["name"]`, you can build powerful team collaboration tools where contributions are properly attributed to their authors. + +When using `infer=False`, messages are stored exactly as provided while still preserving actor metadata—making this approach ideal for: + +- Multi-user chat applications +- Team brainstorming sessions +- Any collaborative "shared canvas" scenario + +> **ℹ️ Note** +> Actor attribution works today with `infer=False` mode. +> Full attribution support for the fact-extraction pipeline (`infer=True`) will be available in an upcoming release. + +## Key Concepts + +### Session Context + +Session context is defined by one of three identifiers: +- **`user_id`**: Ideal for personal memory or user-specific data +- **`agent_id`**: Used for agent-specific memory storage +- **`run_id`**: Best for shared task contexts or collaborative spaces + +Developers choose which identifier best represents their use case. In this example, we use `run_id` to create a shared project space where all team members can collaborate. + +### Actor Attribution + +Actor attribution is derived internally from: +- **`message["name"]`**: Becomes the `actor_id` in the memory's metadata +- **`message["role"]`**: Stored as the `role` in the memory's metadata + +Note that `actor_id` is not a top-level parameter for the `add()` method, but is instead extracted from the message itself. + +### Memory Filtering + +When retrieving memories, you can filter by actor using the `filters` parameter: +```python +# Get all memories from a specific actor +memories = mem.search("query", run_id="landing-v1", filters={"actor_id": "alice"}) + +# Get all memories from all team members +all_memories = mem.get_all(run_id="landing-v1") +``` + +## Upcoming Features + +Mem0 will soon support full actor attribution with `infer=True`, enabling automatic extraction of actor names during the fact extraction process. This enhancement will allow the system to: + +1. Maintain attribution information when converting raw messages to semantic facts +2. Associate extracted knowledge with its original source +3. Track the provenance of information across complex interactions + +Mem0's actor attribution system can power a wide range of advanced conversation and agent scenarios: + +### Conversation Scenarios + +| Scenario | Description | Implementation | +|----------|-------------|----------------| +| **Simple Chat** | One-to-one conversation between user and assistant +| **Multi-User Chat** | Multiple users conversing with a single assistant +| **Multi-Agent Chat** | Multiple AI assistants with distinct personas or capabilities +| **Group Chat** | Complex interactions between multiple humans and assistants +### Agent-Based Applications + +The collaborative task agent uses a simple but powerful architecture: + +* A **shared project space** identified by a single `run_id` +* Each participant (user or AI) writes with their own **unique name** which becomes the `actor_id` in Mem0 +* All memories can be searched, filtered, or visualized by actor + + +## Implementation + +Below is a complete implementation of a collaborative task agent that demonstrates how to build team-oriented applications with Mem0. + +```python +from openai import OpenAI +from mem0 import Memory +import os +from datetime import datetime # For parsing and formatting timestamps + +# Configuration +os.environ["OPENAI_API_KEY"] = "sk-your-key" # Replace with your key +client = OpenAI() + +RUN_ID = "landing-v1" # Shared project context +APP_ID = "task-agent-demo" # Application identifier + +# Initialize Mem0 with default settings (local Qdrant + SQLite) +# Ensure the path is writable if not using in-memory +mem = Memory() + +class TaskAgent: + def __init__(self, run_id: str): + """ + Initialize a collaborative task agent for a specific project. + + Args: + run_id: Unique identifier for this project workspace + """ + self.run_id = run_id + self.mem = mem + + def add_message(self, role: str, speaker: str, content: str): + """ + Store a chat message with proper attribution. + + Args: + role: Message role (user, assistant, system) + speaker: Name of the person/agent speaking (becomes actor_id) + content: The actual message content + """ + msg = {"role": role, "name": speaker, "content": content} + # Ensure created_at is stored. Mem0 does this by default. + self.mem.add( + [msg], + run_id=self.run_id, + metadata={"app_id": APP_ID}, + infer=False + ) + + def brainstorm(self, prompt: str, speaker: str = "assistant", search_limit: int = 10, exclude_assistant_context: bool = False): + """ + Generate a response based on project context and team input. + + Args: + prompt: The question or task to address + speaker: Name to attribute the assistant's response to + search_limit: Max number of memories to retrieve for context + exclude_assistant_context: If True, filters out assistant's own messages from context + + Returns: + str: The assistant's response + """ + # Retrieve relevant context from team's shared memory + # Fetch a bit more if we plan to filter, to ensure we still get enough relevant user messages. + fetch_limit = search_limit + 5 if exclude_assistant_context else search_limit + retrieved_memories = self.mem.search(prompt, run_id=self.run_id, limit=fetch_limit)["results"] + + # Client-side sorting by 'created_at' to prioritize recent memories for context. + # Note: Timestamps should be in a directly comparable format or parsed. + # Mem0 stores created_at as ISO format strings, which are comparable. + retrieved_memories.sort(key=lambda m: m.get('created_at', ''), reverse=True) + + ctx_for_llm = [] + if exclude_assistant_context: + for m in retrieved_memories: + if m.get("role") != "assistant": + ctx_for_llm.append(m) + if len(ctx_for_llm) >= search_limit: + break + else: + ctx_for_llm = retrieved_memories[:search_limit] + + context_parts = [] + for m in ctx_for_llm: + actor = m.get('actor_id') or "Unknown" + # Attempt to parse and format the timestamp for better readability + try: + ts_iso = m.get('created_at', '') + if ts_iso: + ts_obj = datetime.fromisoformat(ts_iso.replace('Z', '+00:00')) # Handle Zulu time + formatted_ts = ts_obj.strftime('%Y-%m-%d %H:%M:%S %Z') + else: + formatted_ts = "Timestamp N/A" + except ValueError: + formatted_ts = ts_iso # Fallback to raw string if parsing fails + context_parts.append(f"- {m['memory']} (by {actor} at {formatted_ts})") + + context_str = "\n".join(context_parts) + + # Generate response with context-aware prompting + sys_prompt = "You are the team's project assistant. Use the provided memory context, paying attention to timestamps for recency, to answer the user's query or perform the task." + user_prompt_with_context = f"Query: {prompt}\n\nRelevant Context (most recent first):\n{context_str}" + + msgs = [ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": user_prompt_with_context} + ] + + reply = client.chat.completions.create( + model="gpt-4o-mini", + messages=msgs + ).choices[0].message.content.strip() + + # Store the assistant's response with attribution + self.add_message("assistant", speaker, reply) + return reply + + def dump(self, sort_by_time: bool = True, group_by_speaker: bool = False): + """ + Display all messages in the shared project space with attribution. + Can be sorted by time and/or grouped by speaker. + """ + results = self.mem.get_all(run_id=self.run_id)["results"] + + if not results: + print("No memories found for this run.") + return + + # Sort by 'created_at' if requested + if sort_by_time: + results.sort(key=lambda m: m.get('created_at', '')) + print(f"\n--- Project memory (run_id: {self.run_id}, sorted by time) ---") + else: + print(f"\n--- Project memory (run_id: {self.run_id}) ---") + + if group_by_speaker: + from collections import defaultdict + grouped_memories = defaultdict(list) + for m in results: # Use already potentially sorted results + grouped_memories[m.get("actor_id") or "Unknown"].append(m) + + for speaker, mem_list in grouped_memories.items(): + print(f"\n=== Speaker: {speaker} ===") + # If not already sorted by time globally, sort within group + # If already sorted globally, this re-sort is redundant unless different key. + # For simplicity, if sort_by_time was true, list is already sorted. + for m_item in mem_list: + timestamp_str = m_item.get('created_at', 'Timestamp N/A') + try: + # Basic parsing for display, adjust as needed + dt_obj = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) + formatted_time = dt_obj.strftime('%Y-%m-%d %H:%M:%S') + except ValueError: + formatted_time = timestamp_str # Fallback + print(f"[{formatted_time:19}] {m_item['memory']}") + else: # Not grouping by speaker + for m in results: + who = m.get("actor_id") or "Unknown" + timestamp_str = m.get('created_at', 'Timestamp N/A') + try: + dt_obj = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) + formatted_time = dt_obj.strftime('%Y-%m-%d %H:%M:%S') + except ValueError: + formatted_time = timestamp_str # Fallback + print(f"[{formatted_time:19}][{who:8}] {m['memory']}") + +# Demo Usage +agent = TaskAgent(RUN_ID) + +# Team collaboration session +agent.add_message("user", "alice", "Let's list tasks for the new landing page.") +agent.add_message("user", "bob", "I'll own the hero section copy. Maybe tomorrow.") +agent.add_message("user", "carol", "I'll choose three product screenshots later today.") +agent.add_message("user", "alice", "Actually, I will work on the hero section copy today.") + + +print("\nAssistant brainstorm reply (default settings):\n") +print(agent.brainstorm("What are the current open tasks related to the hero section?")) + +print("\nAssistant brainstorm reply (excluding its own prior context):\n") +print(agent.brainstorm("Summarize what Alice is working on.", exclude_assistant_context=True)) + + +print("\n--- Dump (sorted by time by default) ---") +agent.dump() + +print("\n--- Dump (grouped by speaker, also sorted by time globally) ---") +agent.dump(group_by_speaker=True) + +print("\n--- Dump (default order, not sorted by time explicitly by dump) ---") +agent.dump(sort_by_time=False) + +``` diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 8ec74f8c..3744724d 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -9,7 +9,7 @@ import uuid import warnings from copy import deepcopy from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, Optional import pytz from pydantic import ValidationError @@ -32,9 +32,80 @@ from mem0.memory.utils import ( ) from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory -# Setup user config -setup_config() +def _build_filters_and_metadata( + *, # Enforce keyword-only arguments + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + actor_id: Optional[str] = None, # For query-time filtering + input_metadata: Optional[Dict[str, Any]] = None, + input_filters: Optional[Dict[str, Any]] = None, +) -> tuple[Dict[str, Any], Dict[str, Any]]: + """ + Constructs metadata for storage and filters for querying based on session and actor identifiers. + + This helper ties every memory/query to exactly one session id (`user_id`, `agent_id`, or `run_id`) and optionally narrows queries to a specific `actor_id`. It returns two dicts: + + + 1. `base_metadata_template`: Used as a template for metadata when storing new memories. + It includes the primary session identifier(s) and any `input_metadata`. + 2. `effective_query_filters`: Used for querying existing memories. It includes the + primary session identifier(s), any `input_filters`, and a resolved actor + identifier for targeted filtering if specified by any actor-related inputs. + + Actor filtering precedence: explicit `actor_id` arg → `filters["actor_id"]` + This resolved actor ID is used for querying but is not added to `base_metadata_template`, + as the actor for storage is typically derived from message content at a later stage. + + Args: + user_id (Optional[str]): User identifier, primarily for Classic Mode session scoping. + agent_id (Optional[str]): Agent identifier, for Classic Mode session scoping or + as auxiliary information in Group Mode. + run_id (Optional[str]): Run identifier, for Classic Mode session scoping or + as auxiliary information in Group Mode. + actor_id (Optional[str]): Explicit actor identifier, used as a potential source for + actor-specific filtering. See actor resolution precedence in the main description. + input_metadata (Optional[Dict[str, Any]]): Base dictionary to be augmented with + session identifiers for the storage metadata template. Defaults to an empty dict. + input_filters (Optional[Dict[str, Any]]): Base dictionary to be augmented with + session and actor identifiers for query filters. Defaults to an empty dict. + + Returns: + tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing: + - base_metadata_template (Dict[str, Any]): Metadata template for storing memories, + scoped to the determined session. + - effective_query_filters (Dict[str, Any]): Filters for querying memories, + scoped to the determined session and potentially a resolved actor. + """ + + base_metadata_template = deepcopy(input_metadata) if input_metadata else {} + effective_query_filters = deepcopy(input_filters) if input_filters else {} + + # ---------- resolve session id (mandatory) ---------- + session_key, session_val = None, None + if user_id: + session_key, session_val = "user_id", user_id + elif agent_id: + session_key, session_val = "agent_id", agent_id + elif run_id: + session_key, session_val = "run_id", run_id + + if session_key is None: + raise ValueError("One of 'user_id', 'agent_id', or 'run_id' must be provided.") + + base_metadata_template[session_key] = session_val + effective_query_filters[session_key] = session_val + + # ---------- optional actor filter ---------- + resolved_actor_id = actor_id or effective_query_filters.get("actor_id") + if resolved_actor_id: + effective_query_filters["actor_id"] = resolved_actor_id + + return base_metadata_template, effective_query_filters + + +setup_config() logger = logging.getLogger(__name__) @@ -107,55 +178,52 @@ class Memory(MemoryBase): def add( self, messages, - user_id=None, - agent_id=None, - run_id=None, - metadata=None, - filters=None, - infer=True, - memory_type=None, - prompt=None, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + infer: bool = True, + memory_type: Optional[str] = None, + prompt: Optional[str] = None, ): """ Create a new memory. + + Adds new memories scoped to a single session id (e.g. `user_id`, `agent_id`, or `run_id`). One of those ids is required. Args: - messages (str or List[Dict[str, str]]): Messages to store in the memory. + messages (str or List[Dict[str, str]]): The message content or list of messages + (e.g., `[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}]`) + to be processed and stored. user_id (str, optional): ID of the user creating the memory. Defaults to None. agent_id (str, optional): ID of the agent creating the memory. Defaults to None. run_id (str, optional): ID of the run creating the memory. Defaults to None. metadata (dict, optional): Metadata to store with the memory. Defaults to None. - filters (dict, optional): Filters to apply to the search. Defaults to None. - infer (bool, optional): Whether to infer the memories. Defaults to True. - memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories. + infer (bool, optional): If True (default), an LLM is used to extract key facts from + 'messages' and decide whether to add, update, or delete related memories. + If False, 'messages' are added as raw memories directly. + memory_type (str, optional): Specifies the type of memory. Currently, only + `MemoryType.PROCEDURAL.value` ("procedural_memory") is explicitly handled for + creating procedural memories (typically requires 'agent_id'). Otherwise, memories + are treated as general conversational/factual memories.memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories. prompt (str, optional): Prompt to use for the memory creation. Defaults to None. + + Returns: - dict: A dictionary containing the result of the memory addition operation. - result: dict of affected events with each dict has the following key: - 'memories': affected memories - 'graph': affected graph memories - - 'memories' and 'graph' is a dict, each with following subkeys: - 'add': added memory - 'update': updated memory - 'delete': deleted memory - - + dict: A dictionary containing the result of the memory addition operation, typically + including a list of memory items affected (added, updated) under a "results" key, + and potentially "relations" if graph store is enabled. + Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "event": "ADD"}]}` """ - if metadata is None: - metadata = {} - - filters = filters or {} - if user_id: - filters["user_id"] = metadata["user_id"] = user_id - if agent_id: - filters["agent_id"] = metadata["agent_id"] = agent_id - if run_id: - filters["run_id"] = metadata["run_id"] = run_id - - if not any(key in filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError("One of the filters: user_id, agent_id or run_id is required!") - + + processed_metadata, effective_filters = _build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_metadata=metadata, + ) + if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: raise ValueError( f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." @@ -163,9 +231,15 @@ class Memory(MemoryBase): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] + + elif isinstance(messages, dict): + messages = [messages] + + elif not isinstance(messages, list): + raise ValueError("messages must be str, dict, or list[dict]") if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value: - results = self._create_procedural_memory(messages, metadata=metadata, prompt=prompt) + results = self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt) return results if self.config.llm.config.get("enable_vision"): @@ -174,14 +248,14 @@ class Memory(MemoryBase): messages = parse_vision_messages(messages) with concurrent.futures.ThreadPoolExecutor() as executor: - future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer) - future2 = executor.submit(self._add_to_graph, messages, filters) + future1 = executor.submit(self._add_to_vector_store, messages, processed_metadata, effective_filters, infer) + future2 = executor.submit(self._add_to_graph, messages, effective_filters) concurrent.futures.wait([future1, future2]) vector_store_result = future1.result() graph_result = future2.result() - + if self.api_version == "v1.0": warnings.warn( "The current add API output format is deprecated. " @@ -203,15 +277,42 @@ class Memory(MemoryBase): def _add_to_vector_store(self, messages, metadata, filters, infer): if not infer: returned_memories = [] - for message in messages: - if message["role"] != "system": - message_embeddings = self.embedding_model.embed(message["content"], "add") - memory_id = self._create_memory(message["content"], message_embeddings, metadata) - returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"}) + for message_dict in messages: + if not isinstance(message_dict, dict) or \ + message_dict.get("role") is None or \ + message_dict.get("content") is None: + logger.warning(f"Skipping invalid message format: {message_dict}") + continue + + if message_dict["role"] == "system": + continue + + + per_msg_meta = deepcopy(metadata) + per_msg_meta["role"] = message_dict["role"] + + + actor_name = message_dict.get("name") + if actor_name: + per_msg_meta["actor_id"] = actor_name + + msg_content = message_dict["content"] + msg_embeddings = self.embedding_model.embed(msg_content, "add") + mem_id = self._create_memory(msg_content, msg_embeddings, per_msg_meta) + + returned_memories.append( + { + "id": mem_id, + "memory": msg_content, + "event": "ADD", + "actor_id": actor_name if actor_name else None, + "role": message_dict["role"], + } + ) return returned_memories - parsed_messages = parse_messages(messages) - + parsed_messages = parse_messages(messages) + if self.config.custom_fact_extraction_prompt: system_prompt = self.config.custom_fact_extraction_prompt user_prompt = f"Input:\n{parsed_messages}" @@ -235,7 +336,7 @@ class Memory(MemoryBase): retrieved_old_memory = [] new_message_embeddings = {} - for new_mem in new_retrieved_facts: + for new_mem in new_retrieved_facts: messages_embeddings = self.embedding_model.embed(new_mem, "add") new_message_embeddings[new_mem] = messages_embeddings existing_memories = self.vector_store.search( @@ -246,6 +347,7 @@ class Memory(MemoryBase): ) for mem in existing_memories: retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]}) + unique_data = {} for item in retrieved_old_memory: unique_data[item["id"]] = item @@ -283,59 +385,48 @@ class Memory(MemoryBase): for resp in new_memories_with_actions.get("memory", []): logging.info(resp) try: - if not resp.get("text"): + action_text = resp.get("text") + if not action_text: logging.info("Skipping memory entry because of empty `text` field.") continue - elif resp.get("event") == "ADD": + + event_type = resp.get("event") + if event_type == "ADD": memory_id = self._create_memory( - data=resp.get("text"), + data=action_text, existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata), ) - returned_memories.append( - { - "id": memory_id, - "memory": resp.get("text"), - "event": resp.get("event"), - } - ) - elif resp.get("event") == "UPDATE": + returned_memories.append({"id": memory_id, "memory": action_text, "event": event_type}) + elif event_type == "UPDATE": self._update_memory( - memory_id=temp_uuid_mapping[resp["id"]], - data=resp.get("text"), + memory_id=temp_uuid_mapping[resp.get("id")], + data=action_text, existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata), ) - returned_memories.append( - { - "id": temp_uuid_mapping[resp.get("id")], - "memory": resp.get("text"), - "event": resp.get("event"), - "previous_memory": resp.get("old_memory"), - } - ) - elif resp.get("event") == "DELETE": + returned_memories.append({ + "id": temp_uuid_mapping[resp.get("id")], "memory": action_text, + "event": event_type, "previous_memory": resp.get("old_memory"), + }) + elif event_type == "DELETE": self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]) - returned_memories.append( - { - "id": temp_uuid_mapping[resp.get("id")], - "memory": resp.get("text"), - "event": resp.get("event"), - } - ) - elif resp.get("event") == "NONE": + returned_memories.append({ + "id": temp_uuid_mapping[resp.get("id")], "memory": action_text, + "event": event_type, + }) + elif event_type == "NONE": logging.info("NOOP for Memory.") except Exception as e: - logging.error(f"Error in new_memories_with_actions: {e}") + logging.error(f"Error processing memory action: {resp}, Error: {e}") except Exception as e: - logging.error(f"Error in new_memories_with_actions: {e}") + logging.error(f"Error iterating new_memories_with_actions: {e}") capture_event( "mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "sync"}, ) - return returned_memories def _add_to_graph(self, messages, filters): @@ -364,148 +455,194 @@ class Memory(MemoryBase): if not memory: return None - filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)} + promoted_payload_keys = [ + "user_id", + "agent_id", + "run_id", + "actor_id", + "role", + ] + + core_and_promoted_keys = { + "data", "hash", "created_at", "updated_at", "id", + *promoted_payload_keys + } - # Prepare base memory item - memory_item = MemoryItem( + result_item = MemoryItem( id=memory.id, memory=memory.payload["data"], hash=memory.payload.get("hash"), created_at=memory.payload.get("created_at"), updated_at=memory.payload.get("updated_at"), - ).model_dump(exclude={"score"}) + ).model_dump() - # Add metadata if there are additional keys - excluded_keys = { - "user_id", - "agent_id", - "run_id", - "hash", - "data", - "created_at", - "updated_at", - "id", + for key in promoted_payload_keys: + if key in memory.payload: + result_item[key] = memory.payload[key] + + additional_metadata = { + k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys } - additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} if additional_metadata: - memory_item["metadata"] = additional_metadata + result_item["metadata"] = additional_metadata + + return result_item - result = {**memory_item, **filters} - - return result - - def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): + def get_all( + self, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + limit: int = 100, + ): """ List all memories. - Returns: - list: List of all memories. - """ - filters = {} - if user_id: - filters["user_id"] = user_id - if agent_id: - filters["agent_id"] = agent_id - if run_id: - filters["run_id"] = run_id + Args: + user_id (str, optional): user id + agent_id (str, optional): agent id + run_id (str, optional): run id + filters (dict, optional): Additional custom key-value filters to apply to the search. + These are merged with the ID-based scoping filters. For example, + `filters={"actor_id": "some_user"}`. + limit (int, optional): The maximum number of memories to return. Defaults to 100. - capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys()), "sync_type": "sync"}) + Returns: + dict: A dictionary containing a list of memories under the "results" key, + and potentially "relations" if graph store is enabled. For API v1.0, + it might return a direct list (see deprecation warning). + Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` + """ + + _, effective_filters = _build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_filters=filters + ) + + if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.") + + capture_event( + "mem0.get_all", + self, + {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "sync"} + ) with concurrent.futures.ThreadPoolExecutor() as executor: - future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) - future_graph_entities = executor.submit(self.graph.get_all, filters, limit) if self.enable_graph else None + future_memories = executor.submit(self._get_all_from_vector_store, effective_filters, limit) + future_graph_entities = ( + executor.submit(self.graph.get_all, effective_filters, limit) if self.enable_graph else None + ) concurrent.futures.wait( [future_memories, future_graph_entities] if future_graph_entities else [future_memories] ) - all_memories = future_memories.result() - graph_entities = future_graph_entities.result() if future_graph_entities else None - + all_memories_result = future_memories.result() + graph_entities_result = future_graph_entities.result() if future_graph_entities else None + if self.enable_graph: - return {"results": all_memories, "relations": graph_entities} + return {"results": all_memories_result, "relations": graph_entities_result} if self.api_version == "v1.0": warnings.warn( "The current get_all API output format is deprecated. " - "To use the latest format, set `api_version='v1.1'`. " - "The current format will be removed in mem0ai 1.1.0 and later versions.", + "To use the latest format, set `api_version='v1.1'` (which returns a dict with a 'results' key). " + "The current format (direct list for v1.0) will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, stacklevel=2, ) - return all_memories + return all_memories_result else: - return {"results": all_memories} + return {"results": all_memories_result} def _get_all_from_vector_store(self, filters, limit): - memories = self.vector_store.list(filters=filters, limit=limit) + memories_result = self.vector_store.list(filters=filters, limit=limit) + actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result - excluded_keys = { - "user_id", - "agent_id", - "run_id", - "hash", - "data", - "created_at", - "updated_at", - "id", - } - all_memories = [ - { - **MemoryItem( - id=mem.id, - memory=mem.payload["data"], - hash=mem.payload.get("hash"), - created_at=mem.payload.get("created_at"), - updated_at=mem.payload.get("updated_at"), - ).model_dump(exclude={"score"}), - **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, - **( - {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} - if any(k for k in mem.payload if k not in excluded_keys) - else {} - ), - } - for mem in memories[0] + promoted_payload_keys = [ + "user_id", "agent_id", "run_id", + "actor_id", + "role", ] - return all_memories + core_and_promoted_keys = { + "data", "hash", "created_at", "updated_at", "id", + *promoted_payload_keys + } - def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None): + formatted_memories = [] + for mem in actual_memories: + memory_item_dict = MemoryItem( + id=mem.id, + memory=mem.payload["data"], + hash=mem.payload.get("hash"), + created_at=mem.payload.get("created_at"), + updated_at=mem.payload.get("updated_at"), + ).model_dump(exclude={"score"}) + + for key in promoted_payload_keys: + if key in mem.payload: + memory_item_dict[key] = mem.payload[key] + + additional_metadata = { + k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys + } + if additional_metadata: + memory_item_dict["metadata"] = additional_metadata + + formatted_memories.append(memory_item_dict) + + return formatted_memories + + def search( + self, + query: str, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + limit: int = 100, + filters: Optional[Dict[str, Any]] = None, + ): """ - Search for memories. - + Searches for memories based on a query Args: query (str): Query to search for. user_id (str, optional): ID of the user to search for. Defaults to None. agent_id (str, optional): ID of the agent to search for. Defaults to None. run_id (str, optional): ID of the run to search for. Defaults to None. limit (int, optional): Limit the number of results. Defaults to 100. - filters (dict, optional): Filters to apply to the search. Defaults to None. + filters (dict, optional): Filters to apply to the search. Defaults to None.. Returns: - list: List of search results. + dict: A dictionary containing the search results, typically under a "results" key, + and potentially "relations" if graph store is enabled. + Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}` """ - filters = filters or {} - if user_id: - filters["user_id"] = user_id - if agent_id: - filters["agent_id"] = agent_id - if run_id: - filters["run_id"] = run_id - - if not any(key in filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError("One of the filters: user_id, agent_id or run_id is required!") + _, effective_filters = _build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_filters=filters + ) + + if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError("At least one of 'user_id', 'agent_id', or 'run_id' must be specified.") capture_event( "mem0.search", self, - {"limit": limit, "version": self.api_version, "keys": list(filters.keys()), "sync_type": "sync"}, + {"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "sync"}, ) with concurrent.futures.ThreadPoolExecutor() as executor: - future_memories = executor.submit(self._search_vector_store, query, filters, limit) + future_memories = executor.submit(self._search_vector_store, query, effective_filters, limit) future_graph_entities = ( - executor.submit(self.graph.search, query, filters, limit) if self.enable_graph else None + executor.submit(self.graph.search, query, effective_filters, limit) if self.enable_graph else None ) concurrent.futures.wait( @@ -514,19 +651,19 @@ class Memory(MemoryBase): original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None - + if self.enable_graph: return {"results": original_memories, "relations": graph_entities} if self.api_version == "v1.0": warnings.warn( - "The current get_all API output format is deprecated. " + "The current search API output format is deprecated. " "To use the latest format, set `api_version='v1.1'`. " "The current format will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, stacklevel=2, ) - return original_memories + return {"results": original_memories} else: return {"results": original_memories} @@ -534,36 +671,41 @@ class Memory(MemoryBase): embeddings = self.embedding_model.embed(query, "search") memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters) - excluded_keys = { + promoted_payload_keys = [ "user_id", "agent_id", "run_id", - "hash", - "data", - "created_at", - "updated_at", - "id", + "actor_id", + "role", + ] + + core_and_promoted_keys = { + "data", "hash", "created_at", "updated_at", "id", + *promoted_payload_keys } - original_memories = [ - { - **MemoryItem( - id=mem.id, - memory=mem.payload["data"], - hash=mem.payload.get("hash"), - created_at=mem.payload.get("created_at"), - updated_at=mem.payload.get("updated_at"), - score=mem.score, - ).model_dump(), - **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, - **( - {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} - if any(k for k in mem.payload if k not in excluded_keys) - else {} - ), + original_memories = [] + for mem in memories: + memory_item_dict = MemoryItem( + id=mem.id, + memory=mem.payload["data"], + hash=mem.payload.get("hash"), + created_at=mem.payload.get("created_at"), + updated_at=mem.payload.get("updated_at"), + score=mem.score, + ).model_dump() + + for key in promoted_payload_keys: + if key in mem.payload: + memory_item_dict[key] = mem.payload[key] + + additional_metadata = { + k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys } - for mem in memories - ] + if additional_metadata: + memory_item_dict["metadata"] = additional_metadata + + original_memories.append(memory_item_dict) return original_memories @@ -596,7 +738,7 @@ class Memory(MemoryBase): self._delete_memory(memory_id) return {"message": "Memory deleted successfully!"} - def delete_all(self, user_id=None, agent_id=None, run_id=None): + def delete_all(self, user_id:Optional[str]=None, agent_id:Optional[str]=None, run_id:Optional[str]=None): """ Delete all memories. @@ -605,7 +747,7 @@ class Memory(MemoryBase): agent_id (str, optional): ID of the agent to delete memories for. Defaults to None. run_id (str, optional): ID of the run to delete memories for. Defaults to None. """ - filters = {} + filters: Dict[str, Any] = {} if user_id: filters["user_id"] = user_id if agent_id: @@ -660,7 +802,15 @@ class Memory(MemoryBase): ids=[memory_id], payloads=[metadata], ) - self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"]) + self.db.add_history( + memory_id, + None, + data, + "ADD", + created_at=metadata.get("created_at"), + actor_id=metadata.get("actor_id"), + role=metadata.get("role"), + ) capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "sync"}) return memory_id @@ -694,13 +844,10 @@ class Memory(MemoryBase): raise ValueError("Metadata cannot be done for procedural memory.") metadata["memory_type"] = MemoryType.PROCEDURAL.value - # Generate embeddings for the summary embeddings = self.embedding_model.embed(procedural_memory, memory_action="add") - # Create the memory memory_id = self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata) capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id, "sync_type": "sync"}) - # Return results in the same format as add() result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]} return result @@ -711,10 +858,13 @@ class Memory(MemoryBase): try: existing_memory = self.vector_store.get(vector_id=memory_id) except Exception: + logger.error(f"Error getting memory with ID {memory_id} during update.") raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") + prev_value = existing_memory.payload.get("data") - new_metadata = metadata or {} + new_metadata = deepcopy(metadata) if metadata is not None else {} + new_metadata["data"] = data new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() new_metadata["created_at"] = existing_memory.payload.get("created_at") @@ -725,18 +875,24 @@ class Memory(MemoryBase): if "agent_id" in existing_memory.payload: new_metadata["agent_id"] = existing_memory.payload["agent_id"] if "run_id" in existing_memory.payload: - new_metadata["run_id"] = existing_memory.payload["run_id"] + new_metadata["run_id"] = existing_memory.payload["run_id"] + if "actor_id" in existing_memory.payload: + new_metadata["actor_id"] = existing_memory.payload["actor_id"] + if "role" in existing_memory.payload: + new_metadata["role"] = existing_memory.payload["role"] if data in existing_embeddings: embeddings = existing_embeddings[data] else: embeddings = self.embedding_model.embed(data, "update") + self.vector_store.update( vector_id=memory_id, vector=embeddings, payload=new_metadata, ) logger.info(f"Updating memory with ID {memory_id=} with {data=}") + self.db.add_history( memory_id, prev_value, @@ -744,6 +900,8 @@ class Memory(MemoryBase): "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"], + actor_id=new_metadata.get("actor_id"), + role=new_metadata.get("role"), ) capture_event("mem0._update_memory", self, {"memory_id": memory_id, "sync_type": "sync"}) return memory_id @@ -753,7 +911,15 @@ class Memory(MemoryBase): existing_memory = self.vector_store.get(vector_id=memory_id) prev_value = existing_memory.payload["data"] self.vector_store.delete(vector_id=memory_id) - self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1) + self.db.add_history( + memory_id, + prev_value, + None, + "DELETE", + actor_id=existing_memory.payload.get("actor_id"), + role=existing_memory.payload.get("role"), + is_deleted=1, + ) capture_event("mem0._delete_memory", self, {"memory_id": memory_id, "sync_type": "sync"}) return memory_id @@ -766,7 +932,6 @@ class Memory(MemoryBase): """ logger.warning("Resetting all memories") - # Close the old connection if possible if hasattr(self.db, "connection") and self.db.connection: self.db.connection.execute("DROP TABLE IF EXISTS history") self.db.connection.close() @@ -844,14 +1009,14 @@ class AsyncMemory(MemoryBase): async def add( self, messages, - user_id=None, - agent_id=None, - run_id=None, - metadata=None, - filters=None, - infer=True, - memory_type=None, - prompt=None, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + infer: bool = True, + memory_type: Optional[str] = None, + prompt: Optional[str] = None, llm=None, ): """ @@ -859,40 +1024,25 @@ class AsyncMemory(MemoryBase): Args: messages (str or List[Dict[str, str]]): Messages to store in the memory. - user_id (str, optional): ID of the user creating the memory. Defaults to None. + user_id (str, optional): ID of the user creating the memory. agent_id (str, optional): ID of the agent creating the memory. Defaults to None. run_id (str, optional): ID of the run creating the memory. Defaults to None. metadata (dict, optional): Metadata to store with the memory. Defaults to None. - filters (dict, optional): Filters to apply to the search. Defaults to None. infer (bool, optional): Whether to infer the memories. Defaults to True. - memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories. + memory_type (str, optional): Type of memory to create. Defaults to None. + Pass "procedural_memory" to create procedural memories. prompt (str, optional): Prompt to use for the memory creation. Defaults to None. llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel. Returns: dict: A dictionary containing the result of the memory addition operation. - result: dict of affected events with each dict has the following key: - 'memories': affected memories - 'graph': affected graph memories - - 'memories' and 'graph' is a dict, each with following subkeys: - 'add': added memory - 'update': updated memory - 'delete': deleted memory """ - if metadata is None: - metadata = {} - - filters = filters or {} - if user_id: - filters["user_id"] = metadata["user_id"] = user_id - if agent_id: - filters["agent_id"] = metadata["agent_id"] = agent_id - if run_id: - filters["run_id"] = metadata["run_id"] = run_id - - if not any(key in filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError("One of the filters: user_id, agent_id or run_id is required!") - + processed_metadata, effective_filters = _build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_metadata=metadata + ) + if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: raise ValueError( f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." @@ -900,9 +1050,15 @@ class AsyncMemory(MemoryBase): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] + + elif isinstance(messages, dict): + messages = [messages] + + elif not isinstance(messages, list): + raise ValueError("messages must be str, dict, or list[dict]") if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value: - results = await self._create_procedural_memory(messages, metadata=metadata, llm=llm, prompt=prompt) + results = await self._create_procedural_memory(messages, metadata=processed_metadata, prompt=prompt, llm=llm) return results if self.config.llm.config.get("enable_vision"): @@ -910,9 +1066,8 @@ class AsyncMemory(MemoryBase): else: messages = parse_vision_messages(messages) - # Run vector store and graph operations concurrently - vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, metadata, filters, infer)) - graph_task = asyncio.create_task(self._add_to_graph(messages, filters)) + vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, processed_metadata, effective_filters, infer)) + graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters)) vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task) @@ -934,18 +1089,44 @@ class AsyncMemory(MemoryBase): return {"results": vector_store_result} - async def _add_to_vector_store(self, messages, metadata, filters, infer): + async def _add_to_vector_store( + self, + messages: list, + metadata: dict, + filters: dict, + infer: bool, + ): if not infer: returned_memories = [] - for message in messages: - if message["role"] != "system": - message_embeddings = await asyncio.to_thread(self.embedding_model.embed, message["content"], "add") - memory_id = await self._create_memory(message["content"], message_embeddings, metadata) - returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"}) + for message_dict in messages: + if not isinstance(message_dict, dict) or \ + message_dict.get("role") is None or \ + message_dict.get("content") is None: + logger.warning(f"Skipping invalid message format (async): {message_dict}") + continue + + if message_dict["role"] == "system": + continue + + per_msg_meta = deepcopy(metadata) + per_msg_meta["role"] = message_dict["role"] + + actor_name = message_dict.get("name") + if actor_name: + per_msg_meta["actor_id"] = actor_name + + msg_content = message_dict["content"] + msg_embeddings = await asyncio.to_thread(self.embedding_model.embed, msg_content, "add") + mem_id = await self._create_memory(msg_content, msg_embeddings, per_msg_meta) + + returned_memories.append({ + "id": mem_id, "memory": msg_content, "event": "ADD", + "actor_id": actor_name if actor_name else None, + "role": message_dict["role"], + }) return returned_memories parsed_messages = parse_messages(messages) - if self.config.custom_fact_extraction_prompt: system_prompt = self.config.custom_fact_extraction_prompt user_prompt = f"Input:\n{parsed_messages}" @@ -954,51 +1135,36 @@ class AsyncMemory(MemoryBase): response = await asyncio.to_thread( self.llm.generate_response, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], + messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], response_format={"type": "json_object"}, ) - try: response = remove_code_blocks(response) new_retrieved_facts = json.loads(response)["facts"] except Exception as e: - logging.error(f"Error in new_retrieved_facts: {e}") - new_retrieved_facts = [] + logging.error(f"Error in new_retrieved_facts: {e}"); new_retrieved_facts = [] retrieved_old_memory = [] new_message_embeddings = {} - - # Process all facts concurrently - async def process_fact(new_mem): - messages_embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem, "add") - new_message_embeddings[new_mem] = messages_embeddings - existing_memories = await asyncio.to_thread( - self.vector_store.search, - query=new_mem, - vectors=messages_embeddings, - limit=5, - filters=filters, + + async def process_fact_for_search(new_mem_content): + embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem_content, "add") + new_message_embeddings[new_mem_content] = embeddings + existing_mems = await asyncio.to_thread( + self.vector_store.search, query=new_mem_content, vectors=embeddings, + limit=5, filters=filters, # 'filters' is query_filters_for_inference ) - return [(mem.id, mem.payload["data"]) for mem in existing_memories] - - fact_tasks = [process_fact(fact) for fact in new_retrieved_facts] - fact_results = await asyncio.gather(*fact_tasks) - - # Flatten results and build retrieved_old_memory - for result in fact_results: - for mem_id, mem_data in result: - retrieved_old_memory.append({"id": mem_id, "text": mem_data}) + return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems] + search_tasks = [process_fact_for_search(fact) for fact in new_retrieved_facts] + search_results_list = await asyncio.gather(*search_tasks) + for result_group in search_results_list: + retrieved_old_memory.extend(result_group) + unique_data = {} - for item in retrieved_old_memory: - unique_data[item["id"]] = item + for item in retrieved_old_memory: unique_data[item["id"]] = item retrieved_old_memory = list(unique_data.values()) logging.info(f"Total existing memories: {len(retrieved_old_memory)}") - - # mapping UUIDs with integers for handling UUID hallucinations temp_uuid_mapping = {} for idx, item in enumerate(retrieved_old_memory): temp_uuid_mapping[str(idx)] = item["id"] @@ -1007,99 +1173,76 @@ class AsyncMemory(MemoryBase): function_calling_prompt = get_update_memory_messages( retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt ) - try: - response: str = await asyncio.to_thread( + response = await asyncio.to_thread( self.llm.generate_response, messages=[{"role": "user", "content": function_calling_prompt}], response_format={"type": "json_object"}, ) except Exception as e: - logging.error(f"Error in new memory actions response: {e}") - response = "" - + logging.error(f"Error in new memory actions response: {e}"); response = "" + try: response = remove_code_blocks(response) new_memories_with_actions = json.loads(response) except Exception as e: - logging.error(f"Invalid JSON response: {e}") - new_memories_with_actions = {} + logging.error(f"Invalid JSON response: {e}"); new_memories_with_actions = {} - returned_memories = [] + returned_memories = [] try: memory_tasks = [] for resp in new_memories_with_actions.get("memory", []): logging.info(resp) try: - if not resp.get("text"): - logging.info("Skipping memory entry because of empty `text` field.") - continue - elif resp.get("event") == "ADD": - task = asyncio.create_task( - self._create_memory( - data=resp.get("text"), - existing_embeddings=new_message_embeddings, - metadata=deepcopy(metadata), - ) - ) - memory_tasks.append((task, resp, "ADD", None)) - elif resp.get("event") == "UPDATE": - task = asyncio.create_task( - self._update_memory( - memory_id=temp_uuid_mapping[resp["id"]], - data=resp.get("text"), - existing_embeddings=new_message_embeddings, - metadata=deepcopy(metadata), - ) - ) - memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]])) - elif resp.get("event") == "DELETE": - task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])) - memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp["id"]])) - elif resp.get("event") == "NONE": - logging.info("NOOP for Memory.") - except Exception as e: - logging.error(f"Error in new_memories_with_actions: {e}") + action_text = resp.get("text") + if not action_text: continue + event_type = resp.get("event") - # Wait for all memory operations to complete + if event_type == "ADD": + task = asyncio.create_task(self._create_memory( + data=action_text, existing_embeddings=new_message_embeddings, + metadata=deepcopy(metadata) + )) + memory_tasks.append((task, resp, "ADD", None)) + elif event_type == "UPDATE": + task = asyncio.create_task(self._update_memory( + memory_id=temp_uuid_mapping[resp["id"]], data=action_text, + existing_embeddings=new_message_embeddings, metadata=deepcopy(metadata) + )) + memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]])) + elif event_type == "DELETE": + task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])) + memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp.get("id")])) + elif event_type == "NONE": + logging.info("NOOP for Memory (async).") + except Exception as e: + logging.error(f"Error processing memory action (async): {resp}, Error: {e}") + for task, resp, event_type, mem_id in memory_tasks: try: result_id = await task if event_type == "ADD": - returned_memories.append( - { - "id": result_id, - "memory": resp.get("text"), - "event": resp.get("event"), - } - ) + returned_memories.append({ + "id": result_id, "memory": resp.get("text"), "event": event_type + }) elif event_type == "UPDATE": - returned_memories.append( - { - "id": mem_id, - "memory": resp.get("text"), - "event": resp.get("event"), - "previous_memory": resp.get("old_memory"), - } - ) + returned_memories.append({ + "id": mem_id, "memory": resp.get("text"), + "event": event_type, "previous_memory": resp.get("old_memory") + }) elif event_type == "DELETE": - returned_memories.append( - { - "id": mem_id, - "memory": resp.get("text"), - "event": resp.get("event"), - } - ) + returned_memories.append({ + "id": mem_id, "memory": resp.get("text"), "event": event_type + }) except Exception as e: - logging.error(f"Error processing memory task: {e}") - + logging.error(f"Error awaiting memory task (async): {e}") except Exception as e: - logging.error(f"Error in new_memories_with_actions: {e}") - + logging.error(f"Error in memory processing loop (async): {e}") + capture_event( - "mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"} + "mem0.add", self, + {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"} ) - return returned_memories async def _add_to_graph(self, messages, filters): @@ -1128,156 +1271,223 @@ class AsyncMemory(MemoryBase): if not memory: return None - filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)} + promoted_payload_keys = [ + "user_id", + "agent_id", + "run_id", + "actor_id", + "role", + ] + + core_and_promoted_keys = { + "data", "hash", "created_at", "updated_at", "id", + *promoted_payload_keys + } - # Prepare base memory item - memory_item = MemoryItem( + result_item = MemoryItem( id=memory.id, memory=memory.payload["data"], hash=memory.payload.get("hash"), created_at=memory.payload.get("created_at"), updated_at=memory.payload.get("updated_at"), - ).model_dump(exclude={"score"}) + ).model_dump() - # Add metadata if there are additional keys - excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"} - additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} + for key in promoted_payload_keys: + if key in memory.payload: + result_item[key] = memory.payload[key] + + additional_metadata = { + k: v for k, v in memory.payload.items() if k not in core_and_promoted_keys + } if additional_metadata: - memory_item["metadata"] = additional_metadata + result_item["metadata"] = additional_metadata + + return result_item - result = {**memory_item, **filters} - - return result - - async def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): + async def get_all( + self, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + limit: int = 100, + ): """ - List all memories asynchronously. + List all memories. + + Args: + user_id (str, optional): user id + agent_id (str, optional): agent id + run_id (str, optional): run id + filters (dict, optional): Additional custom key-value filters to apply to the search. + These are merged with the ID-based scoping filters. For example, + `filters={"actor_id": "some_user"}`. + limit (int, optional): The maximum number of memories to return. Defaults to 100. Returns: - list: List of all memories. + dict: A dictionary containing a list of memories under the "results" key, + and potentially "relations" if graph store is enabled. For API v1.0, + it might return a direct list (see deprecation warning). + Example for v1.1+: `{"results": [{"id": "...", "memory": "...", ...}]}` """ - filters = {} - if user_id: - filters["user_id"] = user_id - if agent_id: - filters["agent_id"] = agent_id - if run_id: - filters["run_id"] = run_id + + _, effective_filters = _build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_filters=filters + ) - capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys()), "sync_type": "async"}) + if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError( + "When 'conversation_id' is not provided (classic mode), " + "at least one of 'user_id', 'agent_id', or 'run_id' must be specified for get_all." + ) - # Run vector store and graph operations concurrently - vector_store_task = asyncio.create_task(self._get_all_from_vector_store(filters, limit)) + capture_event( + "mem0.get_all", + self, + {"limit": limit, "keys": list(effective_filters.keys()), "sync_type": "async"} + ) + with concurrent.futures.ThreadPoolExecutor() as executor: + future_memories = executor.submit(self._get_all_from_vector_store, effective_filters, limit) + future_graph_entities = ( + executor.submit(self.graph.get_all, effective_filters, limit) if self.enable_graph else None + ) + + concurrent.futures.wait( + [future_memories, future_graph_entities] if future_graph_entities else [future_memories] + ) + + all_memories_result = future_memories.result() + graph_entities_result = future_graph_entities.result() if future_graph_entities else None + if self.enable_graph: - graph_task = asyncio.create_task(asyncio.to_thread(self.graph.get_all, filters, limit)) - all_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task) - else: - all_memories = await vector_store_task - graph_entities = None - - if self.enable_graph: - return {"results": all_memories, "relations": graph_entities} + return {"results": all_memories_result, "relations": graph_entities_result} if self.api_version == "v1.0": warnings.warn( "The current get_all API output format is deprecated. " - "To use the latest format, set `api_version='v1.1'`. " - "The current format will be removed in mem0ai 1.1.0 and later versions.", + "To use the latest format, set `api_version='v1.1'` (which returns a dict with a 'results' key). " + "The current format (direct list for v1.0) will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, stacklevel=2, ) - return all_memories + return all_memories_result else: - return {"results": all_memories} + return {"results": all_memories_result} async def _get_all_from_vector_store(self, filters, limit): - memories = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit) + memories_result = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit) + actual_memories = memories_result[0] if isinstance(memories_result, tuple) and len(memories_result) > 0 else memories_result - excluded_keys = { - "user_id", - "agent_id", - "run_id", - "hash", - "data", - "created_at", - "updated_at", - "id", - } - all_memories = [ - { - **MemoryItem( - id=mem.id, - memory=mem.payload["data"], - hash=mem.payload.get("hash"), - created_at=mem.payload.get("created_at"), - updated_at=mem.payload.get("updated_at"), - ).model_dump(exclude={"score"}), - **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, - **( - {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} - if any(k for k in mem.payload if k not in excluded_keys) - else {} - ), - } - for mem in memories[0] + promoted_payload_keys = [ + "user_id", "agent_id", "run_id", + "actor_id", + "role", ] - return all_memories + core_and_promoted_keys = { + "data", "hash", "created_at", "updated_at", "id", + *promoted_payload_keys + } - async def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None): + formatted_memories = [] + for mem in actual_memories: + memory_item_dict = MemoryItem( + id=mem.id, + memory=mem.payload["data"], + hash=mem.payload.get("hash"), + created_at=mem.payload.get("created_at"), + updated_at=mem.payload.get("updated_at"), + ).model_dump(exclude={"score"}) + + for key in promoted_payload_keys: + if key in mem.payload: + memory_item_dict[key] = mem.payload[key] + + additional_metadata = { + k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys + } + if additional_metadata: + memory_item_dict["metadata"] = additional_metadata + + formatted_memories.append(memory_item_dict) + + return formatted_memories + + async def search( + self, + query: str, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + limit: int = 100, + filters: Optional[Dict[str, Any]] = None, + ): """ - Search for memories asynchronously. - + Searches for memories based on a query Args: query (str): Query to search for. user_id (str, optional): ID of the user to search for. Defaults to None. agent_id (str, optional): ID of the agent to search for. Defaults to None. run_id (str, optional): ID of the run to search for. Defaults to None. limit (int, optional): Limit the number of results. Defaults to 100. - filters (dict, optional): Filters to apply to the search. Defaults to None. + filters (dict, optional): Filters to apply to the search. Defaults to None.. Returns: - list: List of search results. + dict: A dictionary containing the search results, typically under a "results" key, + and potentially "relations" if graph store is enabled. + Example for v1.1+: `{"results": [{"id": "...", "memory": "...", "score": 0.8, ...}]}` """ - filters = filters or {} - if user_id: - filters["user_id"] = user_id - if agent_id: - filters["agent_id"] = agent_id - if run_id: - filters["run_id"] = run_id + + _, effective_filters = _build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_filters=filters + ) - if not any(key in filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError("One of the filters: user_id, agent_id or run_id is required!") + if not any(key in effective_filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError("at least one of 'user_id', 'agent_id', or 'run_id' must be specified ") capture_event( "mem0.search", self, - {"limit": limit, "version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}, + {"limit": limit, "version": self.api_version, "keys": list(effective_filters.keys()), "sync_type": "async"}, ) - # Run vector store and graph operations concurrently - vector_store_task = asyncio.create_task(self._search_vector_store(query, filters, limit)) - + vector_store_task = asyncio.create_task(self._search_vector_store(query, effective_filters, limit)) + + graph_task = None if self.enable_graph: - graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, filters, limit)) + if hasattr(self.graph.search, "__await__"): # Check if graph search is async + graph_task = asyncio.create_task(self.graph.search(query, effective_filters, limit)) + else: + graph_task = asyncio.create_task( + asyncio.to_thread(self.graph.search, query, effective_filters, limit) + ) + + if graph_task: original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task) else: original_memories = await vector_store_task graph_entities = None - + if self.enable_graph: return {"results": original_memories, "relations": graph_entities} if self.api_version == "v1.0": warnings.warn( - "The current get_all API output format is deprecated. " + "The current search API output format is deprecated. " "To use the latest format, set `api_version='v1.1'`. " "The current format will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, stacklevel=2, ) - return original_memories + return {"results": original_memories} else: return {"results": original_memories} @@ -1287,36 +1497,41 @@ class AsyncMemory(MemoryBase): self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters ) - excluded_keys = { + promoted_payload_keys = [ "user_id", "agent_id", "run_id", - "hash", - "data", - "created_at", - "updated_at", - "id", + "actor_id", + "role", + ] + + core_and_promoted_keys = { + "data", "hash", "created_at", "updated_at", "id", + *promoted_payload_keys } - original_memories = [ - { - **MemoryItem( - id=mem.id, - memory=mem.payload["data"], - hash=mem.payload.get("hash"), - created_at=mem.payload.get("created_at"), - updated_at=mem.payload.get("updated_at"), - score=mem.score, - ).model_dump(), - **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, - **( - {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} - if any(k for k in mem.payload if k not in excluded_keys) - else {} - ), + original_memories = [] + for mem in memories: + memory_item_dict = MemoryItem( + id=mem.id, + memory=mem.payload["data"], + hash=mem.payload.get("hash"), + created_at=mem.payload.get("created_at"), + updated_at=mem.payload.get("updated_at"), + score=mem.score, + ).model_dump() + + for key in promoted_payload_keys: + if key in mem.payload: + memory_item_dict[key] = mem.payload[key] + + additional_metadata = { + k: v for k, v in mem.payload.items() if k not in core_and_promoted_keys } - for mem in memories - ] + if additional_metadata: + memory_item_dict["metadata"] = additional_metadata + + original_memories.append(memory_item_dict) return original_memories @@ -1421,19 +1636,28 @@ class AsyncMemory(MemoryBase): payloads=[metadata], ) - await asyncio.to_thread(self.db.add_history, memory_id, None, data, "ADD", created_at=metadata["created_at"]) + await asyncio.to_thread( + self.db.add_history, + memory_id, + None, + data, + "ADD", + created_at=metadata.get("created_at"), + actor_id=metadata.get("actor_id"), + role=metadata.get("role"), + ) capture_event("mem0._create_memory", self, {"memory_id": memory_id, "sync_type": "async"}) return memory_id - async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None): + async def _create_procedural_memory(self, messages, metadata=None,llm=None ,prompt=None): """ Create a procedural memory asynchronously Args: messages (list): List of messages to create a procedural memory from. metadata (dict): Metadata to create a procedural memory from. - llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel. + llm (llm, optional): LLM to use for the procedural memory creation. Defaults to None. prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None. """ try: @@ -1469,13 +1693,10 @@ class AsyncMemory(MemoryBase): raise ValueError("Metadata cannot be done for procedural memory.") metadata["memory_type"] = MemoryType.PROCEDURAL.value - # Generate embeddings for the summary embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add") - # Create the memory memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata) capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id, "sync_type": "async"}) - # Return results in the same format as add() result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]} return result @@ -1486,11 +1707,13 @@ class AsyncMemory(MemoryBase): try: existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id) except Exception: + logger.error(f"Error getting memory with ID {memory_id} during update.") raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") - + prev_value = existing_memory.payload.get("data") - new_metadata = metadata or {} + new_metadata = deepcopy(metadata) if metadata is not None else {} + new_metadata["data"] = data new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() new_metadata["created_at"] = existing_memory.payload.get("created_at") @@ -1502,21 +1725,26 @@ class AsyncMemory(MemoryBase): new_metadata["agent_id"] = existing_memory.payload["agent_id"] if "run_id" in existing_memory.payload: new_metadata["run_id"] = existing_memory.payload["run_id"] + + + if "actor_id" in existing_memory.payload: + new_metadata["actor_id"] = existing_memory.payload["actor_id"] + if "role" in existing_memory.payload: + new_metadata["role"] = existing_memory.payload["role"] if data in existing_embeddings: embeddings = existing_embeddings[data] else: embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update") - + await asyncio.to_thread( self.vector_store.update, vector_id=memory_id, vector=embeddings, payload=new_metadata, ) - logger.info(f"Updating memory with ID {memory_id=} with {data=}") - + await asyncio.to_thread( self.db.add_history, memory_id, @@ -1525,8 +1753,9 @@ class AsyncMemory(MemoryBase): "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"], + actor_id=new_metadata.get("actor_id"), + role=new_metadata.get("role"), ) - capture_event("mem0._update_memory", self, {"memory_id": memory_id, "sync_type": "async"}) return memory_id @@ -1536,7 +1765,16 @@ class AsyncMemory(MemoryBase): prev_value = existing_memory.payload["data"] await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id) - await asyncio.to_thread(self.db.add_history, memory_id, prev_value, None, "DELETE", is_deleted=1) + await asyncio.to_thread( + self.db.add_history, + memory_id, + prev_value, + None, + "DELETE", + actor_id=existing_memory.payload.get("actor_id"), + role=existing_memory.payload.get("role"), + is_deleted=1, + ) capture_event("mem0._delete_memory", self, {"memory_id": memory_id, "sync_type": "async"}) return memory_id diff --git a/mem0/memory/storage.py b/mem0/memory/storage.py index b02749cb..982ee020 100644 --- a/mem0/memory/storage.py +++ b/mem0/memory/storage.py @@ -1,144 +1,160 @@ import sqlite3 import threading import uuid +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) class SQLiteManager: - def __init__(self, db_path=":memory:"): - self.connection = sqlite3.connect(db_path, check_same_thread=False) + def __init__(self, db_path: str = ":memory:"): + self.db_path = db_path + self.connection = sqlite3.connect(self.db_path, check_same_thread=False) self._lock = threading.Lock() self._migrate_history_table() self._create_history_table() - def _migrate_history_table(self): - with self._lock: - with self.connection: - cursor = self.connection.cursor() + def _migrate_history_table(self) -> None: + """ + If a pre-existing history table had the old group-chat columns, + rename it, create the new schema, copy the intersecting data, then + drop the old table. + """ + with self._lock, self.connection: + cur = self.connection.cursor() + cur.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='history'" + ) + if cur.fetchone() is None: + return # nothing to migrate - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'") - table_exists = cursor.fetchone() is not None + cur.execute("PRAGMA table_info(history)") + old_cols = {row[1] for row in cur.fetchall()} - if table_exists: - # Get the current schema of the history table - cursor.execute("PRAGMA table_info(history)") - current_schema = {row[1]: row[2] for row in cursor.fetchall()} + expected_cols = { + "id", + "memory_id", + "old_memory", + "new_memory", + "event", + "created_at", + "updated_at", + "is_deleted", + "actor_id", + "role", + } - # Define the expected schema - expected_schema = { - "id": "TEXT", - "memory_id": "TEXT", - "old_memory": "TEXT", - "new_memory": "TEXT", - "new_value": "TEXT", - "event": "TEXT", - "created_at": "DATETIME", - "updated_at": "DATETIME", - "is_deleted": "INTEGER", - } + if old_cols == expected_cols: + return - # Check if the schemas are the same - if current_schema != expected_schema: - # Rename the old table - cursor.execute("ALTER TABLE history RENAME TO old_history") + logger.info("Migrating history table to new schema (no convo columns).") + cur.execute("ALTER TABLE history RENAME TO history_old") - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS history ( - id TEXT PRIMARY KEY, - memory_id TEXT, - old_memory TEXT, - new_memory TEXT, - new_value TEXT, - event TEXT, - created_at DATETIME, - updated_at DATETIME, - is_deleted INTEGER - ) - """ - ) + self._create_history_table() - # Copy data from the old table to the new table - cursor.execute( - """ - INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted) - SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted - FROM old_history - """ # noqa: E501 - ) + intersecting = list(expected_cols & old_cols) + cols_csv = ", ".join(intersecting) + cur.execute( + f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old" + ) + cur.execute("DROP TABLE history_old") - cursor.execute("DROP TABLE old_history") - - self.connection.commit() - - def _create_history_table(self): - with self._lock: - with self.connection: - self.connection.execute( - """ - CREATE TABLE IF NOT EXISTS history ( - id TEXT PRIMARY KEY, - memory_id TEXT, - old_memory TEXT, - new_memory TEXT, - new_value TEXT, - event TEXT, - created_at DATETIME, - updated_at DATETIME, - is_deleted INTEGER - ) + def _create_history_table(self) -> None: + with self._lock, self.connection: + self.connection.execute( """ + CREATE TABLE IF NOT EXISTS history ( + id TEXT PRIMARY KEY, + memory_id TEXT, + old_memory TEXT, + new_memory TEXT, + event TEXT, + created_at DATETIME, + updated_at DATETIME, + is_deleted INTEGER, + actor_id TEXT, + role TEXT ) + """ + ) def add_history( self, - memory_id, - old_memory, - new_memory, - event, - created_at=None, - updated_at=None, - is_deleted=0, - ): - with self._lock: - with self.connection: - self.connection.execute( - """ - INSERT INTO history (id, memory_id, old_memory, new_memory, event, created_at, updated_at, is_deleted) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(uuid.uuid4()), - memory_id, - old_memory, - new_memory, - event, - created_at, - updated_at, - is_deleted, - ), - ) - - def get_history(self, memory_id): - with self._lock: - cursor = self.connection.execute( + memory_id: str, + old_memory: Optional[str], + new_memory: Optional[str], + event: str, + *, + created_at: Optional[str] = None, + updated_at: Optional[str] = None, + is_deleted: int = 0, + actor_id: Optional[str] = None, + role: Optional[str] = None, + ) -> None: + with self._lock, self.connection: + self.connection.execute( """ - SELECT id, memory_id, old_memory, new_memory, event, created_at, updated_at + INSERT INTO history ( + id, memory_id, old_memory, new_memory, event, + created_at, updated_at, is_deleted, actor_id, role + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + str(uuid.uuid4()), + memory_id, + old_memory, + new_memory, + event, + created_at, + updated_at, + is_deleted, + actor_id, + role, + ), + ) + + def get_history(self, memory_id: str) -> List[Dict[str, Any]]: + with self._lock: + cur = self.connection.execute( + """ + SELECT id, memory_id, old_memory, new_memory, event, + created_at, updated_at, is_deleted, actor_id, role FROM history WHERE memory_id = ? - ORDER BY updated_at ASC + ORDER BY created_at ASC, DATETIME(updated_at) ASC """, (memory_id,), ) - rows = cursor.fetchall() - return [ - { - "id": row[0], - "memory_id": row[1], - "old_memory": row[2], - "new_memory": row[3], - "event": row[4], - "created_at": row[5], - "updated_at": row[6], - } - for row in rows - ] + rows = cur.fetchall() + + return [ + { + "id": r[0], + "memory_id": r[1], + "old_memory": r[2], + "new_memory": r[3], + "event": r[4], + "created_at": r[5], + "updated_at": r[6], + "is_deleted": bool(r[7]), + "actor_id": r[8], + "role": r[9], + } + for r in rows + ] + + def reset(self) -> None: + """Drop and recreate the history table.""" + with self._lock, self.connection: + self.connection.execute("DROP TABLE IF EXISTS history") + self._create_history_table() + + def close(self) -> None: + if self.connection: + self.connection.close() + self.connection = None + + def __del__(self): + self.close()