Files
t6_mem0/openmemory/api/app/mcp_server.py
Deshraj Yadav f51b39db91 Add OpenMemory (#2676)
Co-authored-by: Saket Aryan <94069182+whysosaket@users.noreply.github.com>
Co-authored-by: Saket Aryan <saketaryan2002@gmail.com>
2025-05-13 08:30:59 -07:00

383 lines
15 KiB
Python

import logging
import json
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from app.utils.memory import get_memory_client
from fastapi import FastAPI, Request
from fastapi.routing import APIRouter
import contextvars
import os
from dotenv import load_dotenv
from app.database import SessionLocal
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
from app.utils.db import get_user_and_app
import uuid
import datetime
from app.utils.permissions import check_memory_access_permissions
from qdrant_client import models as qdrant_models
# Load environment variables
load_dotenv()
# Initialize MCP and memory client
mcp = FastMCP("mem0-mcp-server")
# Check if OpenAI API key is set
if not os.getenv("OPENAI_API_KEY"):
raise Exception("OPENAI_API_KEY is not set in .env file")
memory_client = get_memory_client()
# Context variables for user_id and client_name
user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id")
client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name")
# Create a router for MCP endpoints
mcp_router = APIRouter(prefix="/mcp")
# Initialize SSE transport
sse = SseServerTransport("/mcp/messages/")
@mcp.tool(description="Add new memories to the user's memory")
async def add_memories(text: str) -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
# Check if app is active
if not app.is_active:
return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories."
response = memory_client.add(text,
user_id=uid,
metadata={
"source_app": "openmemory",
"mcp_client": client_name,
})
# Process the response and update database
if isinstance(response, dict) and 'results' in response:
for result in response['results']:
memory_id = uuid.UUID(result['id'])
memory = db.query(Memory).filter(Memory.id == memory_id).first()
if result['event'] == 'ADD':
if not memory:
memory = Memory(
id=memory_id,
user_id=user.id,
app_id=app.id,
content=result['memory'],
state=MemoryState.active
)
db.add(memory)
else:
memory.state = MemoryState.active
memory.content = result['memory']
# Create history entry
history = MemoryStatusHistory(
memory_id=memory_id,
changed_by=user.id,
old_state=MemoryState.deleted if memory else None,
new_state=MemoryState.active
)
db.add(history)
elif result['event'] == 'DELETE':
if memory:
memory.state = MemoryState.deleted
memory.deleted_at = datetime.datetime.now(datetime.UTC)
# Create history entry
history = MemoryStatusHistory(
memory_id=memory_id,
changed_by=user.id,
old_state=MemoryState.active,
new_state=MemoryState.deleted
)
db.add(history)
db.commit()
return response
finally:
db.close()
except Exception as e:
return f"Error adding to memory: {e}"
@mcp.tool(description="Search the user's memory for memories that match the query")
async def search_memory(query: str) -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
# Get accessible memory IDs based on ACL
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))]
logging.info(f"Accessible memory IDs: {accessible_memory_ids}")
logging.info(f"Conditions: {conditions}")
if accessible_memory_ids:
# Convert UUIDs to strings for Qdrant
accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids]
conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str))
filters = qdrant_models.Filter(must=conditions)
logging.info(f"Filters: {filters}")
embeddings = memory_client.embedding_model.embed(query, "search")
hits = memory_client.vector_store.client.query_points(
collection_name=memory_client.vector_store.collection_name,
query=embeddings,
query_filter=filters,
limit=10,
)
memories = hits.points
memories = [
{
"id": memory.id,
"memory": memory.payload["data"],
"hash": memory.payload.get("hash"),
"created_at": memory.payload.get("created_at"),
"updated_at": memory.payload.get("updated_at"),
"score": memory.score,
}
for memory in memories
]
# Log memory access for each memory found
if isinstance(memories, dict) and 'results' in memories:
print(f"Memories: {memories}")
for memory_data in memories['results']:
if 'id' in memory_data:
memory_id = uuid.UUID(memory_data['id'])
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="search",
metadata_={
"query": query,
"score": memory_data.get('score'),
"hash": memory_data.get('hash')
}
)
db.add(access_log)
db.commit()
else:
for memory in memories:
memory_id = uuid.UUID(memory['id'])
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="search",
metadata_={
"query": query,
"score": memory.get('score'),
"hash": memory.get('hash')
}
)
db.add(access_log)
db.commit()
return json.dumps(memories, indent=2)
finally:
db.close()
except Exception as e:
logging.exception(e)
return f"Error searching memory: {e}"
@mcp.tool(description="List all memories in the user's memory")
async def list_memories() -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
# Get all memories
memories = memory_client.get_all(user_id=uid)
filtered_memories = []
# Filter memories based on permissions
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
if isinstance(memories, dict) and 'results' in memories:
for memory_data in memories['results']:
if 'id' in memory_data:
memory_id = uuid.UUID(memory_data['id'])
if memory_id in accessible_memory_ids:
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="list",
metadata_={
"hash": memory_data.get('hash')
}
)
db.add(access_log)
filtered_memories.append(memory_data)
db.commit()
else:
for memory in memories:
memory_id = uuid.UUID(memory['id'])
memory_obj = db.query(Memory).filter(Memory.id == memory_id).first()
if memory_obj and check_memory_access_permissions(db, memory_obj, app.id):
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="list",
metadata_={
"hash": memory.get('hash')
}
)
db.add(access_log)
filtered_memories.append(memory)
db.commit()
return json.dumps(filtered_memories, indent=2)
finally:
db.close()
except Exception as e:
return f"Error getting memories: {e}"
@mcp.tool(description="Delete all memories in the user's memory")
async def delete_all_memories() -> str:
uid = user_id_var.get(None)
client_name = client_name_var.get(None)
if not uid:
return "Error: user_id not provided"
if not client_name:
return "Error: client_name not provided"
try:
db = SessionLocal()
try:
# Get or create user and app
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
# delete the accessible memories only
for memory_id in accessible_memory_ids:
memory_client.delete(memory_id)
# Update each memory's state and create history entries
now = datetime.datetime.now(datetime.UTC)
for memory_id in accessible_memory_ids:
memory = db.query(Memory).filter(Memory.id == memory_id).first()
# Update memory state
memory.state = MemoryState.deleted
memory.deleted_at = now
# Create history entry
history = MemoryStatusHistory(
memory_id=memory_id,
changed_by=user.id,
old_state=MemoryState.active,
new_state=MemoryState.deleted
)
db.add(history)
# Create access log entry
access_log = MemoryAccessLog(
memory_id=memory_id,
app_id=app.id,
access_type="delete_all",
metadata_={"operation": "bulk_delete"}
)
db.add(access_log)
db.commit()
return "Successfully deleted all memories"
finally:
db.close()
except Exception as e:
return f"Error deleting memories: {e}"
@mcp_router.get("/{client_name}/sse/{user_id}")
async def handle_sse(request: Request):
"""Handle SSE connections for a specific user and client"""
# Extract user_id and client_name from path parameters
uid = request.path_params.get("user_id")
user_token = user_id_var.set(uid or "")
client_name = request.path_params.get("client_name")
client_token = client_name_var.set(client_name or "")
try:
# Handle SSE connection
async with sse.connect_sse(
request.scope,
request.receive,
request._send,
) as (read_stream, write_stream):
await mcp._mcp_server.run(
read_stream,
write_stream,
mcp._mcp_server.create_initialization_options(),
)
finally:
# Clean up context variables
user_id_var.reset(user_token)
client_name_var.reset(client_token)
@mcp_router.post("/messages/")
async def handle_post_message(request: Request):
"""Handle POST messages for SSE"""
try:
body = await request.body()
# Create a simple receive function that returns the body
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
# Create a simple send function that does nothing
async def send(message):
pass
# Call handle_post_message with the correct arguments
await sse.handle_post_message(request.scope, receive, send)
# Return a success response
return {"status": "ok"}
finally:
pass
# Clean up context variable
# client_name_var.reset(client_token)
def setup_mcp_server(app: FastAPI):
"""Setup MCP server with the FastAPI application"""
mcp._mcp_server.name = f"mem0-mcp-server"
# Include MCP router in the FastAPI app
app.include_router(mcp_router)