diff --git a/mem0/graphs/tools.py b/mem0/graphs/tools.py index 1fdbe91f..90268123 100644 --- a/mem0/graphs/tools.py +++ b/mem0/graphs/tools.py @@ -82,11 +82,11 @@ NOOP_TOOL = { } -ADD_MESSAGE_TOOL = { +RELATIONS_TOOL = { "type": "function", "function": { - "name": "add_query", - "description": "Add new entities and relationships to the graph based on the provided query.", + "name": "establish_relations", + "description": "Establish relationships among the entities based on the provided text.", "parameters": { "type": "object", "properties": { @@ -95,18 +95,23 @@ ADD_MESSAGE_TOOL = { "items": { "type": "object", "properties": { - "source_node": {"type": "string"}, - "source_type": {"type": "string"}, - "relation": {"type": "string"}, - "destination_node": {"type": "string"}, - "destination_type": {"type": "string"}, + "source": { + "type": "string", + "description": "The source entity of the relationship." + }, + "relation": { + "type": "string", + "description": "The relationship between the source and destination entities." + }, + "destination": { + "type": "string", + "description": "The destination entity of the relationship." + }, }, - "required": [ - "source_node", - "source_type", + "required": [ + "source_entity", "relation", - "destination_node", - "destination_type", + "destination_entity", ], "additionalProperties": False, }, @@ -119,29 +124,38 @@ ADD_MESSAGE_TOOL = { } -SEARCH_TOOL = { +EXTRACT_ENTITIES_TOOL = { "type": "function", "function": { - "name": "search", - "description": "Search for nodes and relations in the graph.", + "name": "extract_entities", + "description": "Extract entities and their types from the text.", "parameters": { "type": "object", "properties": { - "nodes": { + "entities": { "type": "array", - "items": {"type": "string"}, - "description": "List of nodes to search for.", - }, - "relations": { - "type": "array", - "items": {"type": "string"}, - "description": "List of relations to search for.", - }, + "items": { + "type": "object", + "properties": { + "entity": { + "type": "string", + "description": "The name or identifier of the entity." + }, + "entity_type": { + "type": "string", + "description": "The type or category of the entity." + } + }, + "required": ["entity", "entity_type"], + "additionalProperties": False + }, + "description": "An array of entities with their types." + } }, - "required": ["nodes", "relations"], - "additionalProperties": False, - }, - }, + "required": ["entities"], + "additionalProperties": False + } + } } UPDATE_MEMORY_STRUCT_TOOL_GRAPH = { @@ -230,12 +244,11 @@ NOOP_STRUCT_TOOL = { }, } - -ADD_MESSAGE_STRUCT_TOOL = { +RELATIONS_STRUCT_TOOL = { "type": "function", "function": { - "name": "add_query", - "description": "Add new entities and relationships to the graph based on the provided query.", + "name": "establish_relations", + "description": "Establish relationships among the entities based on the provided text.", "strict": True, "parameters": { "type": "object", @@ -245,18 +258,23 @@ ADD_MESSAGE_STRUCT_TOOL = { "items": { "type": "object", "properties": { - "source_node": {"type": "string"}, - "source_type": {"type": "string"}, - "relation": {"type": "string"}, - "destination_node": {"type": "string"}, - "destination_type": {"type": "string"}, + "source_entity": { + "type": "string", + "description": "The source entity of the relationship." + }, + "relation": { + "type": "string", + "description": "The relationship between the source and destination entities." + }, + "destination_entity": { + "type": "string", + "description": "The destination entity of the relationship." + }, }, - "required": [ - "source_node", - "source_type", + "required": [ + "source_entity", "relation", - "destination_node", - "destination_type", + "destination_entity", ], "additionalProperties": False, }, @@ -269,28 +287,37 @@ ADD_MESSAGE_STRUCT_TOOL = { } -SEARCH_STRUCT_TOOL = { +EXTRACT_ENTITIES_STRUCT_TOOL = { "type": "function", "function": { - "name": "search", - "description": "Search for nodes and relations in the graph.", + "name": "extract_entities", + "description": "Extract entities and their types from the text.", "strict": True, "parameters": { "type": "object", "properties": { - "nodes": { + "entities": { "type": "array", - "items": {"type": "string"}, - "description": "List of nodes to search for.", - }, - "relations": { - "type": "array", - "items": {"type": "string"}, - "description": "List of relations to search for.", - }, + "items": { + "type": "object", + "properties": { + "entity": { + "type": "string", + "description": "The name or identifier of the entity." + }, + "entity_type": { + "type": "string", + "description": "The type or category of the entity." + } + }, + "required": ["entity", "entity_type"], + "additionalProperties": False + }, + "description": "An array of entities with their types." + } }, - "required": ["nodes", "relations"], - "additionalProperties": False, - }, - }, + "required": ["entities"], + "additionalProperties": False + } + } } diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py index efc14db0..4bfc431b 100644 --- a/mem0/graphs/utils.py +++ b/mem0/graphs/utils.py @@ -18,50 +18,51 @@ Guidelines: 7. Relationship Refinement: Look for opportunities to refine relationship descriptions for greater precision or clarity. 8. Redundancy Elimination: Identify and merge any redundant or highly similar relationships that may result from the update. +Memory Format: +source -- RELATIONSHIP -- destination + Task Details: -- Existing Graph Memories: +======= Existing Graph Memories:======= {existing_memories} -- New Graph Memory: {memory} +======= New Graph Memory:======= +{new_memories} Output: Provide a list of update instructions, each specifying the source, target, and the new relationship to be set. Only include memories that require updates. """ -EXTRACT_ENTITIES_PROMPT = """ +EXTRACT_RELATIONS_PROMPT = """ -You are an advanced algorithm designed to extract structured information from text to construct knowledge graphs. Your goal is to capture comprehensive information while maintaining accuracy. Follow these key principles: +You are an advanced algorithm designed to extract structured information from text to construct knowledge graphs. Your goal is to capture comprehensive and accurate information. Follow these key principles: 1. Extract only explicitly stated information from the text. -2. Identify nodes (entities/concepts), their types, and relationships. -3. Use "USER_ID" as the source node for any self-references (I, me, my, etc.) in user messages. +2. Establish relationships among the entities provided. +3. Use "USER_ID" as the source entity for any self-references (e.g., "I," "me," "my," etc.) in user messages. CUSTOM_PROMPT -Nodes and Types: -- Aim for simplicity and clarity in node representation. -- Use basic, general types for node labels (e.g. "person" instead of "mathematician"). - Relationships: -- Use consistent, general, and timeless relationship types. -- Example: Prefer "PROFESSOR" over "BECAME_PROFESSOR". + - Use consistent, general, and timeless relationship types. + - Example: Prefer "PROFESSOR" over "BECAME_PROFESSOR." + - Relationships should only be established among the entities explicitly mentioned in the user message. Entity Consistency: -- Use the most complete identifier for entities mentioned multiple times. -- Example: Always use "John Doe" instead of variations like "Joe" or pronouns. + - Ensure that relationships are coherent and logically align with the context of the message. + - Maintain consistent naming for entities across the extracted data. -Strive for a coherent, easily understandable knowledge graph by maintaining consistency in entity references and relationship types. +Strive to construct a coherent and easily understandable knowledge graph by eshtablishing all the relationships among the entities and adherence to the user’s context. Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction.""" -def get_update_memory_prompt(existing_memories, memory, template): - return template.format(existing_memories=existing_memories, memory=memory) +def get_update_memory_prompt(existing_memories, new_memories, template): + return template.format(existing_memories=existing_memories, new_memories=new_memories) -def get_update_memory_messages(existing_memories, memory): +def get_update_memory_messages(existing_memories, new_memories): return [ { "role": "user", - "content": get_update_memory_prompt(existing_memories, memory, UPDATE_GRAPH_PROMPT), + "content": get_update_memory_prompt(existing_memories, new_memories, UPDATE_GRAPH_PROMPT), }, ] diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 5fad61a4..e3c65b00 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -15,16 +15,16 @@ except ImportError: from mem0.graphs.tools import ( ADD_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, - ADD_MESSAGE_STRUCT_TOOL, - ADD_MESSAGE_TOOL, + EXTRACT_ENTITIES_STRUCT_TOOL, + EXTRACT_ENTITIES_TOOL, NOOP_STRUCT_TOOL, NOOP_TOOL, - SEARCH_STRUCT_TOOL, - SEARCH_TOOL, + RELATIONS_STRUCT_TOOL, + RELATIONS_TOOL, UPDATE_MEMORY_STRUCT_TOOL_GRAPH, UPDATE_MEMORY_TOOL_GRAPH, ) -from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages +from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_update_memory_messages from mem0.utils.factory import EmbedderFactory, LlmFactory logger = logging.getLogger(__name__) @@ -60,44 +60,14 @@ class MemoryGraph: """ # retrieve the search results - search_output = self._search(data, filters) + search_output, entity_type_map = self._search(data, filters) - if self.config.graph_store.custom_prompt: - messages = [ - { - "role": "system", - "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace( - "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" - ), - }, - {"role": "user", "content": data}, - ] - else: - messages = [ - { - "role": "system", - "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id), - }, - {"role": "user", "content": data}, - ] - - _tools = [ADD_MESSAGE_TOOL] - if self.llm_provider in ["azure_openai_structured", "openai_structured"]: - _tools = [ADD_MESSAGE_STRUCT_TOOL] - - extracted_entities = self.llm.generate_response( - messages=messages, - tools=_tools, - ) - - if extracted_entities["tool_calls"]: - extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"] - else: - extracted_entities = [] - - logger.debug(f"Extracted entities: {extracted_entities}") + # extract relations + extracted_relations = self._extract_relations(data, filters, entity_type_map) + search_output_string = format_entities(search_output) - update_memory_prompt = get_update_memory_messages(search_output_string, extracted_entities) + extracted_relations_string = format_entities(extracted_relations) + update_memory_prompt = get_update_memory_messages(search_output_string, extracted_relations_string) _tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: @@ -170,37 +140,33 @@ class MemoryGraph: return returned_entities def _search(self, query, filters, limit=100): - _tools = [SEARCH_TOOL] + _tools = [EXTRACT_ENTITIES_TOOL] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: - _tools = [SEARCH_STRUCT_TOOL] + _tools = [EXTRACT_ENTITIES_STRUCT_TOOL] search_results = self.llm.generate_response( messages=[ { "role": "system", - "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities. ***DO NOT*** answer the question itself if the given text is a question.", + "content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.", }, {"role": "user", "content": query}, ], tools=_tools, ) - node_list = [] + entity_type_map = {} - for item in search_results["tool_calls"]: - if item["name"] == "search": - try: - node_list.extend(item["arguments"]["nodes"]) - except Exception as e: - logger.error(f"Error in search tool: {e}") + try: + for item in search_results["tool_calls"][0]["arguments"]["entities"]: + entity_type_map[item["entity"]] = item["entity_type"] + except Exception as e: + logger.error(f"Error in search tool: {e}") - node_list = list(set(node_list)) - node_list = [node.lower().replace(" ", "_") for node in node_list] - - logger.debug(f"Node list for search query : {node_list}") + logger.debug(f"Entity type map: {entity_type_map}") result_relations = [] - for node in node_list: + for node in list(entity_type_map.keys()): n_embedding = self.embedding_model.embed(node) cypher_query = """ @@ -235,7 +201,7 @@ class MemoryGraph: ans = self.graph.query(cypher_query, params=params) result_relations.extend(ans) - return result_relations + return result_relations, entity_type_map def search(self, query, filters, limit=100): """ @@ -252,7 +218,7 @@ class MemoryGraph: - "entities": List of related graph data based on the query. """ - search_output = self._search(query, filters, limit) + search_output, entity_type_map = self._search(query, filters, limit) if not search_output: return [] @@ -314,6 +280,45 @@ class MemoryGraph: return final_results + def _extract_relations(self, data, filters, entity_type_map, limit=100): + + if self.config.graph_store.custom_prompt: + messages = [ + { + "role": "system", + "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace( + "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" + ), + }, + {"role": "user", "content": data}, + ] + else: + messages = [ + { + "role": "system", + "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]), + }, + {"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"}, + ] + + _tools = [RELATIONS_TOOL] + if self.llm_provider in ["azure_openai_structured", "openai_structured"]: + _tools = [RELATIONS_STRUCT_TOOL] + + extracted_entities = self.llm.generate_response( + messages=messages, + tools=_tools, + ) + + if extracted_entities["tool_calls"]: + extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"] + else: + extracted_entities = [] + + logger.debug(f"Extracted entities: {extracted_entities}") + + return extracted_entities + def _update_relationship(self, source, target, relationship, filters): """ Update or create a relationship between two nodes in the graph. diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py index 5b0a2a1c..f889abe9 100644 --- a/mem0/memory/utils.py +++ b/mem0/memory/utils.py @@ -24,11 +24,7 @@ def format_entities(entities): formatted_lines = [] for entity in entities: - simplified = { - "source": entity["source"], - "relation": entity["relation"], - "destination": entity["destination"] - } - formatted_lines.append(json.dumps(simplified)) + simplified = f"{entity['source']} -- {entity['relation'].upper()} -- {entity['destination']}" + formatted_lines.append(simplified) return "\n".join(formatted_lines) \ No newline at end of file