Add neo4j base label config (#2675)

This commit is contained in:
Tomaz Bratanic
2025-05-20 03:22:20 +02:00
committed by GitHub
parent 12a268da30
commit 1786d907f7
2 changed files with 56 additions and 35 deletions

View File

@@ -10,6 +10,7 @@ class Neo4jConfig(BaseModel):
username: Optional[str] = Field(None, description="Username for the graph database")
password: Optional[str] = Field(None, description="Password for the graph database")
database: Optional[str] = Field(None, description="Database for the graph database")
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):

View File

@@ -40,6 +40,20 @@ class MemoryGraph:
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
)
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
if self.config.graph_store.config.base_label:
# Safely add user_id index
try:
self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)")
except Exception:
pass
try: # Safely try to add composite index (Enterprise only)
self.graph.query(
f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)"
)
except Exception:
pass
self.llm_provider = "openai_structured"
if self.config.llm.provider:
@@ -108,8 +122,8 @@ class MemoryGraph:
return search_results
def delete_all(self, filters):
cypher = """
MATCH (n {user_id: $user_id})
cypher = f"""
MATCH (n {self.node_label} {{user_id: $user_id}})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
@@ -127,10 +141,9 @@ class MemoryGraph:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
# return all nodes and relationships
query = """
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
query = f"""
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
@@ -224,22 +237,21 @@ class MemoryGraph:
def _search_graph_db(self, node_list, filters, limit=100):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
for node in node_list:
n_embedding = self.embedding_model.embed(node)
cypher_query = """
MATCH (n)
cypher_query = f"""
MATCH (n {self.node_label})
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
WHERE similarity >= $threshold
CALL (n) {
CALL (n) {{
MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
}
}}
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
ORDER BY similarity DESC
@@ -294,9 +306,9 @@ class MemoryGraph:
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n {{name: $source_name, user_id: $user_id}})
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
-[r:{relationship}]->
(m {{name: $dest_name, user_id: $user_id}})
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
DELETE r
RETURN
n.name AS source,
@@ -323,7 +335,11 @@ class MemoryGraph:
# types
source_type = entity_type_map.get(source, "__User__")
source_label = self.node_label if self.node_label else f":`{source_type}`"
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
destination_type = entity_type_map.get(destination, "__User__")
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
# embeddings
source_embedding = self.embedding_model.embed(source)
@@ -340,10 +356,11 @@ class MemoryGraph:
WHERE elementId(source) = $source_id
SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}})
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
ON CREATE SET
destination.created = timestamp(),
destination.mentions = 1
{destination_extra_set}
ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1
WITH source, destination
@@ -370,10 +387,11 @@ class MemoryGraph:
WHERE elementId(destination) = $destination_id
SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH destination
MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}})
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
ON CREATE SET
source.created = timestamp(),
source.mentions = 1
{source_extra_set}
ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1
WITH source, destination
@@ -420,24 +438,26 @@ class MemoryGraph:
}
else:
cypher = f"""
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(),
n.mentions = 1
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
WITH n
CALL db.create.setNodeVectorProperty(n, 'embedding', $source_embedding)
WITH n
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(),
m.mentions = 1
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
WITH n, m
CALL db.create.setNodeVectorProperty(m, 'embedding', $source_embedding)
WITH n, m
MERGE (n)-[rel:{relationship}]->(m)
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
ON CREATE SET source.created = timestamp(),
source.mentions = 1
{source_extra_set}
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
WITH source
MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET destination.created = timestamp(),
destination.mentions = 1
{destination_extra_set}
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH source, destination
CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding)
WITH source, destination
MERGE (source)-[rel:{relationship}]->(destination)
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
"""
params = {
"source_name": source,
@@ -458,8 +478,8 @@ class MemoryGraph:
return entity_list
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = """
MATCH (source_candidate)
cypher = f"""
MATCH (source_candidate {self.node_label})
WHERE source_candidate.embedding IS NOT NULL
AND source_candidate.user_id = $user_id
@@ -484,8 +504,8 @@ class MemoryGraph:
return result
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = """
MATCH (destination_candidate)
cypher = f"""
MATCH (destination_candidate {self.node_label})
WHERE destination_candidate.embedding IS NOT NULL
AND destination_candidate.user_id = $user_id