Improve neo4j queries (#2654)

This commit is contained in:
Tomaz Bratanic
2025-05-08 20:11:46 +02:00
committed by GitHub
parent 84910b40da
commit 0d895b28ae

View File

@@ -92,7 +92,7 @@ class MemoryGraph:
return [] return []
search_outputs_sequence = [ search_outputs_sequence = [
[item["source"], item["relatationship"], item["destination"]] for item in search_output [item["source"], item["relationship"], item["destination"]] for item in search_output
] ]
bm25 = BM25Okapi(search_outputs_sequence) bm25 = BM25Okapi(search_outputs_sequence)
@@ -231,23 +231,17 @@ class MemoryGraph:
cypher_query = """ cypher_query = """
MATCH (n) MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold WHERE similarity >= $threshold
CALL (n) {
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relatationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
UNION UNION
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (m)-[r]->(n) MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relatationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
}
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
ORDER BY similarity DESC ORDER BY similarity DESC
LIMIT $limit LIMIT $limit
""" """
@@ -280,6 +274,7 @@ class MemoryGraph:
], ],
tools=_tools, tools=_tools,
) )
to_be_deleted = [] to_be_deleted = []
for item in memory_updates["tool_calls"]: for item in memory_updates["tool_calls"]:
if item["name"] == "delete_graph_memory": if item["name"] == "delete_graph_memory":
@@ -295,12 +290,12 @@ class MemoryGraph:
for item in to_be_deleted: for item in to_be_deleted:
source = item["source"] source = item["source"]
destination = item["destination"] destination = item["destination"]
relatationship = item["relationship"] relationship = item["relationship"]
# Delete the specific relationship between nodes # Delete the specific relationship between nodes
cypher = f""" cypher = f"""
MATCH (n {{name: $source_name, user_id: $user_id}}) MATCH (n {{name: $source_name, user_id: $user_id}})
-[r:{relatationship}]-> -[r:{relationship}]->
(m {{name: $dest_name, user_id: $user_id}}) (m {{name: $dest_name, user_id: $user_id}})
DELETE r DELETE r
RETURN RETURN
@@ -345,8 +340,10 @@ class MemoryGraph:
WHERE elementId(source) = $source_id WHERE elementId(source) = $source_id
MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}}) MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}})
ON CREATE SET ON CREATE SET
destination.created = timestamp(), destination.created = timestamp()
destination.embedding = $destination_embedding WITH source, destination
CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
WITH source, destination
MERGE (source)-[r:{relationship}]->(destination) MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET ON CREATE SET
r.created = timestamp() r.created = timestamp()
@@ -365,8 +362,10 @@ class MemoryGraph:
WHERE elementId(destination) = $destination_id WHERE elementId(destination) = $destination_id
MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}}) MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET ON CREATE SET
source.created = timestamp(), source.created = timestamp()
source.embedding = $source_embedding WITH source, destination
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
WITH source, destination
MERGE (source)-[r:{relationship}]->(destination) MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET ON CREATE SET
r.created = timestamp() r.created = timestamp()
@@ -389,6 +388,8 @@ class MemoryGraph:
ON CREATE SET ON CREATE SET
r.created_at = timestamp(), r.created_at = timestamp(),
r.updated_at = timestamp() r.updated_at = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target RETURN source.name AS source, type(r) AS relationship, destination.name AS target
""" """
params = { params = {
@@ -399,11 +400,15 @@ class MemoryGraph:
else: else:
cypher = f""" cypher = f"""
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}}) MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding ON CREATE SET n.created = timestamp()
ON MATCH SET n.embedding = $source_embedding WITH n
CALL db.create.setNodeVectorProperty(n, 'embedding', $source_embedding)
WITH n
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}}) MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding ON CREATE SET m.created = timestamp()
ON MATCH SET m.embedding = $dest_embedding WITH n, m
CALL db.create.setNodeVectorProperty(m, 'embedding', $source_embedding)
WITH n, m
MERGE (n)-[rel:{relationship}]->(m) MERGE (n)-[rel:{relationship}]->(m)
ON CREATE SET rel.created = timestamp() ON CREATE SET rel.created = timestamp()
RETURN n.name AS source, type(rel) AS relationship, m.name AS target RETURN n.name AS source, type(rel) AS relationship, m.name AS target
@@ -433,14 +438,7 @@ class MemoryGraph:
AND source_candidate.user_id = $user_id AND source_candidate.user_id = $user_id
WITH source_candidate, WITH source_candidate,
round( round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility
reduce(dot = 0.0, i IN range(0, size(source_candidate.embedding)-1) |
dot + source_candidate.embedding[i] * $source_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(source_candidate.embedding)-1) |
l2 + source_candidate.embedding[i] * source_candidate.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($source_embedding)-1) |
l2 + $source_embedding[i] * $source_embedding[i])))
, 4) AS source_similarity
WHERE source_similarity >= $threshold WHERE source_similarity >= $threshold
WITH source_candidate, source_similarity WITH source_candidate, source_similarity
@@ -466,14 +464,8 @@ class MemoryGraph:
AND destination_candidate.user_id = $user_id AND destination_candidate.user_id = $user_id
WITH destination_candidate, WITH destination_candidate,
round( round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility
reduce(dot = 0.0, i IN range(0, size(destination_candidate.embedding)-1) |
dot + destination_candidate.embedding[i] * $destination_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(destination_candidate.embedding)-1) |
l2 + destination_candidate.embedding[i] * destination_candidate.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($destination_embedding)-1) |
l2 + $destination_embedding[i] * $destination_embedding[i])))
, 4) AS destination_similarity
WHERE destination_similarity >= $threshold WHERE destination_similarity >= $threshold
WITH destination_candidate, destination_similarity WITH destination_candidate, destination_similarity