Add loggers for debugging (#1796)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
import json
|
||||
import logging
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from rank_bm25 import BM25Okapi
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
@@ -36,7 +38,7 @@ class MemoryGraph:
|
||||
Returns:
|
||||
dict: A dictionary containing the entities added to the graph.
|
||||
"""
|
||||
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data, filters)
|
||||
|
||||
@@ -50,17 +52,19 @@ class MemoryGraph:
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = [ADD_MESSAGE_TOOL],
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
|
||||
logger.debug(f"Extracted entities: {extracted_entities}")
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
@@ -112,6 +116,8 @@ class MemoryGraph:
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
|
||||
logger.info(f"Added {len(to_be_added)} new memories to the graph")
|
||||
|
||||
|
||||
def _search(self, query, filters):
|
||||
search_results = self.llm.generate_response(
|
||||
@@ -136,6 +142,8 @@ class MemoryGraph:
|
||||
node_list = [node.lower().replace(" ", "_") for node in node_list]
|
||||
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
|
||||
|
||||
logger.debug(f"Node list for search query : {node_list}")
|
||||
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
@@ -168,7 +176,7 @@ class MemoryGraph:
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
|
||||
def search(self, query, filters):
|
||||
"""
|
||||
@@ -202,6 +210,8 @@ class MemoryGraph:
|
||||
"destination": item[2]
|
||||
})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
@@ -212,8 +222,8 @@ class MemoryGraph:
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
|
||||
def get_all(self, filters):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -241,6 +251,8 @@ class MemoryGraph:
|
||||
"target": result['target']
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
@@ -256,6 +268,8 @@ class MemoryGraph:
|
||||
Raises:
|
||||
Exception: If the operation fails.
|
||||
"""
|
||||
logger.info(f"Updating relationship: {source} -{relationship}-> {target}")
|
||||
|
||||
relationship = relationship.lower().replace(" ", "_")
|
||||
|
||||
# Check if nodes exist and create them if they don't
|
||||
|
||||
Reference in New Issue
Block a user