From 65056311a62249a7c82a9aaeb3ce01a60606b764 Mon Sep 17 00:00:00 2001 From: Prateek Chhikara <46902268+prateekchhikara@users.noreply.github.com> Date: Tue, 3 Sep 2024 09:47:35 -0700 Subject: [PATCH] Added user_id support for graph memory --- mem0/memory/main.py | 8 +++--- mem0/memory/main_graph.py | 56 ++++++++++++++++++++------------------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 53eae320..d0ec2279 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -177,7 +177,7 @@ class Memory(MemoryBase): self.graph.user_id = user_id else: self.graph.user_id = "USER" - added_entities = self.graph.add(data) + added_entities = self.graph.add(data, filters) return {"message": "ok"} @@ -281,7 +281,7 @@ class Memory(MemoryBase): if self.version == "v1.1": if self.enable_graph: - graph_entities = self.graph.get_all() + graph_entities = self.graph.get_all(filters) return {"memories": all_memories, "entities": graph_entities} else: return {"memories" : all_memories} @@ -374,7 +374,7 @@ class Memory(MemoryBase): if self.version == "v1.1": if self.enable_graph: - graph_entities = self.graph.search(query) + graph_entities = self.graph.search(query, filters) return {"memories": original_memories, "entities": graph_entities} else: return {"memories" : original_memories} @@ -442,7 +442,7 @@ class Memory(MemoryBase): self._delete_memory_tool(memory.id) if self.version == "v1.1" and self.enable_graph: - self.graph.delete_all() + self.graph.delete_all(filters) return {'message': 'Memories deleted successfully!'} diff --git a/mem0/memory/main_graph.py b/mem0/memory/main_graph.py index 8b84ab21..a4ecfc81 100644 --- a/mem0/memory/main_graph.py +++ b/mem0/memory/main_graph.py @@ -25,7 +25,7 @@ class MemoryGraph: self.user_id = None self.threshold = 0.7 - def add(self, data): + def add(self, data, filters): """ Adds data to the graph. @@ -38,7 +38,7 @@ class MemoryGraph: """ # retrieve the search results - search_output = self._search(data) + search_output = self._search(data, filters) if self.config.graph_store.custom_prompt: messages=[ @@ -74,7 +74,7 @@ class MemoryGraph: if item['name'] == "add_graph_memory": to_be_added.append(item['arguments']) elif item['name'] == "update_graph_memory": - self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship']) + self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters) elif item['name'] == "noop": continue @@ -91,10 +91,10 @@ class MemoryGraph: # Updated Cypher query to include node types and embeddings cypher = f""" - MERGE (n:{source_type} {{name: $source_name}}) + MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}}) ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding ON MATCH SET n.embedding = $source_embedding - MERGE (m:{destination_type} {{name: $dest_name}}) + MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}}) ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding ON MATCH SET m.embedding = $dest_embedding MERGE (n)-[rel:{relation}]->(m) @@ -106,16 +106,17 @@ class MemoryGraph: "source_name": source, "dest_name": destination, "source_embedding": source_embedding, - "dest_embedding": dest_embedding + "dest_embedding": dest_embedding, + "user_id": filters["user_id"] } _ = self.graph.query(cypher, params=params) - def _search(self, query): + def _search(self, query, filters): search_results = self.llm.generate_response( messages=[ - {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."}, + {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."}, {"role": "user", "content": query}, ], tools = [SEARCH_TOOL] @@ -142,7 +143,7 @@ class MemoryGraph: cypher_query = """ MATCH (n) - WHERE n.embedding IS NOT NULL + WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WITH n, round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * @@ -152,7 +153,7 @@ class MemoryGraph: RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity UNION MATCH (n) - WHERE n.embedding IS NOT NULL + WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WITH n, round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * @@ -162,14 +163,14 @@ class MemoryGraph: RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity ORDER BY similarity DESC """ - params = {"n_embedding": n_embedding, "threshold": self.threshold} + params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]} ans = self.graph.query(cypher_query, params=params) result_relations.extend(ans) return result_relations - def search(self, query): + def search(self, query, filters): """ Search for memories and related graph data. @@ -182,7 +183,7 @@ class MemoryGraph: - "entities": List of related graph data based on the query. """ - search_output = self._search(query) + search_output = self._search(query, filters) if not search_output: return [] @@ -204,15 +205,16 @@ class MemoryGraph: return search_results - def delete_all(self): + def delete_all(self, filters): cypher = """ - MATCH (n) + MATCH (n {user_id: $user_id}) DETACH DELETE n """ - self.graph.query(cypher) + params = {"user_id": filters["user_id"]} + self.graph.query(cypher, params=params) - def get_all(self): + def get_all(self, filters): """ Retrieves all nodes and relationships from the graph database based on optional filtering criteria. @@ -226,10 +228,10 @@ class MemoryGraph: # return all nodes and relationships query = """ - MATCH (n)-[r]->(m) + MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id}) RETURN n.name AS source, type(r) AS relationship, m.name AS target """ - results = self.graph.query(query) + results = self.graph.query(query, params={"user_id": filters["user_id"]}) final_results = [] for result in results: @@ -242,7 +244,7 @@ class MemoryGraph: return final_results - def _update_relationship(self, source, target, relationship): + def _update_relationship(self, source, target, relationship, filters): """ Update or create a relationship between two nodes in the graph. @@ -258,25 +260,25 @@ class MemoryGraph: # Check if nodes exist and create them if they don't check_and_create_query = """ - MERGE (n1 {name: $source}) - MERGE (n2 {name: $target}) + MERGE (n1 {name: $source, user_id: $user_id}) + MERGE (n2 {name: $target, user_id: $user_id}) """ - self.graph.query(check_and_create_query, params={"source": source, "target": target}) + self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]}) # Delete any existing relationship between the nodes delete_query = """ - MATCH (n1 {name: $source})-[r]->(n2 {name: $target}) + MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id}) DELETE r """ - self.graph.query(delete_query, params={"source": source, "target": target}) + self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]}) # Create the new relationship create_query = f""" - MATCH (n1 {{name: $source}}), (n2 {{name: $target}}) + MATCH (n1 {{name: $source, user_id: $user_id}}), (n2 {{name: $target, user_id: $user_id}}) CREATE (n1)-[r:{relationship}]->(n2) RETURN n1, r, n2 """ - result = self.graph.query(create_query, params={"source": source, "target": target}) + result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]}) if not result: raise Exception(f"Failed to update or create relationship between {source} and {target}")