Formatting (#2750)
This commit is contained in:
@@ -5,16 +5,12 @@ from mem0.memory.utils import format_entities
|
||||
try:
|
||||
from langchain_memgraph import Memgraph
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
|
||||
)
|
||||
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
|
||||
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
|
||||
)
|
||||
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
@@ -74,22 +70,14 @@ class MemoryGraph:
|
||||
filters (dict): A dictionary containing filters to be applied during the addition.
|
||||
"""
|
||||
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
||||
to_be_added = self._establish_nodes_relations_from_data(
|
||||
data, filters, entity_type_map
|
||||
)
|
||||
search_output = self._search_graph_db(
|
||||
node_list=list(entity_type_map.keys()), filters=filters
|
||||
)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(
|
||||
search_output, data, filters
|
||||
)
|
||||
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||
|
||||
# TODO: Batch queries with APOC plugin
|
||||
# TODO: Add more filter support
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||
added_entities = self._add_entities(
|
||||
to_be_added, filters["user_id"], entity_type_map
|
||||
)
|
||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||
|
||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||
|
||||
@@ -108,16 +96,13 @@ class MemoryGraph:
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
||||
search_output = self._search_graph_db(
|
||||
node_list=list(entity_type_map.keys()), filters=filters
|
||||
)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [
|
||||
[item["source"], item["relationship"], item["destination"]]
|
||||
for item in search_output
|
||||
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
||||
]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
@@ -126,9 +111,7 @@ class MemoryGraph:
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append(
|
||||
{"source": item[0], "relationship": item[1], "destination": item[2]}
|
||||
)
|
||||
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
@@ -161,9 +144,7 @@ class MemoryGraph:
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = self.graph.query(
|
||||
query, params={"user_id": filters["user_id"], "limit": limit}
|
||||
)
|
||||
results = self.graph.query(query, params={"user_id": filters["user_id"], "limit": limit})
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
@@ -208,13 +189,8 @@ class MemoryGraph:
|
||||
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
||||
)
|
||||
|
||||
entity_type_map = {
|
||||
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
|
||||
for k, v in entity_type_map.items()
|
||||
}
|
||||
logger.debug(
|
||||
f"Entity type map: {entity_type_map}\n search_results={search_results}"
|
||||
)
|
||||
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
|
||||
return entity_type_map
|
||||
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
@@ -223,9 +199,7 @@ class MemoryGraph:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace(
|
||||
"USER_ID", filters["user_id"]
|
||||
).replace(
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
@@ -235,9 +209,7 @@ class MemoryGraph:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace(
|
||||
"USER_ID", filters["user_id"]
|
||||
),
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -304,9 +276,7 @@ class MemoryGraph:
|
||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||
"""Get the entities to be deleted from the search output."""
|
||||
search_output_string = format_entities(search_output)
|
||||
system_prompt, user_prompt = get_delete_messages(
|
||||
search_output_string, data, filters["user_id"]
|
||||
)
|
||||
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
||||
|
||||
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -379,12 +349,8 @@ class MemoryGraph:
|
||||
# search for the nodes with the closest embeddings; this is basically
|
||||
# comparison of one embedding to all embeddings in a graph -> vector
|
||||
# search with cosine similarity metric
|
||||
source_node_search_result = self._search_source_node(
|
||||
source_embedding, user_id, threshold=0.9
|
||||
)
|
||||
destination_node_search_result = self._search_destination_node(
|
||||
dest_embedding, user_id, threshold=0.9
|
||||
)
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||
|
||||
# TODO: Create a cypher query and common params for all the cases
|
||||
if not destination_node_search_result and source_node_search_result:
|
||||
@@ -424,9 +390,7 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_search_result[0][
|
||||
"id(destination_candidate)"
|
||||
],
|
||||
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||
"source_name": source,
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
@@ -445,9 +409,7 @@ class MemoryGraph:
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]["id(source_candidate)"],
|
||||
"destination_id": destination_node_search_result[0][
|
||||
"id(destination_candidate)"
|
||||
],
|
||||
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user