Code formatting (#1986)

This commit is contained in:
Dev Khant
2024-10-29 11:32:07 +05:30
committed by GitHub
parent dca74a1ec0
commit 605558da9d
13 changed files with 119 additions and 149 deletions

View File

@@ -37,11 +37,11 @@ class Memory(MemoryBase):
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.version = self.config.version
self.api_version = self.config.version
self.enable_graph = False
if self.version == "v1.1" and self.config.graph_store.config:
if self.api_version == "v1.1" and self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config)
@@ -119,7 +119,7 @@ class Memory(MemoryBase):
vector_store_result = future1.result()
graph_result = future2.result()
if self.version == "v1.1":
if self.api_version == "v1.1":
return {
"results": vector_store_result,
"relations": graph_result,
@@ -226,13 +226,13 @@ class Memory(MemoryBase):
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
capture_event("mem0.add", self)
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
if self.version == "v1.1" and self.enable_graph:
if self.api_version == "v1.1" and self.enable_graph:
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
elif filters["agent_id"]:
@@ -305,13 +305,13 @@ class Memory(MemoryBase):
if run_id:
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit)
if self.version == "v1.1" and self.enable_graph
if self.api_version == "v1.1" and self.enable_graph
else None
)
@@ -322,7 +322,7 @@ class Memory(MemoryBase):
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.version == "v1.1":
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
else:
@@ -398,14 +398,14 @@ class Memory(MemoryBase):
capture_event(
"mem0.search",
self,
{"filters": len(filters), "limit": limit, "version": self.version},
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit)
if self.version == "v1.1" and self.enable_graph
if self.api_version == "v1.1" and self.enable_graph
else None
)
@@ -416,7 +416,7 @@ class Memory(MemoryBase):
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.version == "v1.1":
if self.api_version == "v1.1":
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
else:
@@ -518,14 +518,14 @@ class Memory(MemoryBase):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"filters": len(filters)})
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
if self.version == "v1.1" and self.enable_graph:
if self.api_version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
return {"message": "Memories deleted successfully!"}
@@ -561,6 +561,7 @@ class Memory(MemoryBase):
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
@@ -603,6 +604,7 @@ class Memory(MemoryBase):
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
return memory_id
def _delete_memory(self, memory_id):
@@ -611,6 +613,7 @@ class Memory(MemoryBase):
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id
def reset(self):