Add loggers for debugging (#1796)

This commit is contained in:
Dev Khant
2024-09-04 11:16:18 +05:30
committed by GitHub
parent bf3ad37369
commit 0b1ca090f5
6 changed files with 43 additions and 9 deletions

View File

@@ -1,10 +1,12 @@
from langchain_community.graphs import Neo4jGraph
import json import json
import logging
from langchain_community.graphs import Neo4jGraph
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
logger = logging.getLogger(__name__)
class MemoryGraph: class MemoryGraph:
def __init__(self, config): def __init__(self, config):
@@ -36,7 +38,7 @@ class MemoryGraph:
Returns: Returns:
dict: A dictionary containing the entities added to the graph. dict: A dictionary containing the entities added to the graph.
""" """
# retrieve the search results # retrieve the search results
search_output = self._search(data, filters) search_output = self._search(data, filters)
@@ -50,17 +52,19 @@ class MemoryGraph:
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)}, {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data}, {"role": "user", "content": data},
] ]
extracted_entities = self.llm.generate_response( extracted_entities = self.llm.generate_response(
messages=messages, messages=messages,
tools = [ADD_MESSAGE_TOOL], tools = [ADD_MESSAGE_TOOL],
) )
if extracted_entities['tool_calls']: if extracted_entities['tool_calls']:
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities'] extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
else: else:
extracted_entities = [] extracted_entities = []
logger.debug(f"Extracted entities: {extracted_entities}")
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
memory_updates = self.llm.generate_response( memory_updates = self.llm.generate_response(
@@ -112,6 +116,8 @@ class MemoryGraph:
_ = self.graph.query(cypher, params=params) _ = self.graph.query(cypher, params=params)
logger.info(f"Added {len(to_be_added)} new memories to the graph")
def _search(self, query, filters): def _search(self, query, filters):
search_results = self.llm.generate_response( search_results = self.llm.generate_response(
@@ -136,6 +142,8 @@ class MemoryGraph:
node_list = [node.lower().replace(" ", "_") for node in node_list] node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list] relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
logger.debug(f"Node list for search query : {node_list}")
result_relations = [] result_relations = []
for node in node_list: for node in node_list:
@@ -168,7 +176,7 @@ class MemoryGraph:
result_relations.extend(ans) result_relations.extend(ans)
return result_relations return result_relations
def search(self, query, filters): def search(self, query, filters):
""" """
@@ -202,6 +210,8 @@ class MemoryGraph:
"destination": item[2] "destination": item[2]
}) })
logger.info(f"Returned {len(search_results)} search results")
return search_results return search_results
@@ -212,8 +222,8 @@ class MemoryGraph:
""" """
params = {"user_id": filters["user_id"]} params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params) self.graph.query(cypher, params=params)
def get_all(self, filters): def get_all(self, filters):
""" """
Retrieves all nodes and relationships from the graph database based on optional filtering criteria. Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -241,6 +251,8 @@ class MemoryGraph:
"target": result['target'] "target": result['target']
}) })
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results return final_results
@@ -256,6 +268,8 @@ class MemoryGraph:
Raises: Raises:
Exception: If the operation fails. Exception: If the operation fails.
""" """
logger.info(f"Updating relationship: {source} -{relationship}-> {target}")
relationship = relationship.lower().replace(" ", "_") relationship = relationship.lower().replace(" ", "_")
# Check if nodes exist and create them if they don't # Check if nodes exist and create them if they don't

View File

