From 616313b8b59319e8b60e81c3263739112399f7df Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 10 Apr 2025 11:16:44 +0530 Subject: [PATCH] Add async support (#2492) Co-authored-by: Deshraj Yadav --- docs/docs.json | 1 + docs/open-source/features/async-memory.mdx | 169 +++++ docs/open-source/python-quickstart.mdx | 10 + mem0/__init__.py | 2 +- mem0/memory/main.py | 777 ++++++++++++++++++++- 5 files changed, 955 insertions(+), 4 deletions(-) create mode 100644 docs/open-source/features/async-memory.mdx diff --git a/docs/docs.json b/docs/docs.json index 3ce05c79..1d479c61 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -73,6 +73,7 @@ "group": "Features", "icon": "wrench", "pages": [ + "open-source/features/async-memory", "features/openai_compatibility", "features/custom-fact-extraction-prompt", "features/custom-update-memory-prompt", diff --git a/docs/open-source/features/async-memory.mdx b/docs/open-source/features/async-memory.mdx new file mode 100644 index 00000000..a6c7860a --- /dev/null +++ b/docs/open-source/features/async-memory.mdx @@ -0,0 +1,169 @@ +--- +title: Async Memory +description: 'Asynchronous memory for Mem0' +icon: "bolt" +iconType: "solid" +--- + +## AsyncMemory + +The `AsyncMemory` class is a direct asynchronous interface to Mem0's in-process memory operations. Unlike the memory, which interacts with an API, `AsyncMemory` works directly with the underlying storage systems. This makes it ideal for applications where you want to embed Mem0 directly into your codebase. + +### Initialization + +To use `AsyncMemory`, import it from the `mem0.memory` module: + +```python Python +import asyncio +from mem0 import AsyncMemory + +# Initialize with default configuration +memory = AsyncMemory() + +# Or initialize with custom configuration +from mem0.configs.base import MemoryConfig +custom_config = MemoryConfig( + # Your custom configuration here +) +memory = AsyncMemory(config=custom_config) +``` + +### Key Features + +1. **Non-blocking Operations** - All memory operations use `asyncio` to avoid blocking the event loop +2. **Concurrent Processing** - Parallel execution of vector store and graph operations +3. **Efficient Resource Utilization** - Better handling of I/O bound operations +4. **Compatible with Async Frameworks** - Seamless integration with FastAPI, aiohttp, and other async frameworks + +### Methods + +All methods in `AsyncMemory` have the same parameters as the synchronous `Memory` class but are designed to be used with `async/await`. + +#### Create memories + +Add a new memory asynchronously: + +```python Python +await memory.add( + messages=[ + {"role": "user", "content": "I'm travelling to SF"}, + {"role": "assistant", "content": "That's great to hear!"} + ], + user_id="alice" +) +``` + +#### Retrieve memories + +Retrieve memories related to a query: + +```python Python +await memory.search( + query="Where am I travelling?", + user_id="alice" +) +``` + +#### List memories + +List all memories for a `user_id`, `agent_id`, or `run_id`: + +```python Python +await memory.get_all(user_id="alice") +``` + +#### Get specific memory + +Retrieve a specific memory by its ID: + +```python Python +await memory.get(memory_id="memory-id-here") +``` + +#### Update memory + +Update an existing memory by ID: + +```python Python +await memory.update( + memory_id="memory-id-here", + data="I'm travelling to Seattle" +) +``` + +#### Delete memory + +Delete a specific memory by ID: + +```python Python +await memory.delete(memory_id="memory-id-here") +``` + +#### Delete all memories + +Delete all memories for a specific user, agent, or run: + +```python Python +await memory.delete_all(user_id="alice") +``` + +Note: At least one filter (user_id, agent_id, or run_id) is required when using delete_all. + +#### Memory History + +Get the history of changes for a specific memory: + +```python Python +await memory.history(memory_id="memory-id-here") +``` + +### Example: Concurrent Usage with Other APIs + +`AsyncMemory` can be effectively combined with other async operations. Here's an example showing how to use it alongside OpenAI API calls in separate threads: + +```python Python +import asyncio +from openai import AsyncOpenAI +from mem0 import AsyncMemory + +async_openai_client = AsyncOpenAI() +async_memory = AsyncMemory() + +async def chat_with_memories(message: str, user_id: str = "default_user") -> str: + # Retrieve relevant memories + search_result = await async_memory.search(query=message, user_id=user_id, limit=3) + relevant_memories = search_result["results"] + memories_str = "\n".join(f"- {entry['memory']}" for entry in relevant_memories) + + # Generate Assistant response + system_prompt = f"You are a helpful AI. Answer the question based on query and memories.\nUser Memories:\n{memories_str}" + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": message}] + response = await async_openai_client.chat.completions.create(model="gpt-4o-mini", messages=messages) + assistant_response = response.choices[0].message.content + + # Create new memories from the conversation + messages.append({"role": "assistant", "content": assistant_response}) + await async_memory.add(messages, user_id=user_id) + + return assistant_response + +async def async_main(): + print("Chat with AI (type 'exit' to quit)") + while True: + user_input = input("You: ").strip() + if user_input.lower() == 'exit': + print("Goodbye!") + break + response = await chat_with_memories(user_input) + print(f"AI: {response}") + +def main(): + asyncio.run(async_main()) + +if __name__ == "__main__": + main() +``` + +If you have any questions or need further assistance, please don't hesitate to reach out: + + diff --git a/docs/open-source/python-quickstart.mdx b/docs/open-source/python-quickstart.mdx index db25efbf..e0a90642 100644 --- a/docs/open-source/python-quickstart.mdx +++ b/docs/open-source/python-quickstart.mdx @@ -28,6 +28,16 @@ from mem0 import Memory os.environ["OPENAI_API_KEY"] = "your-api-key" m = Memory() +``` + + +```python +import os +from mem0 import AsyncMemory + +os.environ["OPENAI_API_KEY"] = "your-api-key" + +m = AsyncMemory() ``` diff --git a/mem0/__init__.py b/mem0/__init__.py index bb441a50..21a818fe 100644 --- a/mem0/__init__.py +++ b/mem0/__init__.py @@ -3,4 +3,4 @@ import importlib.metadata __version__ = importlib.metadata.version("mem0ai") from mem0.client.main import AsyncMemoryClient, MemoryClient # noqa -from mem0.memory.main import Memory # noqa +from mem0.memory.main import Memory, AsyncMemory # noqa diff --git a/mem0/memory/main.py b/mem0/memory/main.py index b9d2111f..6c502309 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import hashlib import json @@ -196,8 +197,8 @@ class Memory(MemoryBase): parsed_messages = parse_messages(messages) - if self.custom_fact_extraction_prompt: - system_prompt = self.custom_fact_extraction_prompt + if self.config.custom_fact_extraction_prompt: + system_prompt = self.config.custom_fact_extraction_prompt user_prompt = f"Input:\n{parsed_messages}" else: system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages) @@ -243,7 +244,7 @@ class Memory(MemoryBase): retrieved_old_memory[idx]["id"] = str(idx) function_calling_prompt = get_update_memory_messages( - retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt + retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt ) try: @@ -755,3 +756,773 @@ class Memory(MemoryBase): def chat(self, query): raise NotImplementedError("Chat function not implemented yet.") + + +class AsyncMemory(MemoryBase): + def __init__(self, config: MemoryConfig = MemoryConfig()): + self.config = config + + self.embedding_model = EmbedderFactory.create( + self.config.embedder.provider, + self.config.embedder.config, + self.config.vector_store.config, + ) + self.vector_store = VectorStoreFactory.create( + self.config.vector_store.provider, self.config.vector_store.config + ) + self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config) + self.db = SQLiteManager(self.config.history_db_path) + self.collection_name = self.config.vector_store.config.collection_name + self.api_version = self.config.version + + self.enable_graph = False + + if self.config.graph_store.config: + from mem0.memory.graph_memory import MemoryGraph + + self.graph = MemoryGraph(self.config) + self.enable_graph = True + + capture_event("mem0.init", self) + + @classmethod + async def from_config(cls, config_dict: Dict[str, Any]): + try: + config = cls._process_config(config_dict) + config = MemoryConfig(**config_dict) + except ValidationError as e: + logger.error(f"Configuration validation error: {e}") + raise + return cls(config) + + @staticmethod + def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]: + if "graph_store" in config_dict: + if "vector_store" not in config_dict and "embedder" in config_dict: + config_dict["vector_store"] = {} + config_dict["vector_store"]["config"] = {} + config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][ + "embedding_dims" + ] + try: + return config_dict + except ValidationError as e: + logger.error(f"Configuration validation error: {e}") + raise + + async def add( + self, + messages, + user_id=None, + agent_id=None, + run_id=None, + metadata=None, + filters=None, + infer=True, + memory_type=None, + prompt=None, + llm=None, + ): + """ + Create a new memory asynchronously. + + Args: + messages (str or List[Dict[str, str]]): Messages to store in the memory. + user_id (str, optional): ID of the user creating the memory. Defaults to None. + agent_id (str, optional): ID of the agent creating the memory. Defaults to None. + run_id (str, optional): ID of the run creating the memory. Defaults to None. + metadata (dict, optional): Metadata to store with the memory. Defaults to None. + filters (dict, optional): Filters to apply to the search. Defaults to None. + infer (bool, optional): Whether to infer the memories. Defaults to True. + memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories. + prompt (str, optional): Prompt to use for the memory creation. Defaults to None. + llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel. + Returns: + dict: A dictionary containing the result of the memory addition operation. + result: dict of affected events with each dict has the following key: + 'memories': affected memories + 'graph': affected graph memories + + 'memories' and 'graph' is a dict, each with following subkeys: + 'add': added memory + 'update': updated memory + 'delete': deleted memory + """ + if metadata is None: + metadata = {} + + filters = filters or {} + if user_id: + filters["user_id"] = metadata["user_id"] = user_id + if agent_id: + filters["agent_id"] = metadata["agent_id"] = agent_id + if run_id: + filters["run_id"] = metadata["run_id"] = run_id + + if not any(key in filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError("One of the filters: user_id, agent_id or run_id is required!") + + if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: + raise ValueError( + f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories." + ) + + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value: + results = await self._create_procedural_memory(messages, metadata=metadata, llm=llm, prompt=prompt) + return results + + if self.config.llm.config.get("enable_vision"): + messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details")) + else: + messages = parse_vision_messages(messages) + + # Run vector store and graph operations concurrently + vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, metadata, filters, infer)) + graph_task = asyncio.create_task(self._add_to_graph(messages, filters)) + + vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task) + + if self.api_version == "v1.0": + warnings.warn( + "The current add API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2, + ) + return vector_store_result + + if self.enable_graph: + return { + "results": vector_store_result, + "relations": graph_result, + } + + return {"results": vector_store_result} + + async def _add_to_vector_store(self, messages, metadata, filters, infer): + if not infer: + returned_memories = [] + for message in messages: + if message["role"] != "system": + message_embeddings = await asyncio.to_thread(self.embedding_model.embed, message["content"], "add") + memory_id = await self._create_memory(message["content"], message_embeddings, metadata) + returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"}) + return returned_memories + + parsed_messages = parse_messages(messages) + + if self.config.custom_fact_extraction_prompt: + system_prompt = self.config.custom_fact_extraction_prompt + user_prompt = f"Input:\n{parsed_messages}" + else: + system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages) + + response = await asyncio.to_thread( + self.llm.generate_response, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + ) + + try: + response = remove_code_blocks(response) + new_retrieved_facts = json.loads(response)["facts"] + except Exception as e: + logging.error(f"Error in new_retrieved_facts: {e}") + new_retrieved_facts = [] + + retrieved_old_memory = [] + new_message_embeddings = {} + + # Process all facts concurrently + async def process_fact(new_mem): + messages_embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem, "add") + new_message_embeddings[new_mem] = messages_embeddings + existing_memories = await asyncio.to_thread( + self.vector_store.search, + query=new_mem, + vectors=messages_embeddings, + limit=5, + filters=filters, + ) + return [(mem.id, mem.payload["data"]) for mem in existing_memories] + + fact_tasks = [process_fact(fact) for fact in new_retrieved_facts] + fact_results = await asyncio.gather(*fact_tasks) + + # Flatten results and build retrieved_old_memory + for result in fact_results: + for mem_id, mem_data in result: + retrieved_old_memory.append({"id": mem_id, "text": mem_data}) + + unique_data = {} + for item in retrieved_old_memory: + unique_data[item["id"]] = item + retrieved_old_memory = list(unique_data.values()) + logging.info(f"Total existing memories: {len(retrieved_old_memory)}") + + # mapping UUIDs with integers for handling UUID hallucinations + temp_uuid_mapping = {} + for idx, item in enumerate(retrieved_old_memory): + temp_uuid_mapping[str(idx)] = item["id"] + retrieved_old_memory[idx]["id"] = str(idx) + + function_calling_prompt = get_update_memory_messages( + retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt + ) + + try: + new_memories_with_actions = await asyncio.to_thread( + self.llm.generate_response, + messages=[{"role": "user", "content": function_calling_prompt}], + response_format={"type": "json_object"}, + ) + except Exception as e: + logging.error(f"Error in new_memories_with_actions: {e}") + new_memories_with_actions = [] + + try: + new_memories_with_actions = remove_code_blocks(new_memories_with_actions) + new_memories_with_actions = json.loads(new_memories_with_actions) + except Exception as e: + logging.error(f"Invalid JSON response: {e}") + new_memories_with_actions = [] + + returned_memories = [] + try: + memory_tasks = [] + for resp in new_memories_with_actions.get("memory", []): + logging.info(resp) + try: + if not resp.get("text"): + logging.info("Skipping memory entry because of empty `text` field.") + continue + elif resp.get("event") == "ADD": + task = asyncio.create_task(self._create_memory( + data=resp.get("text"), + existing_embeddings=new_message_embeddings, + metadata=metadata + )) + memory_tasks.append((task, resp, "ADD", None)) + elif resp.get("event") == "UPDATE": + task = asyncio.create_task(self._update_memory( + memory_id=temp_uuid_mapping[resp["id"]], + data=resp.get("text"), + existing_embeddings=new_message_embeddings, + metadata=metadata, + )) + memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]])) + elif resp.get("event") == "DELETE": + task = asyncio.create_task(self._delete_memory( + memory_id=temp_uuid_mapping[resp.get("id")] + )) + memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp["id"]])) + elif resp.get("event") == "NONE": + logging.info("NOOP for Memory.") + except Exception as e: + logging.error(f"Error in new_memories_with_actions: {e}") + + # Wait for all memory operations to complete + for task, resp, event_type, mem_id in memory_tasks: + try: + result_id = await task + if event_type == "ADD": + returned_memories.append({ + "id": result_id, + "memory": resp.get("text"), + "event": resp.get("event"), + }) + elif event_type == "UPDATE": + returned_memories.append({ + "id": mem_id, + "memory": resp.get("text"), + "event": resp.get("event"), + "previous_memory": resp.get("old_memory"), + }) + elif event_type == "DELETE": + returned_memories.append({ + "id": mem_id, + "memory": resp.get("text"), + "event": resp.get("event"), + }) + except Exception as e: + logging.error(f"Error processing memory task: {e}") + + except Exception as e: + logging.error(f"Error in new_memories_with_actions: {e}") + + capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())}) + + return returned_memories + + async def _add_to_graph(self, messages, filters): + added_entities = [] + if self.enable_graph: + if filters.get("user_id") is None: + filters["user_id"] = "user" + + data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) + added_entities = await asyncio.to_thread(self.graph.add, data, filters) + + return added_entities + + async def get(self, memory_id): + """ + Retrieve a memory by ID asynchronously. + + Args: + memory_id (str): ID of the memory to retrieve. + + Returns: + dict: Retrieved memory. + """ + capture_event("mem0.get", self, {"memory_id": memory_id}) + memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id) + if not memory: + return None + + filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)} + + # Prepare base memory item + memory_item = MemoryItem( + id=memory.id, + memory=memory.payload["data"], + hash=memory.payload.get("hash"), + created_at=memory.payload.get("created_at"), + updated_at=memory.payload.get("updated_at"), + ).model_dump(exclude={"score"}) + + # Add metadata if there are additional keys + excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"} + additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} + if additional_metadata: + memory_item["metadata"] = additional_metadata + + result = {**memory_item, **filters} + + return result + + async def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): + """ + List all memories asynchronously. + + Returns: + list: List of all memories. + """ + filters = {} + if user_id: + filters["user_id"] = user_id + if agent_id: + filters["agent_id"] = agent_id + if run_id: + filters["run_id"] = run_id + + capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())}) + + # Run vector store and graph operations concurrently + vector_store_task = asyncio.create_task(self._get_all_from_vector_store(filters, limit)) + + if self.enable_graph: + graph_task = asyncio.create_task(asyncio.to_thread(self.graph.get_all, filters, limit)) + all_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task) + else: + all_memories = await vector_store_task + graph_entities = None + + if self.enable_graph: + return {"results": all_memories, "relations": graph_entities} + + if self.api_version == "v1.0": + warnings.warn( + "The current get_all API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2, + ) + return all_memories + else: + return {"results": all_memories} + + async def _get_all_from_vector_store(self, filters, limit): + memories = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit) + + excluded_keys = { + "user_id", + "agent_id", + "run_id", + "hash", + "data", + "created_at", + "updated_at", + "id", + } + all_memories = [ + { + **MemoryItem( + id=mem.id, + memory=mem.payload["data"], + hash=mem.payload.get("hash"), + created_at=mem.payload.get("created_at"), + updated_at=mem.payload.get("updated_at"), + ).model_dump(exclude={"score"}), + **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, + **( + {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} + if any(k for k in mem.payload if k not in excluded_keys) + else {} + ), + } + for mem in memories[0] + ] + return all_memories + + async def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None): + """ + Search for memories asynchronously. + + Args: + query (str): Query to search for. + user_id (str, optional): ID of the user to search for. Defaults to None. + agent_id (str, optional): ID of the agent to search for. Defaults to None. + run_id (str, optional): ID of the run to search for. Defaults to None. + limit (int, optional): Limit the number of results. Defaults to 100. + filters (dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: List of search results. + """ + filters = filters or {} + if user_id: + filters["user_id"] = user_id + if agent_id: + filters["agent_id"] = agent_id + if run_id: + filters["run_id"] = run_id + + if not any(key in filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError("One of the filters: user_id, agent_id or run_id is required!") + + capture_event( + "mem0.search", + self, + {"limit": limit, "version": self.api_version, "keys": list(filters.keys())}, + ) + + # Run vector store and graph operations concurrently + vector_store_task = asyncio.create_task(self._search_vector_store(query, filters, limit)) + + if self.enable_graph: + graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, filters, limit)) + original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task) + else: + original_memories = await vector_store_task + graph_entities = None + + if self.enable_graph: + return {"results": original_memories, "relations": graph_entities} + + if self.api_version == "v1.0": + warnings.warn( + "The current get_all API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2, + ) + return original_memories + else: + return {"results": original_memories} + + async def _search_vector_store(self, query, filters, limit): + embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search") + memories = await asyncio.to_thread( + self.vector_store.search, + query=query, + vectors=embeddings, + limit=limit, + filters=filters + ) + + excluded_keys = { + "user_id", + "agent_id", + "run_id", + "hash", + "data", + "created_at", + "updated_at", + "id", + } + + original_memories = [ + { + **MemoryItem( + id=mem.id, + memory=mem.payload["data"], + hash=mem.payload.get("hash"), + created_at=mem.payload.get("created_at"), + updated_at=mem.payload.get("updated_at"), + score=mem.score, + ).model_dump(), + **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, + **( + {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} + if any(k for k in mem.payload if k not in excluded_keys) + else {} + ), + } + for mem in memories + ] + + return original_memories + + async def update(self, memory_id, data): + """ + Update a memory by ID asynchronously. + + Args: + memory_id (str): ID of the memory to update. + data (dict): Data to update the memory with. + + Returns: + dict: Updated memory. + """ + capture_event("mem0.update", self, {"memory_id": memory_id}) + + embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update") + existing_embeddings = {data: embeddings} + + await self._update_memory(memory_id, data, existing_embeddings) + return {"message": "Memory updated successfully!"} + + async def delete(self, memory_id): + """ + Delete a memory by ID asynchronously. + + Args: + memory_id (str): ID of the memory to delete. + """ + capture_event("mem0.delete", self, {"memory_id": memory_id}) + await self._delete_memory(memory_id) + return {"message": "Memory deleted successfully!"} + + async def delete_all(self, user_id=None, agent_id=None, run_id=None): + """ + Delete all memories asynchronously. + + Args: + user_id (str, optional): ID of the user to delete memories for. Defaults to None. + agent_id (str, optional): ID of the agent to delete memories for. Defaults to None. + run_id (str, optional): ID of the run to delete memories for. Defaults to None. + """ + filters = {} + if user_id: + filters["user_id"] = user_id + if agent_id: + filters["agent_id"] = agent_id + if run_id: + filters["run_id"] = run_id + + if not filters: + raise ValueError( + "At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method." + ) + + capture_event("mem0.delete_all", self, {"keys": list(filters.keys())}) + memories = await asyncio.to_thread(self.vector_store.list, filters=filters) + + delete_tasks = [] + for memory in memories[0]: + delete_tasks.append(self._delete_memory(memory.id)) + + await asyncio.gather(*delete_tasks) + + logger.info(f"Deleted {len(memories[0])} memories") + + if self.enable_graph: + await asyncio.to_thread(self.graph.delete_all, filters) + + return {"message": "Memories deleted successfully!"} + + async def history(self, memory_id): + """ + Get the history of changes for a memory by ID asynchronously. + + Args: + memory_id (str): ID of the memory to get history for. + + Returns: + list: List of changes for the memory. + """ + capture_event("mem0.history", self, {"memory_id": memory_id}) + return await asyncio.to_thread(self.db.get_history, memory_id) + + async def _create_memory(self, data, existing_embeddings, metadata=None): + logging.debug(f"Creating memory with {data=}") + if data in existing_embeddings: + embeddings = existing_embeddings[data] + else: + embeddings = await asyncio.to_thread(self.embedding_model.embed, data, memory_action="add") + + memory_id = str(uuid.uuid4()) + metadata = metadata or {} + metadata["data"] = data + metadata["hash"] = hashlib.md5(data.encode()).hexdigest() + metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat() + + await asyncio.to_thread( + self.vector_store.insert, + vectors=[embeddings], + ids=[memory_id], + payloads=[metadata], + ) + + await asyncio.to_thread( + self.db.add_history, + memory_id, + None, + data, + "ADD", + created_at=metadata["created_at"] + ) + + capture_event("mem0._create_memory", self, {"memory_id": memory_id}) + return memory_id + + async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None): + """ + Create a procedural memory asynchronously + + Args: + messages (list): List of messages to create a procedural memory from. + metadata (dict): Metadata to create a procedural memory from. + llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel. + prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None. + """ + try: + from langchain_core.messages.utils import ( + convert_to_messages, # type: ignore + ) + except Exception: + logger.error("Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory.") + raise + + logger.info("Creating procedural memory") + + parsed_messages = [ + {"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT}, + *messages, + {"role": "user", "content": "Create procedural memory of the above conversation."}, + ] + + try: + if llm is not None: + parsed_messages = convert_to_messages(parsed_messages) + response = await asyncio.to_thread(llm.invoke, input=parsed_messages) + procedural_memory = response.content + else: + procedural_memory = await asyncio.to_thread( + self.llm.generate_response, + messages=parsed_messages + ) + except Exception as e: + logger.error(f"Error generating procedural memory summary: {e}") + raise + + if metadata is None: + raise ValueError("Metadata cannot be done for procedural memory.") + + metadata["memory_type"] = MemoryType.PROCEDURAL.value + # Generate embeddings for the summary + embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add") + # Create the memory + memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata) + capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id}) + + # Return results in the same format as add() + result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]} + + return result + + async def _update_memory(self, memory_id, data, existing_embeddings, metadata=None): + logger.info(f"Updating memory with {data=}") + + try: + existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id) + except Exception: + raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") + + prev_value = existing_memory.payload.get("data") + + new_metadata = metadata or {} + new_metadata["data"] = data + new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest() + new_metadata["created_at"] = existing_memory.payload.get("created_at") + new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat() + + if "user_id" in existing_memory.payload: + new_metadata["user_id"] = existing_memory.payload["user_id"] + if "agent_id" in existing_memory.payload: + new_metadata["agent_id"] = existing_memory.payload["agent_id"] + if "run_id" in existing_memory.payload: + new_metadata["run_id"] = existing_memory.payload["run_id"] + + if data in existing_embeddings: + embeddings = existing_embeddings[data] + else: + embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update") + + await asyncio.to_thread( + self.vector_store.update, + vector_id=memory_id, + vector=embeddings, + payload=new_metadata, + ) + + logger.info(f"Updating memory with ID {memory_id=} with {data=}") + + await asyncio.to_thread( + self.db.add_history, + memory_id, + prev_value, + data, + "UPDATE", + created_at=new_metadata["created_at"], + updated_at=new_metadata["updated_at"], + ) + + capture_event("mem0._update_memory", self, {"memory_id": memory_id}) + return memory_id + + async def _delete_memory(self, memory_id): + logging.info(f"Deleting memory with {memory_id=}") + existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id) + prev_value = existing_memory.payload["data"] + + await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id) + await asyncio.to_thread(self.db.add_history, memory_id, prev_value, None, "DELETE", is_deleted=1) + + capture_event("mem0._delete_memory", self, {"memory_id": memory_id}) + return memory_id + + async def reset(self): + """ + Reset the memory store asynchronously. + """ + logger.warning("Resetting all memories") + await asyncio.to_thread(self.vector_store.delete_col) + self.vector_store = VectorStoreFactory.create( + self.config.vector_store.provider, self.config.vector_store.config + ) + await asyncio.to_thread(self.db.reset) + capture_event("mem0.reset", self) + + async def chat(self, query): + raise NotImplementedError("Chat function not implemented yet.")