393 lines
15 KiB
Python
393 lines
15 KiB
Python
import logging
|
|
import json
|
|
from mcp.server.fastmcp import FastMCP
|
|
from mcp.server.sse import SseServerTransport
|
|
from app.utils.memory import get_memory_client
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.routing import APIRouter
|
|
import contextvars
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from app.database import SessionLocal
|
|
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
|
|
from app.utils.db import get_user_and_app
|
|
import uuid
|
|
import datetime
|
|
from app.utils.permissions import check_memory_access_permissions
|
|
from qdrant_client import models as qdrant_models
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Initialize MCP and memory client
|
|
mcp = FastMCP("mem0-mcp-server")
|
|
|
|
# Check if OpenAI API key is set
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
raise Exception("OPENAI_API_KEY is not set in .env file")
|
|
|
|
memory_client = get_memory_client()
|
|
|
|
# Context variables for user_id and client_name
|
|
user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id")
|
|
client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name")
|
|
|
|
# Create a router for MCP endpoints
|
|
mcp_router = APIRouter(prefix="/mcp")
|
|
|
|
# Initialize SSE transport
|
|
sse = SseServerTransport("/mcp/messages/")
|
|
|
|
@mcp.tool(description="Add a new memory. This method is called everytime the user informs anything about themselves, their preferences, or anything that has any relevent information whcih can be useful in the future conversation. This can also be called when the user asks you to remember something.")
|
|
async def add_memories(text: str) -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
# Check if app is active
|
|
if not app.is_active:
|
|
return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories."
|
|
|
|
response = memory_client.add(text,
|
|
user_id=uid,
|
|
metadata={
|
|
"source_app": "openmemory",
|
|
"mcp_client": client_name,
|
|
})
|
|
|
|
# Process the response and update database
|
|
if isinstance(response, dict) and 'results' in response:
|
|
for result in response['results']:
|
|
memory_id = uuid.UUID(result['id'])
|
|
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
|
|
|
if result['event'] == 'ADD':
|
|
if not memory:
|
|
memory = Memory(
|
|
id=memory_id,
|
|
user_id=user.id,
|
|
app_id=app.id,
|
|
content=result['memory'],
|
|
state=MemoryState.active
|
|
)
|
|
db.add(memory)
|
|
else:
|
|
memory.state = MemoryState.active
|
|
memory.content = result['memory']
|
|
|
|
# Create history entry
|
|
history = MemoryStatusHistory(
|
|
memory_id=memory_id,
|
|
changed_by=user.id,
|
|
old_state=MemoryState.deleted if memory else None,
|
|
new_state=MemoryState.active
|
|
)
|
|
db.add(history)
|
|
|
|
elif result['event'] == 'DELETE':
|
|
if memory:
|
|
memory.state = MemoryState.deleted
|
|
memory.deleted_at = datetime.datetime.now(datetime.UTC)
|
|
# Create history entry
|
|
history = MemoryStatusHistory(
|
|
memory_id=memory_id,
|
|
changed_by=user.id,
|
|
old_state=MemoryState.active,
|
|
new_state=MemoryState.deleted
|
|
)
|
|
db.add(history)
|
|
|
|
db.commit()
|
|
|
|
return response
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
return f"Error adding to memory: {e}"
|
|
|
|
|
|
@mcp.tool(description="Search through stored memories. This method is called EVERYTIME the user asks anything.")
|
|
async def search_memory(query: str) -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
# Get accessible memory IDs based on ACL
|
|
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
|
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
|
|
|
conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))]
|
|
|
|
if accessible_memory_ids:
|
|
# Convert UUIDs to strings for Qdrant
|
|
accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids]
|
|
conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str))
|
|
|
|
filters = qdrant_models.Filter(must=conditions)
|
|
embeddings = memory_client.embedding_model.embed(query, "search")
|
|
|
|
hits = memory_client.vector_store.client.query_points(
|
|
collection_name=memory_client.vector_store.collection_name,
|
|
query=embeddings,
|
|
query_filter=filters,
|
|
limit=10,
|
|
)
|
|
|
|
# Process search results
|
|
memories = hits.points
|
|
memories = [
|
|
{
|
|
"id": memory.id,
|
|
"memory": memory.payload["data"],
|
|
"hash": memory.payload.get("hash"),
|
|
"created_at": memory.payload.get("created_at"),
|
|
"updated_at": memory.payload.get("updated_at"),
|
|
"score": memory.score,
|
|
}
|
|
for memory in memories
|
|
]
|
|
|
|
# Log memory access for each memory found
|
|
if isinstance(memories, dict) and 'results' in memories:
|
|
print(f"Memories: {memories}")
|
|
for memory_data in memories['results']:
|
|
if 'id' in memory_data:
|
|
memory_id = uuid.UUID(memory_data['id'])
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="search",
|
|
metadata_={
|
|
"query": query,
|
|
"score": memory_data.get('score'),
|
|
"hash": memory_data.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
db.commit()
|
|
else:
|
|
for memory in memories:
|
|
memory_id = uuid.UUID(memory['id'])
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="search",
|
|
metadata_={
|
|
"query": query,
|
|
"score": memory.get('score'),
|
|
"hash": memory.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
db.commit()
|
|
return json.dumps(memories, indent=2)
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
return f"Error searching memory: {e}"
|
|
|
|
|
|
@mcp.tool(description="List all memories in the user's memory")
|
|
async def list_memories() -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
# Get all memories
|
|
memories = memory_client.get_all(user_id=uid)
|
|
filtered_memories = []
|
|
|
|
# Filter memories based on permissions
|
|
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
|
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
|
if isinstance(memories, dict) and 'results' in memories:
|
|
for memory_data in memories['results']:
|
|
if 'id' in memory_data:
|
|
memory_id = uuid.UUID(memory_data['id'])
|
|
if memory_id in accessible_memory_ids:
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="list",
|
|
metadata_={
|
|
"hash": memory_data.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
filtered_memories.append(memory_data)
|
|
db.commit()
|
|
else:
|
|
for memory in memories:
|
|
memory_id = uuid.UUID(memory['id'])
|
|
memory_obj = db.query(Memory).filter(Memory.id == memory_id).first()
|
|
if memory_obj and check_memory_access_permissions(db, memory_obj, app.id):
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="list",
|
|
metadata_={
|
|
"hash": memory.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
filtered_memories.append(memory)
|
|
db.commit()
|
|
return json.dumps(filtered_memories, indent=2)
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
return f"Error getting memories: {e}"
|
|
|
|
|
|
@mcp.tool(description="Delete all memories in the user's memory")
|
|
async def delete_all_memories() -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
|
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
|
|
|
# delete the accessible memories only
|
|
for memory_id in accessible_memory_ids:
|
|
memory_client.delete(memory_id)
|
|
|
|
# Update each memory's state and create history entries
|
|
now = datetime.datetime.now(datetime.UTC)
|
|
for memory_id in accessible_memory_ids:
|
|
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
|
# Update memory state
|
|
memory.state = MemoryState.deleted
|
|
memory.deleted_at = now
|
|
|
|
# Create history entry
|
|
history = MemoryStatusHistory(
|
|
memory_id=memory_id,
|
|
changed_by=user.id,
|
|
old_state=MemoryState.active,
|
|
new_state=MemoryState.deleted
|
|
)
|
|
db.add(history)
|
|
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="delete_all",
|
|
metadata_={"operation": "bulk_delete"}
|
|
)
|
|
db.add(access_log)
|
|
|
|
db.commit()
|
|
return "Successfully deleted all memories"
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
return f"Error deleting memories: {e}"
|
|
|
|
|
|
@mcp_router.get("/{client_name}/sse/{user_id}")
|
|
async def handle_sse(request: Request):
|
|
"""Handle SSE connections for a specific user and client"""
|
|
# Extract user_id and client_name from path parameters
|
|
uid = request.path_params.get("user_id")
|
|
user_token = user_id_var.set(uid or "")
|
|
client_name = request.path_params.get("client_name")
|
|
client_token = client_name_var.set(client_name or "")
|
|
|
|
try:
|
|
# Handle SSE connection
|
|
async with sse.connect_sse(
|
|
request.scope,
|
|
request.receive,
|
|
request._send,
|
|
) as (read_stream, write_stream):
|
|
await mcp._mcp_server.run(
|
|
read_stream,
|
|
write_stream,
|
|
mcp._mcp_server.create_initialization_options(),
|
|
)
|
|
finally:
|
|
# Clean up context variables
|
|
user_id_var.reset(user_token)
|
|
client_name_var.reset(client_token)
|
|
|
|
|
|
@mcp_router.post("/messages/")
|
|
async def handle_get_message(request: Request):
|
|
return await handle_post_message(request)
|
|
|
|
|
|
@mcp_router.post("/{client_name}/sse/{user_id}/messages/")
|
|
async def handle_post_message(request: Request):
|
|
return await handle_post_message(request)
|
|
|
|
async def handle_post_message(request: Request):
|
|
"""Handle POST messages for SSE"""
|
|
try:
|
|
body = await request.body()
|
|
|
|
# Create a simple receive function that returns the body
|
|
async def receive():
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
# Create a simple send function that does nothing
|
|
async def send(message):
|
|
return {}
|
|
|
|
# Call handle_post_message with the correct arguments
|
|
await sse.handle_post_message(request.scope, receive, send)
|
|
|
|
# Return a success response
|
|
return {"status": "ok"}
|
|
finally:
|
|
pass
|
|
# Clean up context variable
|
|
# client_name_var.reset(client_token)
|
|
|
|
def setup_mcp_server(app: FastAPI):
|
|
"""Setup MCP server with the FastAPI application"""
|
|
mcp._mcp_server.name = f"mem0-mcp-server"
|
|
|
|
# Include MCP router in the FastAPI app
|
|
app.include_router(mcp_router)
|