Add neo4j base label config (#2675)
This commit is contained in:
@@ -10,6 +10,7 @@ class Neo4jConfig(BaseModel):
|
|||||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||||
password: Optional[str] = Field(None, description="Password for the graph database")
|
password: Optional[str] = Field(None, description="Password for the graph database")
|
||||||
database: Optional[str] = Field(None, description="Database for the graph database")
|
database: Optional[str] = Field(None, description="Database for the graph database")
|
||||||
|
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def check_host_port_or_path(cls, values):
|
def check_host_port_or_path(cls, values):
|
||||||
|
|||||||
@@ -35,11 +35,25 @@ class MemoryGraph:
|
|||||||
self.config.graph_store.config.password,
|
self.config.graph_store.config.password,
|
||||||
self.config.graph_store.config.database,
|
self.config.graph_store.config.database,
|
||||||
refresh_schema=False,
|
refresh_schema=False,
|
||||||
driver_config={"notifications_min_severity":"OFF"},
|
driver_config={"notifications_min_severity": "OFF"},
|
||||||
)
|
)
|
||||||
self.embedding_model = EmbedderFactory.create(
|
self.embedding_model = EmbedderFactory.create(
|
||||||
self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
|
self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
|
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
||||||
|
|
||||||
|
if self.config.graph_store.config.base_label:
|
||||||
|
# Safely add user_id index
|
||||||
|
try:
|
||||||
|
self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try: # Safely try to add composite index (Enterprise only)
|
||||||
|
self.graph.query(
|
||||||
|
f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
self.llm_provider = "openai_structured"
|
self.llm_provider = "openai_structured"
|
||||||
if self.config.llm.provider:
|
if self.config.llm.provider:
|
||||||
@@ -108,8 +122,8 @@ class MemoryGraph:
|
|||||||
return search_results
|
return search_results
|
||||||
|
|
||||||
def delete_all(self, filters):
|
def delete_all(self, filters):
|
||||||
cypher = """
|
cypher = f"""
|
||||||
MATCH (n {user_id: $user_id})
|
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
"""
|
"""
|
||||||
params = {"user_id": filters["user_id"]}
|
params = {"user_id": filters["user_id"]}
|
||||||
@@ -127,10 +141,9 @@ class MemoryGraph:
|
|||||||
- 'contexts': The base data store response for each memory.
|
- 'contexts': The base data store response for each memory.
|
||||||
- 'entities': A list of strings representing the nodes and relationships
|
- 'entities': A list of strings representing the nodes and relationships
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# return all nodes and relationships
|
# return all nodes and relationships
|
||||||
query = """
|
query = f"""
|
||||||
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
|
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
@@ -224,22 +237,21 @@ class MemoryGraph:
|
|||||||
def _search_graph_db(self, node_list, filters, limit=100):
|
def _search_graph_db(self, node_list, filters, limit=100):
|
||||||
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
||||||
result_relations = []
|
result_relations = []
|
||||||
|
|
||||||
for node in node_list:
|
for node in node_list:
|
||||||
n_embedding = self.embedding_model.embed(node)
|
n_embedding = self.embedding_model.embed(node)
|
||||||
|
|
||||||
cypher_query = """
|
cypher_query = f"""
|
||||||
MATCH (n)
|
MATCH (n {self.node_label})
|
||||||
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||||
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
|
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
|
||||||
WHERE similarity >= $threshold
|
WHERE similarity >= $threshold
|
||||||
CALL (n) {
|
CALL (n) {{
|
||||||
MATCH (n)-[r]->(m)
|
MATCH (n)-[r]->(m)
|
||||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
|
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
|
||||||
UNION
|
UNION
|
||||||
MATCH (m)-[r]->(n)
|
MATCH (m)-[r]->(n)
|
||||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
|
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
|
||||||
}
|
}}
|
||||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
|
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
|
||||||
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||||
ORDER BY similarity DESC
|
ORDER BY similarity DESC
|
||||||
@@ -294,9 +306,9 @@ class MemoryGraph:
|
|||||||
|
|
||||||
# Delete the specific relationship between nodes
|
# Delete the specific relationship between nodes
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n {{name: $source_name, user_id: $user_id}})
|
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||||
-[r:{relationship}]->
|
-[r:{relationship}]->
|
||||||
(m {{name: $dest_name, user_id: $user_id}})
|
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||||
DELETE r
|
DELETE r
|
||||||
RETURN
|
RETURN
|
||||||
n.name AS source,
|
n.name AS source,
|
||||||
@@ -323,7 +335,11 @@ class MemoryGraph:
|
|||||||
|
|
||||||
# types
|
# types
|
||||||
source_type = entity_type_map.get(source, "__User__")
|
source_type = entity_type_map.get(source, "__User__")
|
||||||
|
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||||
|
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||||
destination_type = entity_type_map.get(destination, "__User__")
|
destination_type = entity_type_map.get(destination, "__User__")
|
||||||
|
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||||
|
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
source_embedding = self.embedding_model.embed(source)
|
source_embedding = self.embedding_model.embed(source)
|
||||||
@@ -340,10 +356,11 @@ class MemoryGraph:
|
|||||||
WHERE elementId(source) = $source_id
|
WHERE elementId(source) = $source_id
|
||||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
WITH source
|
WITH source
|
||||||
MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}})
|
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
destination.created = timestamp(),
|
destination.created = timestamp(),
|
||||||
destination.mentions = 1
|
destination.mentions = 1
|
||||||
|
{destination_extra_set}
|
||||||
ON MATCH SET
|
ON MATCH SET
|
||||||
destination.mentions = coalesce(destination.mentions, 0) + 1
|
destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
WITH source, destination
|
WITH source, destination
|
||||||
@@ -370,10 +387,11 @@ class MemoryGraph:
|
|||||||
WHERE elementId(destination) = $destination_id
|
WHERE elementId(destination) = $destination_id
|
||||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
WITH destination
|
WITH destination
|
||||||
MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}})
|
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||||
ON CREATE SET
|
ON CREATE SET
|
||||||
source.created = timestamp(),
|
source.created = timestamp(),
|
||||||
source.mentions = 1
|
source.mentions = 1
|
||||||
|
{source_extra_set}
|
||||||
ON MATCH SET
|
ON MATCH SET
|
||||||
source.mentions = coalesce(source.mentions, 0) + 1
|
source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
WITH source, destination
|
WITH source, destination
|
||||||
@@ -420,24 +438,26 @@ class MemoryGraph:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
|
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||||
ON CREATE SET n.created = timestamp(),
|
ON CREATE SET source.created = timestamp(),
|
||||||
n.mentions = 1
|
source.mentions = 1
|
||||||
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
|
{source_extra_set}
|
||||||
WITH n
|
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||||
CALL db.create.setNodeVectorProperty(n, 'embedding', $source_embedding)
|
WITH source
|
||||||
WITH n
|
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||||
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
|
WITH source
|
||||||
ON CREATE SET m.created = timestamp(),
|
MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}})
|
||||||
m.mentions = 1
|
ON CREATE SET destination.created = timestamp(),
|
||||||
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
|
destination.mentions = 1
|
||||||
WITH n, m
|
{destination_extra_set}
|
||||||
CALL db.create.setNodeVectorProperty(m, 'embedding', $source_embedding)
|
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||||
WITH n, m
|
WITH source, destination
|
||||||
MERGE (n)-[rel:{relationship}]->(m)
|
CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding)
|
||||||
|
WITH source, destination
|
||||||
|
MERGE (source)-[rel:{relationship}]->(destination)
|
||||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||||
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"source_name": source,
|
"source_name": source,
|
||||||
@@ -458,8 +478,8 @@ class MemoryGraph:
|
|||||||
return entity_list
|
return entity_list
|
||||||
|
|
||||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||||
cypher = """
|
cypher = f"""
|
||||||
MATCH (source_candidate)
|
MATCH (source_candidate {self.node_label})
|
||||||
WHERE source_candidate.embedding IS NOT NULL
|
WHERE source_candidate.embedding IS NOT NULL
|
||||||
AND source_candidate.user_id = $user_id
|
AND source_candidate.user_id = $user_id
|
||||||
|
|
||||||
@@ -484,8 +504,8 @@ class MemoryGraph:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||||
cypher = """
|
cypher = f"""
|
||||||
MATCH (destination_candidate)
|
MATCH (destination_candidate {self.node_label})
|
||||||
WHERE destination_candidate.embedding IS NOT NULL
|
WHERE destination_candidate.embedding IS NOT NULL
|
||||||
AND destination_candidate.user_id = $user_id
|
AND destination_candidate.user_id = $user_id
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user