Add loggers for debugging (#1796)

This commit is contained in:
Dev Khant
2024-09-04 11:16:18 +05:30
committed by GitHub
parent bf3ad37369
commit 0b1ca090f5
6 changed files with 43 additions and 9 deletions

View File

@@ -1,10 +1,12 @@
from langchain_community.graphs import Neo4jGraph
import json
import logging
from langchain_community.graphs import Neo4jGraph
from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
@@ -36,7 +38,7 @@ class MemoryGraph:
Returns:
dict: A dictionary containing the entities added to the graph.
"""
# retrieve the search results
search_output = self._search(data, filters)
@@ -50,17 +52,19 @@ class MemoryGraph:
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data},
]
extracted_entities = self.llm.generate_response(
messages=messages,
tools = [ADD_MESSAGE_TOOL],
)
if extracted_entities['tool_calls']:
if extracted_entities['tool_calls']:
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
else:
extracted_entities = []
logger.debug(f"Extracted entities: {extracted_entities}")
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
memory_updates = self.llm.generate_response(
@@ -112,6 +116,8 @@ class MemoryGraph:
_ = self.graph.query(cypher, params=params)
logger.info(f"Added {len(to_be_added)} new memories to the graph")
def _search(self, query, filters):
search_results = self.llm.generate_response(
@@ -136,6 +142,8 @@ class MemoryGraph:
node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
logger.debug(f"Node list for search query : {node_list}")
result_relations = []
for node in node_list:
@@ -168,7 +176,7 @@ class MemoryGraph:
result_relations.extend(ans)
return result_relations
def search(self, query, filters):
"""
@@ -202,6 +210,8 @@ class MemoryGraph:
"destination": item[2]
})
logger.info(f"Returned {len(search_results)} search results")
return search_results
@@ -212,8 +222,8 @@ class MemoryGraph:
"""
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -241,6 +251,8 @@ class MemoryGraph:
"target": result['target']
})
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
@@ -256,6 +268,8 @@ class MemoryGraph:
Raises:
Exception: If the operation fails.
"""
logger.info(f"Updating relationship: {source} -{relationship}-> {target}")
relationship = relationship.lower().replace(" ", "_")
# Check if nodes exist and create them if they don't