Add neo4j base label config (#2675)

This commit is contained in:
Tomaz Bratanic
2025-05-20 03:22:20 +02:00
committed by GitHub
parent 12a268da30
commit 1786d907f7
2 changed files with 56 additions and 35 deletions

View File

@@ -10,6 +10,7 @@ class Neo4jConfig(BaseModel):
username: Optional[str] = Field(None, description="Username for the graph database") username: Optional[str] = Field(None, description="Username for the graph database")
password: Optional[str] = Field(None, description="Password for the graph database") password: Optional[str] = Field(None, description="Password for the graph database")
database: Optional[str] = Field(None, description="Database for the graph database") database: Optional[str] = Field(None, description="Database for the graph database")
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
@model_validator(mode="before") @model_validator(mode="before")
def check_host_port_or_path(cls, values): def check_host_port_or_path(cls, values):

View File

@@ -35,11 +35,25 @@ class MemoryGraph:
self.config.graph_store.config.password, self.config.graph_store.config.password,
self.config.graph_store.config.database, self.config.graph_store.config.database,
refresh_schema=False, refresh_schema=False,
driver_config={"notifications_min_severity":"OFF"}, driver_config={"notifications_min_severity": "OFF"},
) )
self.embedding_model = EmbedderFactory.create( self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
) )
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
if self.config.graph_store.config.base_label:
# Safely add user_id index
try:
self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)")
except Exception:
pass
try: # Safely try to add composite index (Enterprise only)
self.graph.query(
f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)"
)
except Exception:
pass
self.llm_provider = "openai_structured" self.llm_provider = "openai_structured"
if self.config.llm.provider: if self.config.llm.provider:
@@ -108,8 +122,8 @@ class MemoryGraph:
return search_results return search_results
def delete_all(self, filters): def delete_all(self, filters):
cypher = """ cypher = f"""
MATCH (n {user_id: $user_id}) MATCH (n {self.node_label} {{user_id: $user_id}})
DETACH DELETE n DETACH DELETE n
""" """
params = {"user_id": filters["user_id"]} params = {"user_id": filters["user_id"]}
@@ -127,10 +141,9 @@ class MemoryGraph:
- 'contexts': The base data store response for each memory. - 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships - 'entities': A list of strings representing the nodes and relationships
""" """
# return all nodes and relationships # return all nodes and relationships
query = """ query = f"""
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id}) MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
RETURN n.name AS source, type(r) AS relationship, m.name AS target RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit LIMIT $limit
""" """
@@ -224,22 +237,21 @@ class MemoryGraph:
def _search_graph_db(self, node_list, filters, limit=100): def _search_graph_db(self, node_list, filters, limit=100):
"""Search similar nodes among and their respective incoming and outgoing relations.""" """Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = [] result_relations = []
for node in node_list: for node in node_list:
n_embedding = self.embedding_model.embed(node) n_embedding = self.embedding_model.embed(node)
cypher_query = """ cypher_query = f"""
MATCH (n) MATCH (n {self.node_label})
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
WHERE similarity >= $threshold WHERE similarity >= $threshold
CALL (n) { CALL (n) {{
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
UNION UNION
MATCH (m)-[r]->(n) MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
} }}
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
ORDER BY similarity DESC ORDER BY similarity DESC
@@ -294,9 +306,9 @@ class MemoryGraph:
# Delete the specific relationship between nodes # Delete the specific relationship between nodes
cypher = f""" cypher = f"""
MATCH (n {{name: $source_name, user_id: $user_id}}) MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
-[r:{relationship}]-> -[r:{relationship}]->
(m {{name: $dest_name, user_id: $user_id}}) (m {self.node_label} {{name: $dest_name, user_id: $user_id}})
DELETE r DELETE r
RETURN RETURN
n.name AS source, n.name AS source,
@@ -323,7 +335,11 @@ class MemoryGraph:
# types # types
source_type = entity_type_map.get(source, "__User__") source_type = entity_type_map.get(source, "__User__")
source_label = self.node_label if self.node_label else f":`{source_type}`"
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
destination_type = entity_type_map.get(destination, "__User__") destination_type = entity_type_map.get(destination, "__User__")
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
# embeddings # embeddings
source_embedding = self.embedding_model.embed(source) source_embedding = self.embedding_model.embed(source)
@@ -340,10 +356,11 @@ class MemoryGraph:
WHERE elementId(source) = $source_id WHERE elementId(source) = $source_id
SET source.mentions = coalesce(source.mentions, 0) + 1 SET source.mentions = coalesce(source.mentions, 0) + 1
WITH source WITH source
MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}}) MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
ON CREATE SET ON CREATE SET
destination.created = timestamp(), destination.created = timestamp(),
destination.mentions = 1 destination.mentions = 1
{destination_extra_set}
ON MATCH SET ON MATCH SET
destination.mentions = coalesce(destination.mentions, 0) + 1 destination.mentions = coalesce(destination.mentions, 0) + 1
WITH source, destination WITH source, destination
@@ -370,10 +387,11 @@ class MemoryGraph:
WHERE elementId(destination) = $destination_id WHERE elementId(destination) = $destination_id
SET destination.mentions = coalesce(destination.mentions, 0) + 1 SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH destination WITH destination
MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}}) MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
ON CREATE SET ON CREATE SET
source.created = timestamp(), source.created = timestamp(),
source.mentions = 1 source.mentions = 1
{source_extra_set}
ON MATCH SET ON MATCH SET
source.mentions = coalesce(source.mentions, 0) + 1 source.mentions = coalesce(source.mentions, 0) + 1
WITH source, destination WITH source, destination
@@ -420,24 +438,26 @@ class MemoryGraph:
} }
else: else:
cypher = f""" cypher = f"""
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}}) MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), ON CREATE SET source.created = timestamp(),
n.mentions = 1 source.mentions = 1
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1 {source_extra_set}
WITH n ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
CALL db.create.setNodeVectorProperty(n, 'embedding', $source_embedding) WITH source
WITH n CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}}) WITH source
ON CREATE SET m.created = timestamp(), MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}})
m.mentions = 1 ON CREATE SET destination.created = timestamp(),
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1 destination.mentions = 1
WITH n, m {destination_extra_set}
CALL db.create.setNodeVectorProperty(m, 'embedding', $source_embedding) ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
WITH n, m WITH source, destination
MERGE (n)-[rel:{relationship}]->(m) CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding)
WITH source, destination
MERGE (source)-[rel:{relationship}]->(destination)
ON CREATE SET rel.created = timestamp(), rel.mentions = 1 ON CREATE SET rel.created = timestamp(), rel.mentions = 1
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1 ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
RETURN n.name AS source, type(rel) AS relationship, m.name AS target RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
""" """
params = { params = {
"source_name": source, "source_name": source,
@@ -458,8 +478,8 @@ class MemoryGraph:
return entity_list return entity_list
def _search_source_node(self, source_embedding, user_id, threshold=0.9): def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = """ cypher = f"""
MATCH (source_candidate) MATCH (source_candidate {self.node_label})
WHERE source_candidate.embedding IS NOT NULL WHERE source_candidate.embedding IS NOT NULL
AND source_candidate.user_id = $user_id AND source_candidate.user_id = $user_id
@@ -484,8 +504,8 @@ class MemoryGraph:
return result return result
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = """ cypher = f"""
MATCH (destination_candidate) MATCH (destination_candidate {self.node_label})
WHERE destination_candidate.embedding IS NOT NULL WHERE destination_candidate.embedding IS NOT NULL
AND destination_candidate.user_id = $user_id AND destination_candidate.user_id = $user_id