From 1786d907f70de8eb2631a5ae37bf32c816921cc9 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Tue, 20 May 2025 03:22:20 +0200 Subject: [PATCH] Add neo4j base label config (#2675) --- mem0/graphs/configs.py | 1 + mem0/memory/graph_memory.py | 90 ++++++++++++++++++++++--------------- 2 files changed, 56 insertions(+), 35 deletions(-) diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py index 50c585a1..bbfcca8a 100644 --- a/mem0/graphs/configs.py +++ b/mem0/graphs/configs.py @@ -10,6 +10,7 @@ class Neo4jConfig(BaseModel): username: Optional[str] = Field(None, description="Username for the graph database") password: Optional[str] = Field(None, description="Password for the graph database") database: Optional[str] = Field(None, description="Database for the graph database") + base_label: Optional[bool] = Field(None, description="Whether to use base node label __Entity__ for all entities") @model_validator(mode="before") def check_host_port_or_path(cls, values): diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 8167bddf..ff50c221 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -35,11 +35,25 @@ class MemoryGraph: self.config.graph_store.config.password, self.config.graph_store.config.database, refresh_schema=False, - driver_config={"notifications_min_severity":"OFF"}, + driver_config={"notifications_min_severity": "OFF"}, ) self.embedding_model = EmbedderFactory.create( self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config ) + self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else "" + + if self.config.graph_store.config.base_label: + # Safely add user_id index + try: + self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)") + except Exception: + pass + try: # Safely try to add composite index (Enterprise only) + self.graph.query( + f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)" + ) + except Exception: + pass self.llm_provider = "openai_structured" if self.config.llm.provider: @@ -108,8 +122,8 @@ class MemoryGraph: return search_results def delete_all(self, filters): - cypher = """ - MATCH (n {user_id: $user_id}) + cypher = f""" + MATCH (n {self.node_label} {{user_id: $user_id}}) DETACH DELETE n """ params = {"user_id": filters["user_id"]} @@ -127,10 +141,9 @@ class MemoryGraph: - 'contexts': The base data store response for each memory. - 'entities': A list of strings representing the nodes and relationships """ - # return all nodes and relationships - query = """ - MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id}) + query = f""" + MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}}) RETURN n.name AS source, type(r) AS relationship, m.name AS target LIMIT $limit """ @@ -224,22 +237,21 @@ class MemoryGraph: def _search_graph_db(self, node_list, filters, limit=100): """Search similar nodes among and their respective incoming and outgoing relations.""" result_relations = [] - for node in node_list: n_embedding = self.embedding_model.embed(node) - cypher_query = """ - MATCH (n) + cypher_query = f""" + MATCH (n {self.node_label}) WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility WHERE similarity >= $threshold - CALL (n) { + CALL (n) {{ MATCH (n)-[r]->(m) RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id UNION MATCH (m)-[r]->(n) RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id - } + }} WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity //deduplicate RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity ORDER BY similarity DESC @@ -294,9 +306,9 @@ class MemoryGraph: # Delete the specific relationship between nodes cypher = f""" - MATCH (n {{name: $source_name, user_id: $user_id}}) + MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}}) -[r:{relationship}]-> - (m {{name: $dest_name, user_id: $user_id}}) + (m {self.node_label} {{name: $dest_name, user_id: $user_id}}) DELETE r RETURN n.name AS source, @@ -323,7 +335,11 @@ class MemoryGraph: # types source_type = entity_type_map.get(source, "__User__") + source_label = self.node_label if self.node_label else f":`{source_type}`" + source_extra_set = f", source:`{source_type}`" if self.node_label else "" destination_type = entity_type_map.get(destination, "__User__") + destination_label = self.node_label if self.node_label else f":`{destination_type}`" + destination_extra_set = f", destination:`{destination_type}`" if self.node_label else "" # embeddings source_embedding = self.embedding_model.embed(source) @@ -340,10 +356,11 @@ class MemoryGraph: WHERE elementId(source) = $source_id SET source.mentions = coalesce(source.mentions, 0) + 1 WITH source - MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}}) + MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}}) ON CREATE SET destination.created = timestamp(), destination.mentions = 1 + {destination_extra_set} ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1 WITH source, destination @@ -370,10 +387,11 @@ class MemoryGraph: WHERE elementId(destination) = $destination_id SET destination.mentions = coalesce(destination.mentions, 0) + 1 WITH destination - MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}}) + MERGE (source {source_label} {{name: $source_name, user_id: $user_id}}) ON CREATE SET source.created = timestamp(), source.mentions = 1 + {source_extra_set} ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1 WITH source, destination @@ -420,24 +438,26 @@ class MemoryGraph: } else: cypher = f""" - MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}}) - ON CREATE SET n.created = timestamp(), - n.mentions = 1 - ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1 - WITH n - CALL db.create.setNodeVectorProperty(n, 'embedding', $source_embedding) - WITH n - MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}}) - ON CREATE SET m.created = timestamp(), - m.mentions = 1 - ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1 - WITH n, m - CALL db.create.setNodeVectorProperty(m, 'embedding', $source_embedding) - WITH n, m - MERGE (n)-[rel:{relationship}]->(m) + MERGE (source {source_label} {{name: $source_name, user_id: $user_id}}) + ON CREATE SET source.created = timestamp(), + source.mentions = 1 + {source_extra_set} + ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1 + WITH source + CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding) + WITH source + MERGE (destination {destination_label} {{name: $dest_name, user_id: $user_id}}) + ON CREATE SET destination.created = timestamp(), + destination.mentions = 1 + {destination_extra_set} + ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1 + WITH source, destination + CALL db.create.setNodeVectorProperty(destination, 'embedding', $source_embedding) + WITH source, destination + MERGE (source)-[rel:{relationship}]->(destination) ON CREATE SET rel.created = timestamp(), rel.mentions = 1 ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1 - RETURN n.name AS source, type(rel) AS relationship, m.name AS target + RETURN source.name AS source, type(rel) AS relationship, destination.name AS target """ params = { "source_name": source, @@ -458,8 +478,8 @@ class MemoryGraph: return entity_list def _search_source_node(self, source_embedding, user_id, threshold=0.9): - cypher = """ - MATCH (source_candidate) + cypher = f""" + MATCH (source_candidate {self.node_label}) WHERE source_candidate.embedding IS NOT NULL AND source_candidate.user_id = $user_id @@ -484,8 +504,8 @@ class MemoryGraph: return result def _search_destination_node(self, destination_embedding, user_id, threshold=0.9): - cypher = """ - MATCH (destination_candidate) + cypher = f""" + MATCH (destination_candidate {self.node_label}) WHERE destination_candidate.embedding IS NOT NULL AND destination_candidate.user_id = $user_id