Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812)
This commit is contained in:
@@ -2,7 +2,6 @@ import concurrent
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
@@ -11,14 +10,14 @@ from typing import Any, Dict
|
||||
import pytz
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -49,6 +48,7 @@ class Memory(MemoryBase):
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config_dict: Dict[str, Any]):
|
||||
try:
|
||||
@@ -58,6 +58,7 @@ class Memory(MemoryBase):
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages,
|
||||
@@ -81,7 +82,7 @@ class Memory(MemoryBase):
|
||||
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: Memory addition operation message.
|
||||
dict: A dictionary containing the result of the memory addition operation.
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
@@ -102,17 +103,31 @@ class Memory(MemoryBase):
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
|
||||
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
|
||||
future2 = executor.submit(self._add_to_graph, messages, filters)
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
concurrent.futures.wait([future1, future2])
|
||||
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
vector_store_result = future1.result()
|
||||
graph_result = future2.result()
|
||||
|
||||
if self.version == "v1.1":
|
||||
return {
|
||||
"results" : vector_store_result,
|
||||
"relations" : graph_result,
|
||||
}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current add API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return {"message": "ok"}
|
||||
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
def _add_to_vector_store(self, messages, metadata, filters):
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
@@ -151,16 +166,30 @@ class Memory(MemoryBase):
|
||||
)
|
||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||
|
||||
returned_memories = []
|
||||
try:
|
||||
for resp in new_memories_with_actions["memory"]:
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
self._create_memory(data=resp["text"], metadata=metadata)
|
||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
"previous_memory" : resp["old_memory"],
|
||||
})
|
||||
elif resp["event"] == "DELETE":
|
||||
self._delete_memory(memory_id=resp["id"])
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
elif resp["event"] == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
@@ -170,7 +199,11 @@ class Memory(MemoryBase):
|
||||
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
return returned_memories
|
||||
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if filters["user_id"]:
|
||||
self.graph.user_id = filters["user_id"]
|
||||
@@ -179,6 +212,9 @@ class Memory(MemoryBase):
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||
self.graph.add(data, filters)
|
||||
|
||||
return added_entities
|
||||
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
@@ -229,6 +265,7 @@ class Memory(MemoryBase):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
List all memories.
|
||||
@@ -255,9 +292,9 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
return {"results": all_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"memories": all_memories}
|
||||
return {"results": all_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
@@ -268,6 +305,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
@@ -302,6 +340,7 @@ class Memory(MemoryBase):
|
||||
]
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
):
|
||||
@@ -343,9 +382,9 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
return {"results" : original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
@@ -356,6 +395,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return original_memories
|
||||
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
@@ -404,6 +444,7 @@ class Memory(MemoryBase):
|
||||
|
||||
return original_memories
|
||||
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
@@ -419,6 +460,7 @@ class Memory(MemoryBase):
|
||||
self._update_memory(memory_id, data)
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID.
|
||||
@@ -430,6 +472,7 @@ class Memory(MemoryBase):
|
||||
self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
Delete all memories.
|
||||
@@ -464,6 +507,7 @@ class Memory(MemoryBase):
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
Get the history of changes for a memory by ID.
|
||||
@@ -477,6 +521,7 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
|
||||
def _create_memory(self, data, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
@@ -496,6 +541,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return memory_id
|
||||
|
||||
|
||||
def _update_memory(self, memory_id, data, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -532,6 +578,7 @@ class Memory(MemoryBase):
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
def _delete_memory(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -539,6 +586,7 @@ class Memory(MemoryBase):
|
||||
self.vector_store.delete(vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the memory store.
|
||||
@@ -548,5 +596,6 @@ class Memory(MemoryBase):
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
|
||||
def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
|
||||
Reference in New Issue
Block a user