[Mem0] Integrate Graph Memory (#1718)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Prateek Chhikara
2024-08-20 16:37:38 -07:00
committed by GitHub
parent 9b7a882d57
commit c64e0824da
22 changed files with 867 additions and 26 deletions

View File

@@ -4,9 +4,8 @@ import uuid
import pytz
from datetime import datetime
from typing import Any, Dict
import warnings
from pydantic import ValidationError
from mem0.llms.utils.tools import (
ADD_MEMORY_TOOL,
DELETE_MEMORY_TOOL,
@@ -37,7 +36,15 @@ class Memory(MemoryBase):
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.version = self.config.version
self.enable_graph = False
if self.version == "v1.1" and self.config.graph_store.config:
from mem0.memory.main_graph import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
@@ -164,6 +171,14 @@ class Memory(MemoryBase):
{"memory_id": function_result, "function_name": function_name},
)
capture_event("mem0.add", self)
if self.version == "v1.1" and self.enable_graph:
if user_id:
self.graph.user_id = user_id
else:
self.graph.user_id = "USER"
added_entities = self.graph.add(data)
return {"message": "ok"}
def get(self, memory_id):
@@ -234,16 +249,8 @@ class Memory(MemoryBase):
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
return [
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
all_memories = [
{
**MemoryItem(
id=mem.id,
@@ -271,6 +278,23 @@ class Memory(MemoryBase):
}
for mem in memories[0]
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.get_all()
return {"memories": all_memories, "entities": graph_entities}
else:
return {"memories" : all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
@@ -302,7 +326,7 @@ class Memory(MemoryBase):
"One of the filters: user_id, agent_id or run_id is required!"
)
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
@@ -318,7 +342,7 @@ class Memory(MemoryBase):
"updated_at",
}
return [
original_memories = [
{
**MemoryItem(
id=mem.id,
@@ -348,6 +372,22 @@ class Memory(MemoryBase):
for mem in memories
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.search(query)
return {"memories": original_memories, "entities": graph_entities}
else:
return {"memories" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
@@ -400,7 +440,11 @@ class Memory(MemoryBase):
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
return {"message": "Memories deleted successfully!"}
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all()
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id):
"""