Files
t6_mem0/openmemory/api/app/mcp_server.py
2025-05-16 22:11:22 +05:30

393 lines
15 KiB
Python

import logging
import json
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from app.utils.memory import get_memory_client
from fastapi import FastAPI, Request
from fastapi.routing import APIRouter
import contextvars
import os
from dotenv import load_dotenv
from app.database import SessionLocal
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
from app.utils.db import get_user_and_app
import uuid
import datetime
from app.utils.permissions import check_memory_access_permissions
from qdrant_client import models as qdrant_models
# Load environment variables
load_dotenv()
# Initialize MCP and memory client
mcp = FastMCP("mem0-mcp-server")
# Check if OpenAI API key is set
if not os.getenv("OPENAI_API_KEY"):
raise Exception("OPENAI_API_KEY is not set in .env file")
memory_client = get_memory_client()
# Context variables for user_id and client_name
user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id")
client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name")
# Create a router for MCP endpoints
mcp_router = APIRouter(prefix="/mcp")
# Initialize SSE transport
sse = SseServerTransport("/mcp/messages/")
@mcp.tool(description="Add a new memory. This method is called everytime the user informs anything about themselves, their preferences, or anything that has any relevent information whcih can be useful in the future conversation. This can also be called when the user asks you to remember something.")
async def add_memories(text: str) -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
# Check if app is active
if not app.is_active:
return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories."
response = memory_client.add(text,
user_id=uid,
metadata={
"source_app": "openmemory",
"mcp_client": client_name,
})
# Process the response and update database
if isinstance(response, dict) and 'results' in response:
for result in response['results']:
memory_id = uuid.UUID(result['id'])
memory = db.query(Memory).filter(Memory.id == memory_id).first()
if result['event'] == 'ADD':
if not memory:
memory = Memory(
id=memory_id,
user_id=user.id,
app_id=app.id,
content=result['memory'],
state=MemoryState.active
)
db.add(memory)
else:
memory.state = MemoryState.active
memory.content = result['memory']
# Create history entry
history = MemoryStatusHistory(
memory_id=memory_id,
changed_by=user.id,
old_state=MemoryState.deleted if memory else None,
new_state=MemoryState.active
)
db.add(history)
elif result['event'] == 'DELETE':
if memory:
memory.state = MemoryState.deleted
memory.deleted_at = datetime.datetime.now(datetime.UTC)
# Create history entry
history = MemoryStatusHistory(
memory_id=memory_id,
changed_by=user.id,
old_state=MemoryState.active,
new_state=MemoryState.deleted
)
db.add(history)
db.commit()
return response
finally:
db.close()
except Exception as e:
return f"Error adding to memory: {e}"
@mcp.tool(description="Search through stored memories. This method is called EVERYTIME the user asks anything.")
async def search_memory(query: str) -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
# Get accessible memory IDs based on ACL
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))]
if accessible_memory_ids:
# Convert UUIDs to strings for Qdrant
accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids]
conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str))
filters = qdrant_models.Filter(must=conditions)
embeddings = memory_client.embedding_model.embed(query, "search")
hits = memory_client.vector_store.client.query_points(
collection_name=memory_client.vector_store.collection_name,
query=embeddings,
query_filter=filters,
limit=10,
)
# Process search results
memories = hits.points
memories = [
{
"id": memory.id,
"memory": memory.payload["data"],
"hash": memory.payload.get("hash"),
"created_at": memory.payload.get("created_at"),
"updated_at": memory.payload.get("updated_at"),
"score": memory.score,
}
for memory in memories
]
# Log memory access for each memory found
if isinstance(memories, dict) and 'results' in memories:
print(f"Memories: {memories}")
for memory_data in memories['results']:
if 'id' in memory_data:
memory_id = uuid.UUID(memory_data['id'])
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="search",
metadata_={
"query": query,
"score": memory_data.get('score'),
"hash": memory_data.get('hash')
}
)
db.add(access_log)
db.commit()
else:
for memory in memories:
memory_id = uuid.UUID(memory['id'])
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="search",
metadata_={
"query": query,
"score": memory.get('score'),
"hash": memory.get('hash')
}
)
db.add(access_log)
db.commit()
return json.dumps(memories, indent=2)
finally:
db.close()
except Exception as e:
logging.exception(e)
return f"Error searching memory: {e}"
@mcp.tool(description="List all memories in the user's memory")
async def list_memories() -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
# Get all memories
memories = memory_client.get_all(user_id=uid)
filtered_memories = []
# Filter memories based on permissions
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
if isinstance(memories, dict) and 'results' in memories:
for memory_data in memories['results']:
if 'id' in memory_data:
memory_id = uuid.UUID(memory_data['id'])
if memory_id in accessible_memory_ids:
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="list",
metadata_={
"hash": memory_data.get('hash')
}
)
db.add(access_log)
filtered_memories.append(memory_data)
db.commit()
else:
for memory in memories:
memory_id = uuid.UUID(memory['id'])
memory_obj = db.query(Memory).filter(Memory.id == memory_id).first()
if memory_obj and check_memory_access_permissions(db, memory_obj, app.id):
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="list",
metadata_={
"hash": memory.get('hash')
}
)
db.add(access_log)
filtered_memories.append(memory)
db.commit()
return json.dumps(filtered_memories, indent=2)
finally:
db.close()
except Exception as e:
return f"Error getting memories: {e}"
@mcp.tool(description="Delete all memories in the user's memory")
async def delete_all_memories() -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
# delete the accessible memories only
for memory_id in accessible_memory_ids:
memory_client.delete(memory_id)
# Update each memory's state and create history entries
now = datetime.datetime.now(datetime.UTC)
for memory_id in accessible_memory_ids:
memory = db.query(Memory).filter(Memory.id == memory_id).first()
# Update memory state
memory.state = MemoryState.deleted
memory.deleted_at = now
# Create history entry
history = MemoryStatusHistory(
memory_id=memory_id,
changed_by=user.id,
old_state=MemoryState.active,
new_state=MemoryState.deleted
)
db.add(history)
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="delete_all",
metadata_={"operation": "bulk_delete"}
)
db.add(access_log)
db.commit()
return "Successfully deleted all memories"
finally:
db.close()
except Exception as e:
return f"Error deleting memories: {e}"
@mcp_router.get("/{client_name}/sse/{user_id}")
async def handle_sse(request: Request):
"""Handle SSE connections for a specific user and client"""
# Extract user_id and client_name from path parameters
uid = request.path_params.get("user_id")
user_token = user_id_var.set(uid or "")
client_name = request.path_params.get("client_name")
client_token = client_name_var.set(client_name or "")
try:
# Handle SSE connection
async with sse.connect_sse(
request.scope,
request.receive,
request._send,
) as (read_stream, write_stream):
await mcp._mcp_server.run(
read_stream,
write_stream,
mcp._mcp_server.create_initialization_options(),
)
finally:
# Clean up context variables
user_id_var.reset(user_token)
client_name_var.reset(client_token)
@mcp_router.post("/messages/")
async def handle_get_message(request: Request):
return await handle_post_message(request)
@mcp_router.post("/{client_name}/sse/{user_id}/messages/")
async def handle_post_message(request: Request):
return await handle_post_message(request)
async def handle_post_message(request: Request):
"""Handle POST messages for SSE"""
try:
body = await request.body()
# Create a simple receive function that returns the body
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
# Create a simple send function that does nothing
async def send(message):
return {}
# Call handle_post_message with the correct arguments
await sse.handle_post_message(request.scope, receive, send)
# Return a success response
return {"status": "ok"}
finally:
pass
# Clean up context variable
# client_name_var.reset(client_token)
def setup_mcp_server(app: FastAPI):
"""Setup MCP server with the FastAPI application"""
mcp._mcp_server.name = f"mem0-mcp-server"
# Include MCP router in the FastAPI app
app.include_router(mcp_router)