From 0b1ca090f5a88532f01c67112ef885bdc73a8762 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 4 Sep 2024 11:16:18 +0530 Subject: [PATCH] Add loggers for debugging (#1796) --- mem0/memory/graph_memory.py | 28 +++++++++++++++++++++------- mem0/memory/main.py | 10 ++++++++-- mem0/proxy/main.py | 5 +++++ mem0/vector_stores/chroma.py | 3 +++ mem0/vector_stores/pgvector.py | 3 +++ mem0/vector_stores/qdrant.py | 3 +++ 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index a4ecfc81..71352c94 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -1,10 +1,12 @@ -from langchain_community.graphs import Neo4jGraph import json +import logging +from langchain_community.graphs import Neo4jGraph from rank_bm25 import BM25Okapi from mem0.utils.factory import LlmFactory, EmbedderFactory from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL +logger = logging.getLogger(__name__) class MemoryGraph: def __init__(self, config): @@ -36,7 +38,7 @@ class MemoryGraph: Returns: dict: A dictionary containing the entities added to the graph. """ - + # retrieve the search results search_output = self._search(data, filters) @@ -50,17 +52,19 @@ class MemoryGraph: {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)}, {"role": "user", "content": data}, ] - + extracted_entities = self.llm.generate_response( messages=messages, tools = [ADD_MESSAGE_TOOL], ) - if extracted_entities['tool_calls']: + if extracted_entities['tool_calls']: extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities'] else: extracted_entities = [] - + + logger.debug(f"Extracted entities: {extracted_entities}") + update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) memory_updates = self.llm.generate_response( @@ -112,6 +116,8 @@ class MemoryGraph: _ = self.graph.query(cypher, params=params) + logger.info(f"Added {len(to_be_added)} new memories to the graph") + def _search(self, query, filters): search_results = self.llm.generate_response( @@ -136,6 +142,8 @@ class MemoryGraph: node_list = [node.lower().replace(" ", "_") for node in node_list] relation_list = [relation.lower().replace(" ", "_") for relation in relation_list] + logger.debug(f"Node list for search query : {node_list}") + result_relations = [] for node in node_list: @@ -168,7 +176,7 @@ class MemoryGraph: result_relations.extend(ans) return result_relations - + def search(self, query, filters): """ @@ -202,6 +210,8 @@ class MemoryGraph: "destination": item[2] }) + logger.info(f"Returned {len(search_results)} search results") + return search_results @@ -212,8 +222,8 @@ class MemoryGraph: """ params = {"user_id": filters["user_id"]} self.graph.query(cypher, params=params) - + def get_all(self, filters): """ Retrieves all nodes and relationships from the graph database based on optional filtering criteria. @@ -241,6 +251,8 @@ class MemoryGraph: "target": result['target'] }) + logger.info(f"Retrieved {len(final_results)} relationships") + return final_results @@ -256,6 +268,8 @@ class MemoryGraph: Raises: Exception: If the operation fails. """ + logger.info(f"Updating relationship: {source} -{relationship}-> {target}") + relationship = relationship.lower().replace(" ", "_") # Check if nodes exist and create them if they don't diff --git a/mem0/memory/main.py b/mem0/memory/main.py index a0fe56c7..2e073cef 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -21,6 +21,8 @@ import concurrent # Setup user config setup_config() +logger = logging.getLogger(__name__) + class Memory(MemoryBase): def __init__(self, config: MemoryConfig = MemoryConfig()): @@ -50,7 +52,7 @@ class Memory(MemoryBase): try: config = MemoryConfig(**config_dict) except ValidationError as e: - logging.error(f"Configuration validation error: {e}") + logger.error(f"Configuration validation error: {e}") raise return cls(config) @@ -453,6 +455,8 @@ class Memory(MemoryBase): for memory in memories: self._delete_memory(memory.id) + logger.info(f"Deleted {len(memories)} memories") + if self.version == "v1.1" and self.enable_graph: self.graph.delete_all(filters) @@ -491,6 +495,7 @@ class Memory(MemoryBase): return memory_id def _update_memory(self, memory_id, data, metadata=None): + logger.info(f"Updating memory with {data=}") existing_memory = self.vector_store.get(vector_id=memory_id) prev_value = existing_memory.payload.get("data") @@ -515,7 +520,7 @@ class Memory(MemoryBase): vector=embeddings, payload=new_metadata, ) - logging.info(f"Updating memory with ID {memory_id=} with {data=}") + logger.info(f"Updating memory with ID {memory_id=} with {data=}") self.db.add_history( memory_id, prev_value, @@ -536,6 +541,7 @@ class Memory(MemoryBase): """ Reset the memory store. """ + logger.warning("Resetting all memories") self.vector_store.delete_col() self.db.reset() capture_event("mem0.reset", self) diff --git a/mem0/proxy/main.py b/mem0/proxy/main.py index 58a6a730..daf24ae1 100644 --- a/mem0/proxy/main.py +++ b/mem0/proxy/main.py @@ -1,3 +1,4 @@ +import logging import subprocess import sys import httpx @@ -23,6 +24,8 @@ from mem0.memory.telemetry import capture_client_event from mem0 import Memory, MemoryClient from mem0.configs.prompts import MEMORY_ANSWER_PROMPT +logger = logging.getLogger(__name__) + class Mem0: def __init__( @@ -107,6 +110,7 @@ class Completions: relevant_memories = self._fetch_relevant_memories( messages, user_id, agent_id, run_id, filters, limit ) + logger.debug(f"Retrieved {len(relevant_memories)} relevant memories") prepared_messages[-1]["content"] = self._format_query_with_memories( messages, relevant_memories ) @@ -155,6 +159,7 @@ class Completions: self, messages, user_id, agent_id, run_id, metadata, filters ): def add_task(): + logger.debug("Adding to memory asynchronously") self.mem0_client.add( messages=messages, user_id=user_id, diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 791ff8c9..4904aacf 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -24,6 +24,8 @@ except ImportError: from mem0.vector_stores.base import VectorStoreBase +logger = logging.getLogger(__name__) + class OutputData(BaseModel): id: Optional[str] # memory id @@ -151,6 +153,7 @@ class ChromaDB(VectorStoreBase): payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None. ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None. """ + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) def search( diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index 9da5f4a1..7f8c2159 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -1,6 +1,7 @@ import subprocess import sys import json +import logging from typing import Optional, List from pydantic import BaseModel @@ -24,6 +25,7 @@ except ImportError: from mem0.vector_stores.base import VectorStoreBase +logger = logging.getLogger(__name__) class OutputData(BaseModel): id: Optional[str] @@ -102,6 +104,7 @@ class PGVector(VectorStoreBase): payloads (List[Dict], optional): List of payloads corresponding to vectors. ids (List[str], optional): List of IDs corresponding to vectors. """ + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") json_payloads = [json.dumps(payload) for payload in payloads] data = [ diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index 25064336..0c7000b8 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -16,6 +16,8 @@ from qdrant_client.models import ( from mem0.vector_stores.base import VectorStoreBase +logger = logging.getLogger(__name__) + class Qdrant(VectorStoreBase): def __init__( @@ -102,6 +104,7 @@ class Qdrant(VectorStoreBase): payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. ids (list, optional): List of IDs corresponding to vectors. Defaults to None. """ + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") points = [ PointStruct( id=idx if ids is None else ids[idx],