[Mem0] Integrate Graph Memory (#1718)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -4,9 +4,8 @@ import uuid
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import warnings
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.llms.utils.tools import (
|
||||
ADD_MEMORY_TOOL,
|
||||
DELETE_MEMORY_TOOL,
|
||||
@@ -37,7 +36,15 @@ class Memory(MemoryBase):
|
||||
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
self.collection_name = self.config.vector_store.config.collection_name
|
||||
self.version = self.config.version
|
||||
|
||||
self.enable_graph = False
|
||||
|
||||
if self.version == "v1.1" and self.config.graph_store.config:
|
||||
from mem0.memory.main_graph import MemoryGraph
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
@classmethod
|
||||
@@ -164,6 +171,14 @@ class Memory(MemoryBase):
|
||||
{"memory_id": function_result, "function_name": function_name},
|
||||
)
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if user_id:
|
||||
self.graph.user_id = user_id
|
||||
else:
|
||||
self.graph.user_id = "USER"
|
||||
added_entities = self.graph.add(data)
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
def get(self, memory_id):
|
||||
@@ -234,16 +249,8 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"hash",
|
||||
"data",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
}
|
||||
return [
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
all_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
@@ -271,6 +278,23 @@ class Memory(MemoryBase):
|
||||
}
|
||||
for mem in memories[0]
|
||||
]
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.get_all()
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : all_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
@@ -302,7 +326,7 @@ class Memory(MemoryBase):
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
query=embeddings, limit=limit, filters=filters
|
||||
@@ -318,7 +342,7 @@ class Memory(MemoryBase):
|
||||
"updated_at",
|
||||
}
|
||||
|
||||
return [
|
||||
original_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
@@ -348,6 +372,22 @@ class Memory(MemoryBase):
|
||||
for mem in memories
|
||||
]
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.search(query)
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return original_memories
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
@@ -400,7 +440,11 @@ class Memory(MemoryBase):
|
||||
memories = self.vector_store.list(filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory_tool(memory.id)
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all()
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user