@@ -21,6 +21,8 @@ import concurrent
# Setup user config # Setup user config
setup_config() setup_config()
logger = logging.getLogger(__name__)
class Memory(MemoryBase): class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()): def __init__(self, config: MemoryConfig = MemoryConfig()):
@@ -50,7 +52,7 @@ class Memory(MemoryBase):
try: try:
config = MemoryConfig(**config_dict) config = MemoryConfig(**config_dict)
except ValidationError as e: except ValidationError as e:
logging.error(f"Configuration validation error: {e}") logger.error(f"Configuration validation error: {e}")
raise raise
return cls(config) return cls(config)
@@ -453,6 +455,8 @@ class Memory(MemoryBase):
for memory in memories: for memory in memories:
self._delete_memory(memory.id) self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
if self.version == "v1.1" and self.enable_graph: if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters) self.graph.delete_all(filters)
@@ -491,6 +495,7 @@ class Memory(MemoryBase):
return memory_id return memory_id
def _update_memory(self, memory_id, data, metadata=None): def _update_memory(self, memory_id, data, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id) existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload.get("data") prev_value = existing_memory.payload.get("data")
@@ -515,7 +520,7 @@ class Memory(MemoryBase):
vector=embeddings, vector=embeddings,
payload=new_metadata, payload=new_metadata,
) )
logging.info(f"Updating memory with ID {memory_id=} with {data=}") logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history( self.db.add_history(
memory_id, memory_id,
prev_value, prev_value,
@@ -536,6 +541,7 @@ class Memory(MemoryBase):
""" """
Reset the memory store. Reset the memory store.
""" """
logger.warning("Resetting all memories")
self.vector_store.delete_col() self.vector_store.delete_col()
self.db.reset() self.db.reset()
capture_event("mem0.reset", self) capture_event("mem0.reset", self)

View File

@@ -1,3 +1,4 @@
import logging
import subprocess import subprocess
import sys import sys
import httpx import httpx
@@ -23,6 +24,8 @@ from mem0.memory.telemetry import capture_client_event
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
logger = logging.getLogger(__name__)
class Mem0: class Mem0:
def __init__( def __init__(
@@ -107,6 +110,7 @@ class Completions:
relevant_memories = self._fetch_relevant_memories( relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit messages, user_id, agent_id, run_id, filters, limit
) )
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
prepared_messages[-1]["content"] = self._format_query_with_memories( prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories messages, relevant_memories
) )
@@ -155,6 +159,7 @@ class Completions:
self, messages, user_id, agent_id, run_id, metadata, filters self, messages, user_id, agent_id, run_id, metadata, filters
): ):
def add_task(): def add_task():
logger.debug("Adding to memory asynchronously")
self.mem0_client.add( self.mem0_client.add(
messages=messages, messages=messages,
user_id=user_id, user_id=user_id,

View File

@@ -24,6 +24,8 @@ except ImportError:
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel): class OutputData(BaseModel):
id: Optional[str] # memory id id: Optional[str] # memory id
@@ -151,6 +153,7 @@ class ChromaDB(VectorStoreBase):
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None. payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None. ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
""" """
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search( def search(

View File

@@ -1,6 +1,7 @@
import subprocess import subprocess
import sys import sys
import json import json
import logging
from typing import Optional, List from typing import Optional, List
from pydantic import BaseModel from pydantic import BaseModel
@@ -24,6 +25,7 @@ except ImportError:
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel): class OutputData(BaseModel):
id: Optional[str] id: Optional[str]
@@ -102,6 +104,7 @@ class PGVector(VectorStoreBase):
payloads (List[Dict], optional): List of payloads corresponding to vectors. payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors. ids (List[str], optional): List of IDs corresponding to vectors.
""" """
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
json_payloads = [json.dumps(payload) for payload in payloads] json_payloads = [json.dumps(payload) for payload in payloads]
data = [ data = [

View File

@@ -16,6 +16,8 @@ from qdrant_client.models import (
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class Qdrant(VectorStoreBase): class Qdrant(VectorStoreBase):
def __init__( def __init__(
@@ -102,6 +104,7 @@ class Qdrant(VectorStoreBase):
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None. ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
""" """
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
points = [ points = [
PointStruct( PointStruct(
id=idx if ids is None else ids[idx], id=idx if ids is None else ids[idx],