Add loggers for debugging (#1796)

This commit is contained in:
Dev Khant
2024-09-04 11:16:18 +05:30
committed by GitHub
parent bf3ad37369
commit 0b1ca090f5
6 changed files with 43 additions and 9 deletions

View File

@@ -1,10 +1,12 @@
from langchain_community.graphs import Neo4jGraph
import json
import logging
from langchain_community.graphs import Neo4jGraph
from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
@@ -36,7 +38,7 @@ class MemoryGraph:
Returns:
dict: A dictionary containing the entities added to the graph.
"""
# retrieve the search results
search_output = self._search(data, filters)
@@ -50,17 +52,19 @@ class MemoryGraph:
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data},
]
extracted_entities = self.llm.generate_response(
messages=messages,
tools = [ADD_MESSAGE_TOOL],
)
if extracted_entities['tool_calls']:
if extracted_entities['tool_calls']:
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
else:
extracted_entities = []
logger.debug(f"Extracted entities: {extracted_entities}")
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
memory_updates = self.llm.generate_response(
@@ -112,6 +116,8 @@ class MemoryGraph:
_ = self.graph.query(cypher, params=params)
logger.info(f"Added {len(to_be_added)} new memories to the graph")
def _search(self, query, filters):
search_results = self.llm.generate_response(
@@ -136,6 +142,8 @@ class MemoryGraph:
node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
logger.debug(f"Node list for search query : {node_list}")
result_relations = []
for node in node_list:
@@ -168,7 +176,7 @@ class MemoryGraph:
result_relations.extend(ans)
return result_relations
def search(self, query, filters):
"""
@@ -202,6 +210,8 @@ class MemoryGraph:
"destination": item[2]
})
logger.info(f"Returned {len(search_results)} search results")
return search_results
@@ -212,8 +222,8 @@ class MemoryGraph:
"""
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -241,6 +251,8 @@ class MemoryGraph:
"target": result['target']
})
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
@@ -256,6 +268,8 @@ class MemoryGraph:
Raises:
Exception: If the operation fails.
"""
logger.info(f"Updating relationship: {source} -{relationship}-> {target}")
relationship = relationship.lower().replace(" ", "_")
# Check if nodes exist and create them if they don't

View File

@@ -21,6 +21,8 @@ import concurrent
# Setup user config
setup_config()
logger = logging.getLogger(__name__)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
@@ -50,7 +52,7 @@ class Memory(MemoryBase):
try:
config = MemoryConfig(**config_dict)
except ValidationError as e:
logging.error(f"Configuration validation error: {e}")
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
@@ -453,6 +455,8 @@ class Memory(MemoryBase):
for memory in memories:
self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
@@ -491,6 +495,7 @@ class Memory(MemoryBase):
return memory_id
def _update_memory(self, memory_id, data, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload.get("data")
@@ -515,7 +520,7 @@ class Memory(MemoryBase):
vector=embeddings,
payload=new_metadata,
)
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
@@ -536,6 +541,7 @@ class Memory(MemoryBase):
"""
Reset the memory store.
"""
logger.warning("Resetting all memories")
self.vector_store.delete_col()
self.db.reset()
capture_event("mem0.reset", self)

View File

@@ -1,3 +1,4 @@
import logging
import subprocess
import sys
import httpx
@@ -23,6 +24,8 @@ from mem0.memory.telemetry import capture_client_event
from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
logger = logging.getLogger(__name__)
class Mem0:
def __init__(
@@ -107,6 +110,7 @@ class Completions:
relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit
)
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories
)
@@ -155,6 +159,7 @@ class Completions:
self, messages, user_id, agent_id, run_id, metadata, filters
):
def add_task():
logger.debug("Adding to memory asynchronously")
self.mem0_client.add(
messages=messages,
user_id=user_id,

View File

@@ -24,6 +24,8 @@ except ImportError:
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
@@ -151,6 +153,7 @@ class ChromaDB(VectorStoreBase):
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
def search(

View File

@@ -1,6 +1,7 @@
import subprocess
import sys
import json
import logging
from typing import Optional, List
from pydantic import BaseModel
@@ -24,6 +25,7 @@ except ImportError:
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
@@ -102,6 +104,7 @@ class PGVector(VectorStoreBase):
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
json_payloads = [json.dumps(payload) for payload in payloads]
data = [

View File

@@ -16,6 +16,8 @@ from qdrant_client.models import (
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class Qdrant(VectorStoreBase):
def __init__(
@@ -102,6 +104,7 @@ class Qdrant(VectorStoreBase):
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
points = [
PointStruct(
id=idx if ids is None else ids[idx],