Code formatting and doc update (#2130)
This commit is contained in:
@@ -58,12 +58,12 @@ class MemoryGraph:
|
||||
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
||||
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
||||
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
||||
|
||||
#TODO: Batch queries with APOC plugin
|
||||
#TODO: Add more filter support
|
||||
|
||||
# TODO: Batch queries with APOC plugin
|
||||
# TODO: Add more filter support
|
||||
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
|
||||
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
|
||||
|
||||
|
||||
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
||||
|
||||
def search(self, query, filters, limit=100):
|
||||
@@ -86,7 +86,9 @@ class MemoryGraph:
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [[item["source"], item["relatationship"], item["destination"]] for item in search_output]
|
||||
search_outputs_sequence = [
|
||||
[item["source"], item["relatationship"], item["destination"]] for item in search_output
|
||||
]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
tokenized_query = query.split(" ")
|
||||
@@ -142,7 +144,7 @@ class MemoryGraph:
|
||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
def _retrieve_nodes_from_data(self, data, filters):
|
||||
"""Extracts all the entities mentioned in the query."""
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
@@ -170,7 +172,7 @@ class MemoryGraph:
|
||||
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
||||
logger.debug(f"Entity type map: {entity_type_map}")
|
||||
return entity_type_map
|
||||
|
||||
|
||||
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
||||
"""Eshtablish relations among the extracted nodes."""
|
||||
if self.config.graph_store.custom_prompt:
|
||||
@@ -209,7 +211,7 @@ class MemoryGraph:
|
||||
extracted_entities = self._remove_spaces_from_entities(extracted_entities)
|
||||
logger.debug(f"Extracted entities: {extracted_entities}")
|
||||
return extracted_entities
|
||||
|
||||
|
||||
def _search_graph_db(self, node_list, filters, limit=100):
|
||||
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
||||
result_relations = []
|
||||
@@ -250,7 +252,7 @@ class MemoryGraph:
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
||||
"""Get the entities to be deleted from the search output."""
|
||||
search_output_string = format_entities(search_output)
|
||||
@@ -273,11 +275,11 @@ class MemoryGraph:
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "delete_graph_memory":
|
||||
to_be_deleted.append(item["arguments"])
|
||||
#in case if it is not in the correct format
|
||||
# in case if it is not in the correct format
|
||||
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
||||
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
||||
return to_be_deleted
|
||||
|
||||
|
||||
def _delete_entities(self, to_be_deleted, user_id):
|
||||
"""Delete the entities from the graph."""
|
||||
results = []
|
||||
@@ -285,7 +287,7 @@ class MemoryGraph:
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relatationship = item["relationship"]
|
||||
|
||||
|
||||
# Delete the specific relationship between nodes
|
||||
cypher = f"""
|
||||
MATCH (n {{name: $source_name, user_id: $user_id}})
|
||||
@@ -305,29 +307,29 @@ class MemoryGraph:
|
||||
result = self.graph.query(cypher, params=params)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _add_entities(self, to_be_added, user_id, entity_type_map):
|
||||
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
||||
results = []
|
||||
for item in to_be_added:
|
||||
#entities
|
||||
# entities
|
||||
source = item["source"]
|
||||
destination = item["destination"]
|
||||
relationship = item["relationship"]
|
||||
|
||||
#types
|
||||
# types
|
||||
source_type = entity_type_map.get(source, "unknown")
|
||||
destination_type = entity_type_map.get(destination, "unknown")
|
||||
|
||||
#embeddings
|
||||
|
||||
# embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
#search for the nodes with the closest embeddings
|
||||
|
||||
# search for the nodes with the closest embeddings
|
||||
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
|
||||
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
|
||||
|
||||
#TODO: Create a cypher query and common params for all the cases
|
||||
# TODO: Create a cypher query and common params for all the cases
|
||||
if not destination_node_search_result and source_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
@@ -343,7 +345,7 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_name": destination,
|
||||
"relationship": relationship,
|
||||
"destination_type": destination_type,
|
||||
@@ -366,9 +368,9 @@ class MemoryGraph:
|
||||
r.created = timestamp()
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
|
||||
|
||||
params = {
|
||||
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
|
||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||
"source_name": source,
|
||||
"relationship": relationship,
|
||||
"source_type": source_type,
|
||||
@@ -377,7 +379,7 @@ class MemoryGraph:
|
||||
}
|
||||
resp = self.graph.query(cypher, params=params)
|
||||
results.append(resp)
|
||||
|
||||
|
||||
elif source_node_search_result and destination_node_search_result:
|
||||
cypher = f"""
|
||||
MATCH (source)
|
||||
@@ -391,8 +393,8 @@ class MemoryGraph:
|
||||
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
|
||||
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
|
||||
"source_id": source_node_search_result[0]["elementId(source_candidate)"],
|
||||
"destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
|
||||
"user_id": user_id,
|
||||
"relationship": relationship,
|
||||
}
|
||||
@@ -432,7 +434,7 @@ class MemoryGraph:
|
||||
return entity_list
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
cypher = f"""
|
||||
cypher = """
|
||||
MATCH (source_candidate)
|
||||
WHERE source_candidate.embedding IS NOT NULL
|
||||
AND source_candidate.user_id = $user_id
|
||||
@@ -454,7 +456,7 @@ class MemoryGraph:
|
||||
|
||||
RETURN elementId(source_candidate)
|
||||
"""
|
||||
|
||||
|
||||
params = {
|
||||
"source_embedding": source_embedding,
|
||||
"user_id": user_id,
|
||||
@@ -465,7 +467,7 @@ class MemoryGraph:
|
||||
return result
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
cypher = f"""
|
||||
cypher = """
|
||||
MATCH (destination_candidate)
|
||||
WHERE destination_candidate.embedding IS NOT NULL
|
||||
AND destination_candidate.user_id = $user_id
|
||||
@@ -494,4 +496,4 @@ class MemoryGraph:
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -249,7 +249,7 @@ class Memory(MemoryBase):
|
||||
if self.api_version == "v1.1" and self.enable_graph:
|
||||
if filters.get("user_id") is None:
|
||||
filters["user_id"] = "user"
|
||||
|
||||
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||
added_entities = self.graph.add(data, filters)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import json
|
||||
|
||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
|
||||
@@ -19,6 +18,7 @@ def parse_messages(messages):
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
|
||||
|
||||
def format_entities(entities):
|
||||
if not entities:
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user