Code formatting (#1986)
This commit is contained in:
@@ -37,11 +37,11 @@ class Memory(MemoryBase):
|
||||
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
self.collection_name = self.config.vector_store.config.collection_name
|
||||
self.version = self.config.version
|
||||
self.api_version = self.config.version
|
||||
|
||||
self.enable_graph = False
|
||||
|
||||
if self.version == "v1.1" and self.config.graph_store.config:
|
||||
if self.api_version == "v1.1" and self.config.graph_store.config:
|
||||
from mem0.memory.graph_memory import MemoryGraph
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
@@ -119,7 +119,7 @@ class Memory(MemoryBase):
|
||||
vector_store_result = future1.result()
|
||||
graph_result = future2.result()
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.api_version == "v1.1":
|
||||
return {
|
||||
"results": vector_store_result,
|
||||
"relations": graph_result,
|
||||
@@ -226,13 +226,13 @@ class Memory(MemoryBase):
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
|
||||
capture_event("mem0.add", self)
|
||||
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
|
||||
|
||||
return returned_memories
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if self.api_version == "v1.1" and self.enable_graph:
|
||||
if filters["user_id"]:
|
||||
self.graph.user_id = filters["user_id"]
|
||||
elif filters["agent_id"]:
|
||||
@@ -305,13 +305,13 @@ class Memory(MemoryBase):
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.get_all, filters, limit)
|
||||
if self.version == "v1.1" and self.enable_graph
|
||||
if self.api_version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -322,7 +322,7 @@ class Memory(MemoryBase):
|
||||
all_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.api_version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"results": all_memories, "relations": graph_entities}
|
||||
else:
|
||||
@@ -398,14 +398,14 @@ class Memory(MemoryBase):
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"filters": len(filters), "limit": limit, "version": self.version},
|
||||
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.search, query, filters, limit)
|
||||
if self.version == "v1.1" and self.enable_graph
|
||||
if self.api_version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -416,7 +416,7 @@ class Memory(MemoryBase):
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.api_version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
else:
|
||||
@@ -518,14 +518,14 @@ class Memory(MemoryBase):
|
||||
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
|
||||
)
|
||||
|
||||
capture_event("mem0.delete_all", self, {"filters": len(filters)})
|
||||
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
|
||||
memories = self.vector_store.list(filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory(memory.id)
|
||||
|
||||
logger.info(f"Deleted {len(memories)} memories")
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if self.api_version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all(filters)
|
||||
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
@@ -561,6 +561,7 @@ class Memory(MemoryBase):
|
||||
payloads=[metadata],
|
||||
)
|
||||
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
|
||||
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
|
||||
@@ -603,6 +604,7 @@ class Memory(MemoryBase):
|
||||
created_at=new_metadata["created_at"],
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
def _delete_memory(self, memory_id):
|
||||
@@ -611,6 +613,7 @@ class Memory(MemoryBase):
|
||||
prev_value = existing_memory.payload["data"]
|
||||
self.vector_store.delete(vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
def reset(self):
|
||||
|
||||
Reference in New Issue
Block a user