add response to m.add() call (#1732)

This commit is contained in:
femto
2024-10-16 06:53:18 +08:00
committed by GitHub
parent 2b262a65b2
commit 2cd9f94ea6

View File

@@ -4,6 +4,7 @@ import json
import logging import logging
import uuid import uuid
import warnings import warnings
from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Any, Dict from typing import Any, Dict
@@ -82,6 +83,16 @@ class Memory(MemoryBase):
Returns: Returns:
dict: A dictionary containing the result of the memory addition operation. dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
'memories': affected memories
'graph': affected graph memories
'memories' and 'graph' is a dict, each with following subkeys:
'add': added memory
'update': updated memory
'delete': deleted memory
""" """
if metadata is None: if metadata is None:
metadata = {} metadata = {}
@@ -175,9 +186,10 @@ class Memory(MemoryBase):
logging.info(resp) logging.info(resp)
try: try:
if resp["event"] == "ADD": if resp["event"] == "ADD":
_ = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
returned_memories.append( returned_memories.append(
{ {
"id": memory_id,
"memory": resp["text"], "memory": resp["text"],
"event": resp["event"], "event": resp["event"],
} }
@@ -186,6 +198,7 @@ class Memory(MemoryBase):
self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
returned_memories.append( returned_memories.append(
{ {
"id": resp["id"],
"memory": resp["text"], "memory": resp["text"],
"event": resp["event"], "event": resp["event"],
"previous_memory": resp["old_memory"], "previous_memory": resp["old_memory"],
@@ -195,6 +208,7 @@ class Memory(MemoryBase):
self._delete_memory(memory_id=resp["id"]) self._delete_memory(memory_id=resp["id"])
returned_memories.append( returned_memories.append(
{ {
"id": resp["id"],
"memory": resp["text"], "memory": resp["text"],
"event": resp["event"], "event": resp["event"],
} }
@@ -453,7 +467,7 @@ class Memory(MemoryBase):
dict: Updated memory. dict: Updated memory.
""" """
capture_event("mem0.update", self, {"memory_id": memory_id}) capture_event("mem0.update", self, {"memory_id": memory_id})
existing_embeddings = {data: self.embedding_model.embed(data)} existing_embeddings = {data: self.embedding_model.embed(data)}
self._update_memory(memory_id, data, existing_embeddings) self._update_memory(memory_id, data, existing_embeddings)
@@ -519,9 +533,9 @@ class Memory(MemoryBase):
def _create_memory(self, data, existing_embeddings, metadata=None): def _create_memory(self, data, existing_embeddings, metadata=None):
logging.info(f"Creating memory with {data=}") logging.info(f"Creating memory with {data=}")
if data in existing_embeddings: if data in existing_embeddings:
embeddings = existing_embeddings[data] embeddings = existing_embeddings[data]
else: else:
embeddings = self.embedding_model.embed(data) embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata = metadata or {} metadata = metadata or {}
@@ -559,9 +573,9 @@ class Memory(MemoryBase):
if "run_id" in existing_memory.payload: if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"] new_metadata["run_id"] = existing_memory.payload["run_id"]
if data in existing_embeddings: if data in existing_embeddings:
embeddings = existing_embeddings[data] embeddings = existing_embeddings[data]
else: else:
embeddings = self.embedding_model.embed(data) embeddings = self.embedding_model.embed(data)
self.vector_store.update( self.vector_store.update(
vector_id=memory_id, vector_id=memory_id,
@@ -577,6 +591,7 @@ class Memory(MemoryBase):
created_at=new_metadata["created_at"], created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"], updated_at=new_metadata["updated_at"],
) )
return memory_id
def _delete_memory(self, memory_id): def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}") logging.info(f"Deleting memory with {memory_id=}")
@@ -584,6 +599,7 @@ class Memory(MemoryBase):
prev_value = existing_memory.payload["data"] prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id) self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1) self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
return memory_id
def reset(self): def reset(self):
""" """