Add loggers for debugging (#1796)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
import json
|
||||
import logging
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from rank_bm25 import BM25Okapi
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
@@ -36,7 +38,7 @@ class MemoryGraph:
|
||||
Returns:
|
||||
dict: A dictionary containing the entities added to the graph.
|
||||
"""
|
||||
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data, filters)
|
||||
|
||||
@@ -50,17 +52,19 @@ class MemoryGraph:
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = [ADD_MESSAGE_TOOL],
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
|
||||
logger.debug(f"Extracted entities: {extracted_entities}")
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
@@ -112,6 +116,8 @@ class MemoryGraph:
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
|
||||
logger.info(f"Added {len(to_be_added)} new memories to the graph")
|
||||
|
||||
|
||||
def _search(self, query, filters):
|
||||
search_results = self.llm.generate_response(
|
||||
@@ -136,6 +142,8 @@ class MemoryGraph:
|
||||
node_list = [node.lower().replace(" ", "_") for node in node_list]
|
||||
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
|
||||
|
||||
logger.debug(f"Node list for search query : {node_list}")
|
||||
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
@@ -168,7 +176,7 @@ class MemoryGraph:
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
|
||||
def search(self, query, filters):
|
||||
"""
|
||||
@@ -202,6 +210,8 @@ class MemoryGraph:
|
||||
"destination": item[2]
|
||||
})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
@@ -212,8 +222,8 @@ class MemoryGraph:
|
||||
"""
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
|
||||
def get_all(self, filters):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -241,6 +251,8 @@ class MemoryGraph:
|
||||
"target": result['target']
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
@@ -256,6 +268,8 @@ class MemoryGraph:
|
||||
Raises:
|
||||
Exception: If the operation fails.
|
||||
"""
|
||||
logger.info(f"Updating relationship: {source} -{relationship}-> {target}")
|
||||
|
||||
relationship = relationship.lower().replace(" ", "_")
|
||||
|
||||
# Check if nodes exist and create them if they don't
|
||||
|
||||
@@ -21,6 +21,8 @@ import concurrent
|
||||
# Setup user config
|
||||
setup_config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Memory(MemoryBase):
|
||||
def __init__(self, config: MemoryConfig = MemoryConfig()):
|
||||
@@ -50,7 +52,7 @@ class Memory(MemoryBase):
|
||||
try:
|
||||
config = MemoryConfig(**config_dict)
|
||||
except ValidationError as e:
|
||||
logging.error(f"Configuration validation error: {e}")
|
||||
logger.error(f"Configuration validation error: {e}")
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
@@ -453,6 +455,8 @@ class Memory(MemoryBase):
|
||||
for memory in memories:
|
||||
self._delete_memory(memory.id)
|
||||
|
||||
logger.info(f"Deleted {len(memories)} memories")
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all(filters)
|
||||
|
||||
@@ -491,6 +495,7 @@ class Memory(MemoryBase):
|
||||
return memory_id
|
||||
|
||||
def _update_memory(self, memory_id, data, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
prev_value = existing_memory.payload.get("data")
|
||||
|
||||
@@ -515,7 +520,7 @@ class Memory(MemoryBase):
|
||||
vector=embeddings,
|
||||
payload=new_metadata,
|
||||
)
|
||||
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
self.db.add_history(
|
||||
memory_id,
|
||||
prev_value,
|
||||
@@ -536,6 +541,7 @@ class Memory(MemoryBase):
|
||||
"""
|
||||
Reset the memory store.
|
||||
"""
|
||||
logger.warning("Resetting all memories")
|
||||
self.vector_store.delete_col()
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import httpx
|
||||
@@ -23,6 +24,8 @@ from mem0.memory.telemetry import capture_client_event
|
||||
from mem0 import Memory, MemoryClient
|
||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Mem0:
|
||||
def __init__(
|
||||
@@ -107,6 +110,7 @@ class Completions:
|
||||
relevant_memories = self._fetch_relevant_memories(
|
||||
messages, user_id, agent_id, run_id, filters, limit
|
||||
)
|
||||
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
|
||||
prepared_messages[-1]["content"] = self._format_query_with_memories(
|
||||
messages, relevant_memories
|
||||
)
|
||||
@@ -155,6 +159,7 @@ class Completions:
|
||||
self, messages, user_id, agent_id, run_id, metadata, filters
|
||||
):
|
||||
def add_task():
|
||||
logger.debug("Adding to memory asynchronously")
|
||||
self.mem0_client.add(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -24,6 +24,8 @@ except ImportError:
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
@@ -151,6 +153,7 @@ class ChromaDB(VectorStoreBase):
|
||||
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)
|
||||
|
||||
def search(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -24,6 +25,7 @@ except ImportError:
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
@@ -102,6 +104,7 @@ class PGVector(VectorStoreBase):
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
json_payloads = [json.dumps(payload) for payload in payloads]
|
||||
|
||||
data = [
|
||||
|
||||
@@ -16,6 +16,8 @@ from qdrant_client.models import (
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qdrant(VectorStoreBase):
|
||||
def __init__(
|
||||
@@ -102,6 +104,7 @@ class Qdrant(VectorStoreBase):
|
||||
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx if ids is None else ids[idx],
|
||||
|
||||
Reference in New Issue
Block a user