Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812)

This commit is contained in:
Prateek Chhikara
2024-09-09 10:04:11 -07:00
committed by GitHub
parent 58f29d8781
commit b081e43b8d
5 changed files with 300 additions and 62 deletions

View File

@@ -2,7 +2,6 @@ import concurrent
import hashlib
import json
import logging
import threading
import uuid
import warnings
from datetime import datetime
@@ -11,14 +10,14 @@ from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
# Setup user config
setup_config()
@@ -49,6 +48,7 @@ class Memory(MemoryBase):
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
@@ -58,6 +58,7 @@ class Memory(MemoryBase):
raise
return cls(config)
def add(
self,
messages,
@@ -81,7 +82,7 @@ class Memory(MemoryBase):
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
Returns:
dict: Memory addition operation message.
dict: A dictionary containing the result of the memory addition operation.
"""
if metadata is None:
metadata = {}
@@ -102,17 +103,31 @@ class Memory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
thread1.start()
thread2.start()
concurrent.futures.wait([future1, future2])
thread1.join()
thread2.join()
vector_store_result = future1.result()
graph_result = future2.result()
if self.version == "v1.1":
return {
"results" : vector_store_result,
"relations" : graph_result,
}
else:
warnings.warn(
"The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return {"message": "ok"}
return {"message": "ok"}
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
@@ -151,16 +166,30 @@ class Memory(MemoryBase):
)
new_memories_with_actions = json.loads(new_memories_with_actions)
returned_memories = []
try:
for resp in new_memories_with_actions["memory"]:
logging.info(resp)
try:
if resp["event"] == "ADD":
self._create_memory(data=resp["text"], metadata=metadata)
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
})
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
"previous_memory" : resp["old_memory"],
})
elif resp["event"] == "DELETE":
self._delete_memory(memory_id=resp["id"])
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
})
elif resp["event"] == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
@@ -170,7 +199,11 @@ class Memory(MemoryBase):
capture_event("mem0.add", self)
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
if self.version == "v1.1" and self.enable_graph:
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
@@ -179,6 +212,9 @@ class Memory(MemoryBase):
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
self.graph.add(data, filters)
return added_entities
def get(self, memory_id):
"""
Retrieve a memory by ID.
@@ -229,6 +265,7 @@ class Memory(MemoryBase):
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
@@ -255,9 +292,9 @@ class Memory(MemoryBase):
if self.version == "v1.1":
if self.enable_graph:
return {"memories": all_memories, "entities": graph_entities}
return {"results": all_memories, "relations": graph_entities}
else:
return {"memories": all_memories}
return {"results": all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
@@ -268,6 +305,7 @@ class Memory(MemoryBase):
)
return all_memories
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
@@ -302,6 +340,7 @@ class Memory(MemoryBase):
]
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
):
@@ -343,9 +382,9 @@ class Memory(MemoryBase):
if self.version == "v1.1":
if self.enable_graph:
return {"memories": original_memories, "entities": graph_entities}
return {"results": original_memories, "relations": graph_entities}
else:
return {"memories" : original_memories}
return {"results" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
@@ -356,6 +395,7 @@ class Memory(MemoryBase):
)
return original_memories
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
@@ -404,6 +444,7 @@ class Memory(MemoryBase):
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
@@ -419,6 +460,7 @@ class Memory(MemoryBase):
self._update_memory(memory_id, data)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
@@ -430,6 +472,7 @@ class Memory(MemoryBase):
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
@@ -464,6 +507,7 @@ class Memory(MemoryBase):
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
@@ -477,6 +521,7 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
@@ -496,6 +541,7 @@ class Memory(MemoryBase):
)
return memory_id
def _update_memory(self, memory_id, data, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -532,6 +578,7 @@ class Memory(MemoryBase):
updated_at=new_metadata["updated_at"],
)
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -539,6 +586,7 @@ class Memory(MemoryBase):
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
def reset(self):
"""
Reset the memory store.
@@ -548,5 +596,6 @@ class Memory(MemoryBase):
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")