[improvement]: Graph nodes extraction improved (#2035)
This commit is contained in:
@@ -15,16 +15,16 @@ except ImportError:
|
||||
from mem0.graphs.tools import (
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_TOOL_GRAPH,
|
||||
ADD_MESSAGE_STRUCT_TOOL,
|
||||
ADD_MESSAGE_TOOL,
|
||||
EXTRACT_ENTITIES_STRUCT_TOOL,
|
||||
EXTRACT_ENTITIES_TOOL,
|
||||
NOOP_STRUCT_TOOL,
|
||||
NOOP_TOOL,
|
||||
SEARCH_STRUCT_TOOL,
|
||||
SEARCH_TOOL,
|
||||
RELATIONS_STRUCT_TOOL,
|
||||
RELATIONS_TOOL,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_TOOL_GRAPH,
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_update_memory_messages
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,44 +60,14 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data, filters)
|
||||
search_output, entity_type_map = self._search(data, filters)
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
|
||||
_tools = [ADD_MESSAGE_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [ADD_MESSAGE_STRUCT_TOOL]
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
if extracted_entities["tool_calls"]:
|
||||
extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
logger.debug(f"Extracted entities: {extracted_entities}")
|
||||
# extract relations
|
||||
extracted_relations = self._extract_relations(data, filters, entity_type_map)
|
||||
|
||||
search_output_string = format_entities(search_output)
|
||||
update_memory_prompt = get_update_memory_messages(search_output_string, extracted_entities)
|
||||
extracted_relations_string = format_entities(extracted_relations)
|
||||
update_memory_prompt = get_update_memory_messages(search_output_string, extracted_relations_string)
|
||||
|
||||
_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -170,37 +140,33 @@ class MemoryGraph:
|
||||
return returned_entities
|
||||
|
||||
def _search(self, query, filters, limit=100):
|
||||
_tools = [SEARCH_TOOL]
|
||||
_tools = [EXTRACT_ENTITIES_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [SEARCH_STRUCT_TOOL]
|
||||
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities. ***DO NOT*** answer the question itself if the given text is a question.",
|
||||
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
node_list = []
|
||||
entity_type_map = {}
|
||||
|
||||
for item in search_results["tool_calls"]:
|
||||
if item["name"] == "search":
|
||||
try:
|
||||
node_list.extend(item["arguments"]["nodes"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
try:
|
||||
for item in search_results["tool_calls"][0]["arguments"]["entities"]:
|
||||
entity_type_map[item["entity"]] = item["entity_type"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
|
||||
node_list = list(set(node_list))
|
||||
node_list = [node.lower().replace(" ", "_") for node in node_list]
|
||||
|
||||
logger.debug(f"Node list for search query : {node_list}")
|
||||
logger.debug(f"Entity type map: {entity_type_map}")
|
||||
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
for node in list(entity_type_map.keys()):
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
|
||||
cypher_query = """
|
||||
@@ -235,7 +201,7 @@ class MemoryGraph:
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
return result_relations, entity_type_map
|
||||
|
||||
def search(self, query, filters, limit=100):
|
||||
"""
|
||||
@@ -252,7 +218,7 @@ class MemoryGraph:
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
|
||||
search_output = self._search(query, filters, limit)
|
||||
search_output, entity_type_map = self._search(query, filters, limit)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
@@ -314,6 +280,45 @@ class MemoryGraph:
|
||||
|
||||
return final_results
|
||||
|
||||
def _extract_relations(self, data, filters, entity_type_map, limit=100):
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
||||
},
|
||||
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
|
||||
]
|
||||
|
||||
_tools = [RELATIONS_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [RELATIONS_STRUCT_TOOL]
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
if extracted_entities["tool_calls"]:
|
||||
extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
logger.debug(f"Extracted entities: {extracted_entities}")
|
||||
|
||||
return extracted_entities
|
||||
|
||||
def _update_relationship(self, source, target, relationship, filters):
|
||||
"""
|
||||
Update or create a relationship between two nodes in the graph.
|
||||
|
||||
@@ -24,11 +24,7 @@ def format_entities(entities):
|
||||
|
||||
formatted_lines = []
|
||||
for entity in entities:
|
||||
simplified = {
|
||||
"source": entity["source"],
|
||||
"relation": entity["relation"],
|
||||
"destination": entity["destination"]
|
||||
}
|
||||
formatted_lines.append(json.dumps(simplified))
|
||||
simplified = f"{entity['source']} -- {entity['relation'].upper()} -- {entity['destination']}"
|
||||
formatted_lines.append(simplified)
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
Reference in New Issue
Block a user