Formatting (#2526)

This commit is contained in:
Dev Khant
2025-04-10 11:42:25 +05:30
committed by GitHub
parent 616313b8b5
commit 07462adc9a
2 changed files with 79 additions and 85 deletions

View File

@@ -165,7 +165,7 @@ class MemoryGraph:
try: try:
for tool_call in search_results["tool_calls"]: for tool_call in search_results["tool_calls"]:
if tool_call['name'] != "extract_entities": if tool_call["name"] != "extract_entities":
continue continue
for item in tool_call["arguments"]["entities"]: for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"] entity_type_map[item["entity"]] = item["entity_type"]

View File

@@ -783,7 +783,7 @@ class AsyncMemory(MemoryBase):
self.graph = MemoryGraph(self.config) self.graph = MemoryGraph(self.config)
self.enable_graph = True self.enable_graph = True
capture_event("mem0.init", self) capture_event("async_mem0.init", self)
@classmethod @classmethod
async def from_config(cls, config_dict: Dict[str, Any]): async def from_config(cls, config_dict: Dict[str, Any]):
@@ -1004,24 +1004,24 @@ class AsyncMemory(MemoryBase):
logging.info("Skipping memory entry because of empty `text` field.") logging.info("Skipping memory entry because of empty `text` field.")
continue continue
elif resp.get("event") == "ADD": elif resp.get("event") == "ADD":
task = asyncio.create_task(self._create_memory( task = asyncio.create_task(
data=resp.get("text"), self._create_memory(
existing_embeddings=new_message_embeddings, data=resp.get("text"), existing_embeddings=new_message_embeddings, metadata=metadata
metadata=metadata )
)) )
memory_tasks.append((task, resp, "ADD", None)) memory_tasks.append((task, resp, "ADD", None))
elif resp.get("event") == "UPDATE": elif resp.get("event") == "UPDATE":
task = asyncio.create_task(self._update_memory( task = asyncio.create_task(
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]], memory_id=temp_uuid_mapping[resp["id"]],
data=resp.get("text"), data=resp.get("text"),
existing_embeddings=new_message_embeddings, existing_embeddings=new_message_embeddings,
metadata=metadata, metadata=metadata,
)) )
)
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]])) memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
elif resp.get("event") == "DELETE": elif resp.get("event") == "DELETE":
task = asyncio.create_task(self._delete_memory( task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
memory_id=temp_uuid_mapping[resp.get("id")]
))
memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp["id"]])) memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp["id"]]))
elif resp.get("event") == "NONE": elif resp.get("event") == "NONE":
logging.info("NOOP for Memory.") logging.info("NOOP for Memory.")
@@ -1033,31 +1033,37 @@ class AsyncMemory(MemoryBase):
try: try:
result_id = await task result_id = await task
if event_type == "ADD": if event_type == "ADD":
returned_memories.append({ returned_memories.append(
{
"id": result_id, "id": result_id,
"memory": resp.get("text"), "memory": resp.get("text"),
"event": resp.get("event"), "event": resp.get("event"),
}) }
)
elif event_type == "UPDATE": elif event_type == "UPDATE":
returned_memories.append({ returned_memories.append(
{
"id": mem_id, "id": mem_id,
"memory": resp.get("text"), "memory": resp.get("text"),
"event": resp.get("event"), "event": resp.get("event"),
"previous_memory": resp.get("old_memory"), "previous_memory": resp.get("old_memory"),
}) }
)
elif event_type == "DELETE": elif event_type == "DELETE":
returned_memories.append({ returned_memories.append(
{
"id": mem_id, "id": mem_id,
"memory": resp.get("text"), "memory": resp.get("text"),
"event": resp.get("event"), "event": resp.get("event"),
}) }
)
except Exception as e: except Exception as e:
logging.error(f"Error processing memory task: {e}") logging.error(f"Error processing memory task: {e}")
except Exception as e: except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}") logging.error(f"Error in new_memories_with_actions: {e}")
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())}) capture_event("async_mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
return returned_memories return returned_memories
@@ -1082,7 +1088,7 @@ class AsyncMemory(MemoryBase):
Returns: Returns:
dict: Retrieved memory. dict: Retrieved memory.
""" """
capture_event("mem0.get", self, {"memory_id": memory_id}) capture_event("async_mem0.get", self, {"memory_id": memory_id})
memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id) memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
if not memory: if not memory:
return None return None
@@ -1123,7 +1129,7 @@ class AsyncMemory(MemoryBase):
if run_id: if run_id:
filters["run_id"] = run_id filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())}) capture_event("async_mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
# Run vector store and graph operations concurrently # Run vector store and graph operations concurrently
vector_store_task = asyncio.create_task(self._get_all_from_vector_store(filters, limit)) vector_store_task = asyncio.create_task(self._get_all_from_vector_store(filters, limit))
@@ -1210,7 +1216,7 @@ class AsyncMemory(MemoryBase):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!") raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
capture_event( capture_event(
"mem0.search", "async_mem0.search",
self, self,
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())}, {"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
) )
@@ -1243,11 +1249,7 @@ class AsyncMemory(MemoryBase):
async def _search_vector_store(self, query, filters, limit): async def _search_vector_store(self, query, filters, limit):
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search") embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
memories = await asyncio.to_thread( memories = await asyncio.to_thread(
self.vector_store.search, self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters
query=query,
vectors=embeddings,
limit=limit,
filters=filters
) )
excluded_keys = { excluded_keys = {
@@ -1294,7 +1296,7 @@ class AsyncMemory(MemoryBase):
Returns: Returns:
dict: Updated memory. dict: Updated memory.
""" """
capture_event("mem0.update", self, {"memory_id": memory_id}) capture_event("async_mem0.update", self, {"memory_id": memory_id})
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update") embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
existing_embeddings = {data: embeddings} existing_embeddings = {data: embeddings}
@@ -1309,7 +1311,7 @@ class AsyncMemory(MemoryBase):
Args: Args:
memory_id (str): ID of the memory to delete. memory_id (str): ID of the memory to delete.
""" """
capture_event("mem0.delete", self, {"memory_id": memory_id}) capture_event("async_mem0.delete", self, {"memory_id": memory_id})
await self._delete_memory(memory_id) await self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"} return {"message": "Memory deleted successfully!"}
@@ -1335,7 +1337,7 @@ class AsyncMemory(MemoryBase):
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method." "At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
) )
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())}) capture_event("async_mem0.delete_all", self, {"keys": list(filters.keys())})
memories = await asyncio.to_thread(self.vector_store.list, filters=filters) memories = await asyncio.to_thread(self.vector_store.list, filters=filters)
delete_tasks = [] delete_tasks = []
@@ -1361,7 +1363,7 @@ class AsyncMemory(MemoryBase):
Returns: Returns:
list: List of changes for the memory. list: List of changes for the memory.
""" """
capture_event("mem0.history", self, {"memory_id": memory_id}) capture_event("async_mem0.history", self, {"memory_id": memory_id})
return await asyncio.to_thread(self.db.get_history, memory_id) return await asyncio.to_thread(self.db.get_history, memory_id)
async def _create_memory(self, data, existing_embeddings, metadata=None): async def _create_memory(self, data, existing_embeddings, metadata=None):
@@ -1384,16 +1386,9 @@ class AsyncMemory(MemoryBase):
payloads=[metadata], payloads=[metadata],
) )
await asyncio.to_thread( await asyncio.to_thread(self.db.add_history, memory_id, None, data, "ADD", created_at=metadata["created_at"])
self.db.add_history,
memory_id,
None,
data,
"ADD",
created_at=metadata["created_at"]
)
capture_event("mem0._create_memory", self, {"memory_id": memory_id}) capture_event("async_mem0._create_memory", self, {"memory_id": memory_id})
return memory_id return memory_id
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None): async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
@@ -1411,7 +1406,9 @@ class AsyncMemory(MemoryBase):
convert_to_messages, # type: ignore convert_to_messages, # type: ignore
) )
except Exception: except Exception:
logger.error("Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory.") logger.error(
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
)
raise raise
logger.info("Creating procedural memory") logger.info("Creating procedural memory")
@@ -1428,10 +1425,7 @@ class AsyncMemory(MemoryBase):
response = await asyncio.to_thread(llm.invoke, input=parsed_messages) response = await asyncio.to_thread(llm.invoke, input=parsed_messages)
procedural_memory = response.content procedural_memory = response.content
else: else:
procedural_memory = await asyncio.to_thread( procedural_memory = await asyncio.to_thread(self.llm.generate_response, messages=parsed_messages)
self.llm.generate_response,
messages=parsed_messages
)
except Exception as e: except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}") logger.error(f"Error generating procedural memory summary: {e}")
raise raise
@@ -1444,7 +1438,7 @@ class AsyncMemory(MemoryBase):
embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add") embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add")
# Create the memory # Create the memory
memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata) memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id}) capture_event("async_mem0._create_procedural_memory", self, {"memory_id": memory_id})
# Return results in the same format as add() # Return results in the same format as add()
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]} result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
@@ -1498,7 +1492,7 @@ class AsyncMemory(MemoryBase):
updated_at=new_metadata["updated_at"], updated_at=new_metadata["updated_at"],
) )
capture_event("mem0._update_memory", self, {"memory_id": memory_id}) capture_event("async_mem0._update_memory", self, {"memory_id": memory_id})
return memory_id return memory_id
async def _delete_memory(self, memory_id): async def _delete_memory(self, memory_id):
@@ -1509,7 +1503,7 @@ class AsyncMemory(MemoryBase):
await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id) await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id)
await asyncio.to_thread(self.db.add_history, memory_id, prev_value, None, "DELETE", is_deleted=1) await asyncio.to_thread(self.db.add_history, memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id}) capture_event("async_mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id return memory_id
async def reset(self): async def reset(self):
@@ -1522,7 +1516,7 @@ class AsyncMemory(MemoryBase):
self.config.vector_store.provider, self.config.vector_store.config self.config.vector_store.provider, self.config.vector_store.config
) )
await asyncio.to_thread(self.db.reset) await asyncio.to_thread(self.db.reset)
capture_event("mem0.reset", self) capture_event("async_mem0.reset", self)
async def chat(self, query): async def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.") raise NotImplementedError("Chat function not implemented yet.")