[Misc] Lint code and fix code smells (#1871)
This commit is contained in:
@@ -3,30 +3,28 @@ import logging
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
ADD_MEMORY_TOOL_GRAPH,
|
||||
ADD_MESSAGE_TOOL,
|
||||
NOOP_TOOL,
|
||||
SEARCH_TOOL,
|
||||
UPDATE_MEMORY_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
NOOP_STRUCT_TOOL,
|
||||
ADD_MESSAGE_STRUCT_TOOL,
|
||||
SEARCH_STRUCT_TOOL
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||
from mem0.graphs.tools import (ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_TOOL_GRAPH, ADD_MESSAGE_STRUCT_TOOL,
|
||||
ADD_MESSAGE_TOOL, NOOP_STRUCT_TOOL, NOOP_TOOL,
|
||||
SEARCH_STRUCT_TOOL, SEARCH_TOOL,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_TOOL_GRAPH)
|
||||
from mem0.graphs.utils import (EXTRACT_ENTITIES_PROMPT,
|
||||
get_update_memory_messages)
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
self.graph = Neo4jGraph(
|
||||
self.config.graph_store.config.url,
|
||||
self.config.graph_store.config.username,
|
||||
self.config.graph_store.config.password,
|
||||
)
|
||||
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
|
||||
|
||||
self.llm_provider = "openai_structured"
|
||||
if self.config.llm.provider:
|
||||
@@ -51,15 +49,23 @@ class MemoryGraph:
|
||||
search_output = self._search(data, filters)
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
]
|
||||
|
||||
_tools = [ADD_MESSAGE_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -67,11 +73,11 @@ class MemoryGraph:
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = _tools,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
if extracted_entities["tool_calls"]:
|
||||
extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
@@ -79,9 +85,13 @@ class MemoryGraph:
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
|
||||
_tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured","openai_structured"]:
|
||||
_tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
|
||||
_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
NOOP_STRUCT_TOOL,
|
||||
]
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=update_memory_prompt,
|
||||
@@ -90,28 +100,29 @@ class MemoryGraph:
|
||||
|
||||
to_be_added = []
|
||||
|
||||
for item in memory_updates['tool_calls']:
|
||||
if item['name'] == "add_graph_memory":
|
||||
to_be_added.append(item['arguments'])
|
||||
elif item['name'] == "update_graph_memory":
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
|
||||
elif item['name'] == "noop":
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "add_graph_memory":
|
||||
to_be_added.append(item["arguments"])
|
||||
elif item["name"] == "update_graph_memory":
|
||||
self._update_relationship(
|
||||
item["arguments"]["source"],
|
||||
item["arguments"]["destination"],
|
||||
item["arguments"]["relationship"],
|
||||
filters,
|
||||
)
|
||||
elif item["name"] == "noop":
|
||||
continue
|
||||
|
||||
returned_entities = []
|
||||
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
relation = item['relationship'].lower().replace(" ", "_")
|
||||
destination = item['destination'].lower().replace(" ", "_")
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
source = item["source"].lower().replace(" ", "_")
|
||||
source_type = item["source_type"].lower().replace(" ", "_")
|
||||
relation = item["relationship"].lower().replace(" ", "_")
|
||||
destination = item["destination"].lower().replace(" ", "_")
|
||||
destination_type = item["destination_type"].lower().replace(" ", "_")
|
||||
|
||||
returned_entities.append({
|
||||
"source" : source,
|
||||
"relationship" : relation,
|
||||
"target" : destination
|
||||
})
|
||||
returned_entities.append({"source": source, "relationship": relation, "target": destination})
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
@@ -135,7 +146,7 @@ class MemoryGraph:
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": filters["user_id"]
|
||||
"user_id": filters["user_id"],
|
||||
}
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
@@ -150,19 +161,22 @@ class MemoryGraph:
|
||||
_tools = [SEARCH_STRUCT_TOOL]
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
tools = _tools
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
node_list = []
|
||||
relation_list = []
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
for item in search_results["tool_calls"]:
|
||||
if item["name"] == "search":
|
||||
try:
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
node_list.extend(item["arguments"]["nodes"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
|
||||
@@ -201,13 +215,16 @@ class MemoryGraph:
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
"""
|
||||
params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
}
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
def search(self, query, filters):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
@@ -235,17 +252,12 @@ class MemoryGraph:
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item[0],
|
||||
"relationship": item[1],
|
||||
"target": item[2]
|
||||
})
|
||||
search_results.append({"source": item[0], "relationship": item[1], "target": item[2]})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher = """
|
||||
MATCH (n {user_id: $user_id})
|
||||
@@ -254,7 +266,6 @@ class MemoryGraph:
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def get_all(self, filters):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -276,17 +287,18 @@ class MemoryGraph:
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
final_results.append({
|
||||
"source": result['source'],
|
||||
"relationship": result['relationship'],
|
||||
"target": result['target']
|
||||
})
|
||||
final_results.append(
|
||||
{
|
||||
"source": result["source"],
|
||||
"relationship": result["relationship"],
|
||||
"target": result["target"],
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
|
||||
def _update_relationship(self, source, target, relationship, filters):
|
||||
"""
|
||||
Update or create a relationship between two nodes in the graph.
|
||||
@@ -309,14 +321,20 @@ class MemoryGraph:
|
||||
MERGE (n1 {name: $source, user_id: $user_id})
|
||||
MERGE (n2 {name: $target, user_id: $user_id})
|
||||
"""
|
||||
self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
self.graph.query(
|
||||
check_and_create_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
# Delete any existing relationship between the nodes
|
||||
delete_query = """
|
||||
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
|
||||
DELETE r
|
||||
"""
|
||||
self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
self.graph.query(
|
||||
delete_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
# Create the new relationship
|
||||
create_query = f"""
|
||||
@@ -324,7 +342,10 @@ class MemoryGraph:
|
||||
CREATE (n1)-[r:{relationship}]->(n2)
|
||||
RETURN n1, r, n2
|
||||
"""
|
||||
result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
result = self.graph.query(
|
||||
create_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
||||
|
||||
Reference in New Issue
Block a user