Add neo4j base label config (#2675)
This commit is contained in:
@@ -10,6 +10,7 @@ class Neo4jConfig(BaseModel):
|
||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||
password: Optional[str] = Field(None, description="Password for the graph database")
|
||||
database: Optional[str] = Field(None, description="Database for the graph database")
|
||||
base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
|
||||
@@ -40,6 +40,20 @@ class MemoryGraph:
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
|
||||
)
|
||||
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
||||
|
||||
if self.config.graph_store.config.base_label:
|
||||
# Safely add user_id index
|
||||
try:
|
||||
self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)")
|
||||
except Exception:
|
||||
pass
|
||||
try: # Safely try to add composite index (Enterprise only)
|
||||
self.graph.query(
|
||||
f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.llm_provider = "openai_structured"
|
||||
if self.config.llm.provider:
|
||||
@@ -108,8 +122,8 @@ class MemoryGraph:
|
||||
return search_results
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher = """
|
||||
MATCH (n {user_id: $user_id})
|
||||
cypher = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
@@ -127,10 +141,9 @@ class MemoryGraph:
|
||||
- 'contexts': The base data store response for each memory.
|
||||
- 'entities': A list of strings representing the nodes and relationships
|
||||
"""
|
||||
|
||||
# return all nodes and relationships
|
||||
query = """
|
||||
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
|
||||
query = f"""
|
||||
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -224,22 +237,21 @@ class MemoryGraph:
|
||||
def _search_graph_db(self, node_list, filters, limit=100):
|
||||
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
|
||||
cypher_query = """
|
||||
MATCH (n)
|
||||
cypher_query = f"""
|
||||
MATCH (n {self.node_label})
|
||||
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||
WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
|
||||
WHERE similarity >= $threshold
|
||||
CALL (n) {
|
||||
CALL (n) {{
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
|
||||
UNION
|
||||
MATCH (m)-[r]->(n)
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
|
||||
}
|
||||
}}
|
||||
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate
|
||||
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
@@ -294,9 +306,9 @@ class MemoryGraph:
|
||||
|
||||
# Delete the specific relationship between nodes
|
||||
cypher = f"""
|
||||
MATCH (n {{name: $source_name, user_id: $user_id}})
|
||||
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
||||
-[r:{relationship}]->
|
||||
(m {{name: $dest_name, user_id: $user_id}})
|
||||
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
||||
DELETE r
|
||||
RETURN
|
||||
n.name AS source,
|
||||
@@ -323,7 +335,11 @@ class MemoryGraph:
|
||||
|
||||
# types
|
||||
source_type = entity_type_map.get(source, "__User__")
|
||||
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
||||
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
||||
destination_type = entity_type_map.get(destination, "__User__")
|
||||
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
||||
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
||||
|
||||
# embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
@@ -340,10 +356,11 @@ class MemoryGraph:
|
||||
WHERE elementId(source) = $source_id
|
||||
SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}})
|
||||
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET
|
||||
destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
@@ -370,10 +387,11 @@ class MemoryGraph:
|
||||
WHERE elementId(destination) = $destination_id
|
||||
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH destination
|
||||
MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}})
|
||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET
|
||||
source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET
|
||||
source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
@@ -420,24 +438,26 @@ class MemoryGraph:
|
||||
}
|
||||
else:
|
||||
cypher = f"""
|
||||
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET n.created = timestamp(),
|
||||
n.mentions = 1
|
||||
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
|
||||
WITH n
|
||||
CALL db.create.setNodeVectorProperty(n, 'embedding', $source_embedding)
|
||||
WITH n
|
||||
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
|
||||
ON CREATE SET m.created = timestamp(),
|
||||
m.mentions = 1
|
||||
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
|
||||
WITH n, m
|
||||
CALL db.create.setNodeVectorProperty(m, 'embedding', $source_embedding)
|
||||
WITH n, m
|
||||
MERGE (n)-[rel:{relationship}]->(m)
|
||||
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET source.created = timestamp(),
|
||||
source.mentions = 1
|
||||
{source_extra_set}
|
||||
ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
|
||||
WITH source
|
||||
CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
|
||||
WITH source
|
||||
MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}})
|
||||
ON CREATE SET destination.created = timestamp(),
|
||||
destination.mentions = 1
|
||||
{destination_extra_set}
|
||||
ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
||||
WITH source, destination
|
||||
CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding)
|
||||
WITH source, destination
|
||||
MERGE (source)-[rel:{relationship}]->(destination)
|
||||
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
||||
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
||||
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
||||
RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
|
||||
"""
|
||||
params = {
|
||||
"source_name": source,
|
||||
@@ -458,8 +478,8 @@ class MemoryGraph:
|
||||
return entity_list
|
||||
|
||||
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
|
||||
cypher = """
|
||||
MATCH (source_candidate)
|
||||
cypher = f"""
|
||||
MATCH (source_candidate {self.node_label})
|
||||
WHERE source_candidate.embedding IS NOT NULL
|
||||
AND source_candidate.user_id = $user_id
|
||||
|
||||
@@ -484,8 +504,8 @@ class MemoryGraph:
|
||||
return result
|
||||
|
||||
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
|
||||
cypher = """
|
||||
MATCH (destination_candidate)
|
||||
cypher = f"""
|
||||
MATCH (destination_candidate {self.node_label})
|
||||
WHERE destination_candidate.embedding IS NOT NULL
|
||||
AND destination_candidate.user_id = $user_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user