Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
try:
from langchain_memgraph import Memgraph
except ImportError:
raise ImportError(
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
)
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
)
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
@@ -74,22 +70,14 @@ class MemoryGraph:
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(
data, filters, entity_type_map
)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
to_be_deleted = self._get_delete_entities_from_search_output(
search_output, data, filters
)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
# TODO: Batch queries with APOC plugin
# TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities(
to_be_added, filters["user_id"], entity_type_map
)
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
@@ -108,16 +96,13 @@ class MemoryGraph:
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]]
for item in search_output
[item["source"], item["relationship"], item["destination"]] for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
@@ -126,9 +111,7 @@ class MemoryGraph:
search_results = []
for item in reranked_results:
search_results.append(
{"source": item[0], "relationship": item[1], "destination": item[2]}
)
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
@@ -161,9 +144,7 @@ class MemoryGraph:
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(
query, params={"user_id": filters["user_id"], "limit": limit}
)
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
final_results = []
for result in results:
@@ -208,13 +189,8 @@ class MemoryGraph:
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
for k, v in entity_type_map.items()
}
logger.debug(
f"Entity type map: {entity_type_map}\n search_results={search_results}"
)
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
@@ -223,9 +199,7 @@ class MemoryGraph:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
).replace(
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
),
},
@@ -235,9 +209,7 @@ class MemoryGraph:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
),
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
},
{
"role": "user",
@@ -304,9 +276,7 @@ class MemoryGraph:
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(
search_output_string, data, filters["user_id"]
)
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
@@ -379,12 +349,8 @@ class MemoryGraph:
# search for the nodes with the closest embeddings; this is basically
# comparison of one embedding to all embeddings in a graph -> vector
# search with cosine similarity metric
source_node_search_result = self._search_source_node(
source_embedding, user_id, threshold=0.9
)
destination_node_search_result = self._search_destination_node(
dest_embedding, user_id, threshold=0.9
)
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
@@ -424,9 +390,7 @@ class MemoryGraph:
"""
params = {
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
@@ -445,9 +409,7 @@ class MemoryGraph:
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
"user_id": user_id,
}
else: