[Misc] Lint code and fix code smells (#1871)
This commit is contained in:
@@ -3,30 +3,28 @@ import logging
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
ADD_MEMORY_TOOL_GRAPH,
|
||||
ADD_MESSAGE_TOOL,
|
||||
NOOP_TOOL,
|
||||
SEARCH_TOOL,
|
||||
UPDATE_MEMORY_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
NOOP_STRUCT_TOOL,
|
||||
ADD_MESSAGE_STRUCT_TOOL,
|
||||
SEARCH_STRUCT_TOOL
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||
from mem0.graphs.tools import (ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_TOOL_GRAPH, ADD_MESSAGE_STRUCT_TOOL,
|
||||
ADD_MESSAGE_TOOL, NOOP_STRUCT_TOOL, NOOP_TOOL,
|
||||
SEARCH_STRUCT_TOOL, SEARCH_TOOL,
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
UPDATE_MEMORY_TOOL_GRAPH)
|
||||
from mem0.graphs.utils import (EXTRACT_ENTITIES_PROMPT,
|
||||
get_update_memory_messages)
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
self.graph = Neo4jGraph(
|
||||
self.config.graph_store.config.url,
|
||||
self.config.graph_store.config.username,
|
||||
self.config.graph_store.config.password,
|
||||
)
|
||||
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
|
||||
|
||||
self.llm_provider = "openai_structured"
|
||||
if self.config.llm.provider:
|
||||
@@ -51,15 +49,23 @@ class MemoryGraph:
|
||||
search_output = self._search(data, filters)
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace(
|
||||
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id),
|
||||
},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
]
|
||||
|
||||
_tools = [ADD_MESSAGE_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
@@ -67,11 +73,11 @@ class MemoryGraph:
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = _tools,
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
if extracted_entities["tool_calls"]:
|
||||
extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
@@ -79,9 +85,13 @@ class MemoryGraph:
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
|
||||
_tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured","openai_structured"]:
|
||||
_tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
|
||||
_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
||||
_tools = [
|
||||
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
ADD_MEMORY_STRUCT_TOOL_GRAPH,
|
||||
NOOP_STRUCT_TOOL,
|
||||
]
|
||||
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=update_memory_prompt,
|
||||
@@ -90,28 +100,29 @@ class MemoryGraph:
|
||||
|
||||
to_be_added = []
|
||||
|
||||
for item in memory_updates['tool_calls']:
|
||||
if item['name'] == "add_graph_memory":
|
||||
to_be_added.append(item['arguments'])
|
||||
elif item['name'] == "update_graph_memory":
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
|
||||
elif item['name'] == "noop":
|
||||
for item in memory_updates["tool_calls"]:
|
||||
if item["name"] == "add_graph_memory":
|
||||
to_be_added.append(item["arguments"])
|
||||
elif item["name"] == "update_graph_memory":
|
||||
self._update_relationship(
|
||||
item["arguments"]["source"],
|
||||
item["arguments"]["destination"],
|
||||
item["arguments"]["relationship"],
|
||||
filters,
|
||||
)
|
||||
elif item["name"] == "noop":
|
||||
continue
|
||||
|
||||
returned_entities = []
|
||||
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
relation = item['relationship'].lower().replace(" ", "_")
|
||||
destination = item['destination'].lower().replace(" ", "_")
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
source = item["source"].lower().replace(" ", "_")
|
||||
source_type = item["source_type"].lower().replace(" ", "_")
|
||||
relation = item["relationship"].lower().replace(" ", "_")
|
||||
destination = item["destination"].lower().replace(" ", "_")
|
||||
destination_type = item["destination_type"].lower().replace(" ", "_")
|
||||
|
||||
returned_entities.append({
|
||||
"source" : source,
|
||||
"relationship" : relation,
|
||||
"target" : destination
|
||||
})
|
||||
returned_entities.append({"source": source, "relationship": relation, "target": destination})
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
@@ -135,7 +146,7 @@ class MemoryGraph:
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": filters["user_id"]
|
||||
"user_id": filters["user_id"],
|
||||
}
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
@@ -150,19 +161,22 @@ class MemoryGraph:
|
||||
_tools = [SEARCH_STRUCT_TOOL]
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.",
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
tools = _tools
|
||||
tools=_tools,
|
||||
)
|
||||
|
||||
node_list = []
|
||||
relation_list = []
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
for item in search_results["tool_calls"]:
|
||||
if item["name"] == "search":
|
||||
try:
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
node_list.extend(item["arguments"]["nodes"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
|
||||
@@ -201,13 +215,16 @@ class MemoryGraph:
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
"""
|
||||
params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
|
||||
params = {
|
||||
"n_embedding": n_embedding,
|
||||
"threshold": self.threshold,
|
||||
"user_id": filters["user_id"],
|
||||
}
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
def search(self, query, filters):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
@@ -235,17 +252,12 @@ class MemoryGraph:
|
||||
|
||||
search_results = []
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item[0],
|
||||
"relationship": item[1],
|
||||
"target": item[2]
|
||||
})
|
||||
search_results.append({"source": item[0], "relationship": item[1], "target": item[2]})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def delete_all(self, filters):
|
||||
cypher = """
|
||||
MATCH (n {user_id: $user_id})
|
||||
@@ -254,7 +266,6 @@ class MemoryGraph:
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def get_all(self, filters):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
@@ -276,17 +287,18 @@ class MemoryGraph:
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
final_results.append({
|
||||
"source": result['source'],
|
||||
"relationship": result['relationship'],
|
||||
"target": result['target']
|
||||
})
|
||||
final_results.append(
|
||||
{
|
||||
"source": result["source"],
|
||||
"relationship": result["relationship"],
|
||||
"target": result["target"],
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved {len(final_results)} relationships")
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
|
||||
def _update_relationship(self, source, target, relationship, filters):
|
||||
"""
|
||||
Update or create a relationship between two nodes in the graph.
|
||||
@@ -309,14 +321,20 @@ class MemoryGraph:
|
||||
MERGE (n1 {name: $source, user_id: $user_id})
|
||||
MERGE (n2 {name: $target, user_id: $user_id})
|
||||
"""
|
||||
self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
self.graph.query(
|
||||
check_and_create_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
# Delete any existing relationship between the nodes
|
||||
delete_query = """
|
||||
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
|
||||
DELETE r
|
||||
"""
|
||||
self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
self.graph.query(
|
||||
delete_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
# Create the new relationship
|
||||
create_query = f"""
|
||||
@@ -324,7 +342,10 @@ class MemoryGraph:
|
||||
CREATE (n1)-[r:{relationship}]->(n2)
|
||||
RETURN n1, r, n2
|
||||
"""
|
||||
result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
result = self.graph.query(
|
||||
create_query,
|
||||
params={"source": source, "target": target, "user_id": filters["user_id"]},
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
||||
|
||||
@@ -10,14 +10,14 @@ from typing import Any, Dict
|
||||
import pytz
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -30,9 +30,7 @@ class Memory(MemoryBase):
|
||||
self.config = config
|
||||
|
||||
self.custom_prompt = self.config.custom_prompt
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
)
|
||||
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
|
||||
self.vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
@@ -45,12 +43,12 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1" and self.config.graph_store.config:
|
||||
from mem0.memory.graph_memory import MemoryGraph
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config_dict: Dict[str, Any]):
|
||||
try:
|
||||
@@ -60,7 +58,6 @@ class Memory(MemoryBase):
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages,
|
||||
@@ -98,9 +95,7 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = metadata["run_id"] = run_id
|
||||
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError(
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -116,8 +111,8 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
return {
|
||||
"results" : vector_store_result,
|
||||
"relations" : graph_result,
|
||||
"results": vector_store_result,
|
||||
"relations": graph_result,
|
||||
}
|
||||
else:
|
||||
warnings.warn(
|
||||
@@ -125,29 +120,29 @@ class Memory(MemoryBase):
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
return {"message": "ok"}
|
||||
|
||||
|
||||
def _add_to_vector_store(self, messages, metadata, filters):
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
if self.custom_prompt:
|
||||
system_prompt=self.custom_prompt
|
||||
user_prompt=f"Input: {parsed_messages}"
|
||||
system_prompt = self.custom_prompt
|
||||
user_prompt = f"Input: {parsed_messages}"
|
||||
else:
|
||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||
|
||||
response = self.llm.generate_response(
|
||||
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
try:
|
||||
new_retrieved_facts = json.loads(response)[
|
||||
"facts"
|
||||
]
|
||||
new_retrieved_facts = json.loads(response)["facts"]
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
@@ -178,24 +173,30 @@ class Memory(MemoryBase):
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
_ = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
returned_memories.append(
|
||||
{
|
||||
"memory": resp["text"],
|
||||
"event": resp["event"],
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
"previous_memory" : resp["old_memory"],
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"memory": resp["text"],
|
||||
"event": resp["event"],
|
||||
"previous_memory": resp["old_memory"],
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "DELETE":
|
||||
self._delete_memory(memory_id=resp["id"])
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
returned_memories.append(
|
||||
{
|
||||
"memory": resp["text"],
|
||||
"event": resp["event"],
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
@@ -206,7 +207,6 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
return returned_memories
|
||||
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
@@ -220,7 +220,6 @@ class Memory(MemoryBase):
|
||||
|
||||
return added_entities
|
||||
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
@@ -236,11 +235,7 @@ class Memory(MemoryBase):
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
filters = {
|
||||
key: memory.payload[key]
|
||||
for key in ["user_id", "agent_id", "run_id"]
|
||||
if memory.payload.get(key)
|
||||
}
|
||||
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
|
||||
|
||||
# Prepare base memory item
|
||||
memory_item = MemoryItem(
|
||||
@@ -261,9 +256,7 @@ class Memory(MemoryBase):
|
||||
"created_at",
|
||||
"updated_at",
|
||||
}
|
||||
additional_metadata = {
|
||||
k: v for k, v in memory.payload.items() if k not in excluded_keys
|
||||
}
|
||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
|
||||
if additional_metadata:
|
||||
memory_item["metadata"] = additional_metadata
|
||||
|
||||
@@ -271,7 +264,6 @@ class Memory(MemoryBase):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
List all memories.
|
||||
@@ -288,10 +280,12 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||
future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
)
|
||||
|
||||
all_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
@@ -307,15 +301,22 @@ class Memory(MemoryBase):
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"hash",
|
||||
"data",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
}
|
||||
all_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
@@ -325,19 +326,9 @@ class Memory(MemoryBase):
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
).model_dump(exclude={"score"}),
|
||||
**{
|
||||
key: mem.payload[key]
|
||||
for key in ["user_id", "agent_id", "run_id"]
|
||||
if key in mem.payload
|
||||
},
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**(
|
||||
{
|
||||
"metadata": {
|
||||
k: v
|
||||
for k, v in mem.payload.items()
|
||||
if k not in excluded_keys
|
||||
}
|
||||
}
|
||||
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys)
|
||||
else {}
|
||||
),
|
||||
@@ -346,10 +337,7 @@ class Memory(MemoryBase):
|
||||
]
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
):
|
||||
def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
|
||||
"""
|
||||
Search for memories.
|
||||
|
||||
@@ -373,15 +361,21 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = run_id
|
||||
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError(
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"filters": len(filters), "limit": limit, "version": self.version},
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||
future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.search, query, filters)
|
||||
if self.version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
)
|
||||
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
@@ -390,23 +384,20 @@ class Memory(MemoryBase):
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"results" : original_memories}
|
||||
return {"results": original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
stacklevel=2,
|
||||
)
|
||||
return original_memories
|
||||
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
query=embeddings, limit=limit, filters=filters
|
||||
)
|
||||
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
|
||||
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
@@ -428,19 +419,9 @@ class Memory(MemoryBase):
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
score=mem.score,
|
||||
).model_dump(),
|
||||
**{
|
||||
key: mem.payload[key]
|
||||
for key in ["user_id", "agent_id", "run_id"]
|
||||
if key in mem.payload
|
||||
},
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**(
|
||||
{
|
||||
"metadata": {
|
||||
k: v
|
||||
for k, v in mem.payload.items()
|
||||
if k not in excluded_keys
|
||||
}
|
||||
}
|
||||
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys)
|
||||
else {}
|
||||
),
|
||||
@@ -450,7 +431,6 @@ class Memory(MemoryBase):
|
||||
|
||||
return original_memories
|
||||
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
@@ -466,7 +446,6 @@ class Memory(MemoryBase):
|
||||
self._update_memory(memory_id, data)
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID.
|
||||
@@ -478,7 +457,6 @@ class Memory(MemoryBase):
|
||||
self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
Delete all memories.
|
||||
@@ -511,8 +489,7 @@ class Memory(MemoryBase):
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all(filters)
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
@@ -527,7 +504,6 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
|
||||
def _create_memory(self, data, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
@@ -542,12 +518,9 @@ class Memory(MemoryBase):
|
||||
ids=[memory_id],
|
||||
payloads=[metadata],
|
||||
)
|
||||
self.db.add_history(
|
||||
memory_id, None, data, "ADD", created_at=metadata["created_at"]
|
||||
)
|
||||
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
|
||||
return memory_id
|
||||
|
||||
|
||||
def _update_memory(self, memory_id, data, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -557,9 +530,7 @@ class Memory(MemoryBase):
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = existing_memory.payload.get("hash")
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
new_metadata["updated_at"] = datetime.now(
|
||||
pytz.timezone("US/Pacific")
|
||||
).isoformat()
|
||||
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
|
||||
|
||||
if "user_id" in existing_memory.payload:
|
||||
new_metadata["user_id"] = existing_memory.payload["user_id"]
|
||||
@@ -584,7 +555,6 @@ class Memory(MemoryBase):
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
def _delete_memory(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -592,7 +562,6 @@ class Memory(MemoryBase):
|
||||
self.vector_store.delete(vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the memory store.
|
||||
@@ -602,6 +571,5 @@ class Memory(MemoryBase):
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
|
||||
def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
|
||||
@@ -12,9 +12,7 @@ class SQLiteManager:
|
||||
with self.connection:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='history'"
|
||||
)
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
|
||||
table_exists = cursor.fetchone() is not None
|
||||
|
||||
if table_exists:
|
||||
@@ -62,7 +60,7 @@ class SQLiteManager:
|
||||
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
|
||||
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
|
||||
FROM old_history
|
||||
"""
|
||||
""" # noqa: E501
|
||||
)
|
||||
|
||||
cursor.execute("DROP TABLE old_history")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import os
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
@@ -15,8 +15,9 @@ if isinstance(MEM0_TELEMETRY, str):
|
||||
if not isinstance(MEM0_TELEMETRY, bool):
|
||||
raise ValueError("MEM0_TELEMETRY must be a boolean value.")
|
||||
|
||||
logging.getLogger('posthog').setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("posthog").setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
|
||||
|
||||
|
||||
class AnonymousTelemetry:
|
||||
def __init__(self, project_api_key, host):
|
||||
@@ -24,9 +25,8 @@ class AnonymousTelemetry:
|
||||
# Call setup config to ensure that the user_id is generated
|
||||
setup_config()
|
||||
self.user_id = get_user_id()
|
||||
# Optional
|
||||
if not MEM0_TELEMETRY:
|
||||
self.posthog.disabled = True
|
||||
if not MEM0_TELEMETRY:
|
||||
self.posthog.disabled = True
|
||||
|
||||
def capture_event(self, event_name, properties=None):
|
||||
if properties is None:
|
||||
@@ -40,9 +40,7 @@ class AnonymousTelemetry:
|
||||
"machine": platform.machine(),
|
||||
**properties,
|
||||
}
|
||||
self.posthog.capture(
|
||||
distinct_id=self.user_id, event=event_name, properties=properties
|
||||
)
|
||||
self.posthog.capture(distinct_id=self.user_id, event=event_name, properties=properties)
|
||||
|
||||
def identify_user(self, user_id, properties=None):
|
||||
if properties is None:
|
||||
@@ -65,6 +63,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
|
||||
"collection": memory_instance.collection_name,
|
||||
"vector_size": memory_instance.embedding_model.config.embedding_dims,
|
||||
"history_store": "sqlite",
|
||||
"graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}" if memory_instance.config.graph_store.config else None,
|
||||
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
|
||||
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
|
||||
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
|
||||
@@ -76,7 +75,6 @@ def capture_event(event_name, memory_instance, additional_data=None):
|
||||
telemetry.capture_event(event_name, event_data)
|
||||
|
||||
|
||||
|
||||
def capture_client_event(event_name, instance, additional_data=None):
|
||||
event_data = {
|
||||
"function": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
|
||||
|
||||
@@ -4,13 +4,14 @@ from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
def get_fact_retrieval_messages(message):
|
||||
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"
|
||||
|
||||
|
||||
def parse_messages(messages):
|
||||
response = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
response += f"system: {msg['content']}\n"
|
||||
if msg["role"] == "user":
|
||||
response += f"user: {msg['content']}\n"
|
||||
if msg["role"] == "assistant":
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
response = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
response += f"system: {msg['content']}\n"
|
||||
if msg["role"] == "user":
|
||||
response += f"user: {msg['content']}\n"
|
||||
if msg["role"] == "assistant":
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user