383 lines
15 KiB
Python
383 lines
15 KiB
Python
import logging
|
|
import json
|
|
from mcp.server.fastmcp import FastMCP
|
|
from mcp.server.sse import SseServerTransport
|
|
from app.utils.memory import get_memory_client
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.routing import APIRouter
|
|
import contextvars
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from app.database import SessionLocal
|
|
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
|
|
from app.utils.db import get_user_and_app
|
|
import uuid
|
|
import datetime
|
|
from app.utils.permissions import check_memory_access_permissions
|
|
from qdrant_client import models as qdrant_models
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Initialize MCP and memory client
|
|
mcp = FastMCP("mem0-mcp-server")
|
|
|
|
# Check if OpenAI API key is set
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
raise Exception("OPENAI_API_KEY is not set in .env file")
|
|
|
|
memory_client = get_memory_client()
|
|
|
|
# Context variables for user_id and client_name
|
|
user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id")
|
|
client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name")
|
|
|
|
# Create a router for MCP endpoints
|
|
mcp_router = APIRouter(prefix="/mcp")
|
|
|
|
# Initialize SSE transport
|
|
sse = SseServerTransport("/mcp/messages/")
|
|
|
|
@mcp.tool(description="Add new memories to the user's memory")
|
|
async def add_memories(text: str) -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
# Check if app is active
|
|
if not app.is_active:
|
|
return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories."
|
|
|
|
response = memory_client.add(text,
|
|
user_id=uid,
|
|
metadata={
|
|
"source_app": "openmemory",
|
|
"mcp_client": client_name,
|
|
})
|
|
|
|
# Process the response and update database
|
|
if isinstance(response, dict) and 'results' in response:
|
|
for result in response['results']:
|
|
memory_id = uuid.UUID(result['id'])
|
|
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
|
|
|
if result['event'] == 'ADD':
|
|
if not memory:
|
|
memory = Memory(
|
|
id=memory_id,
|
|
user_id=user.id,
|
|
app_id=app.id,
|
|
content=result['memory'],
|
|
state=MemoryState.active
|
|
)
|
|
db.add(memory)
|
|
else:
|
|
memory.state = MemoryState.active
|
|
memory.content = result['memory']
|
|
|
|
# Create history entry
|
|
history = MemoryStatusHistory(
|
|
memory_id=memory_id,
|
|
changed_by=user.id,
|
|
old_state=MemoryState.deleted if memory else None,
|
|
new_state=MemoryState.active
|
|
)
|
|
db.add(history)
|
|
|
|
elif result['event'] == 'DELETE':
|
|
if memory:
|
|
memory.state = MemoryState.deleted
|
|
memory.deleted_at = datetime.datetime.now(datetime.UTC)
|
|
# Create history entry
|
|
history = MemoryStatusHistory(
|
|
memory_id=memory_id,
|
|
changed_by=user.id,
|
|
old_state=MemoryState.active,
|
|
new_state=MemoryState.deleted
|
|
)
|
|
db.add(history)
|
|
|
|
db.commit()
|
|
|
|
return response
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
return f"Error adding to memory: {e}"
|
|
|
|
|
|
@mcp.tool(description="Search the user's memory for memories that match the query")
|
|
async def search_memory(query: str) -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
# Get accessible memory IDs based on ACL
|
|
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
|
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
|
conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))]
|
|
logging.info(f"Accessible memory IDs: {accessible_memory_ids}")
|
|
logging.info(f"Conditions: {conditions}")
|
|
if accessible_memory_ids:
|
|
# Convert UUIDs to strings for Qdrant
|
|
accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids]
|
|
conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str))
|
|
filters = qdrant_models.Filter(must=conditions)
|
|
logging.info(f"Filters: {filters}")
|
|
embeddings = memory_client.embedding_model.embed(query, "search")
|
|
hits = memory_client.vector_store.client.query_points(
|
|
collection_name=memory_client.vector_store.collection_name,
|
|
query=embeddings,
|
|
query_filter=filters,
|
|
limit=10,
|
|
)
|
|
|
|
memories = hits.points
|
|
memories = [
|
|
{
|
|
"id": memory.id,
|
|
"memory": memory.payload["data"],
|
|
"hash": memory.payload.get("hash"),
|
|
"created_at": memory.payload.get("created_at"),
|
|
"updated_at": memory.payload.get("updated_at"),
|
|
"score": memory.score,
|
|
}
|
|
for memory in memories
|
|
]
|
|
|
|
# Log memory access for each memory found
|
|
if isinstance(memories, dict) and 'results' in memories:
|
|
print(f"Memories: {memories}")
|
|
for memory_data in memories['results']:
|
|
if 'id' in memory_data:
|
|
memory_id = uuid.UUID(memory_data['id'])
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="search",
|
|
metadata_={
|
|
"query": query,
|
|
"score": memory_data.get('score'),
|
|
"hash": memory_data.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
db.commit()
|
|
else:
|
|
for memory in memories:
|
|
memory_id = uuid.UUID(memory['id'])
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="search",
|
|
metadata_={
|
|
"query": query,
|
|
"score": memory.get('score'),
|
|
"hash": memory.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
db.commit()
|
|
return json.dumps(memories, indent=2)
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
return f"Error searching memory: {e}"
|
|
|
|
|
|
@mcp.tool(description="List all memories in the user's memory")
|
|
async def list_memories() -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
# Get all memories
|
|
memories = memory_client.get_all(user_id=uid)
|
|
filtered_memories = []
|
|
|
|
# Filter memories based on permissions
|
|
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
|
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
|
if isinstance(memories, dict) and 'results' in memories:
|
|
for memory_data in memories['results']:
|
|
if 'id' in memory_data:
|
|
memory_id = uuid.UUID(memory_data['id'])
|
|
if memory_id in accessible_memory_ids:
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="list",
|
|
metadata_={
|
|
"hash": memory_data.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
filtered_memories.append(memory_data)
|
|
db.commit()
|
|
else:
|
|
for memory in memories:
|
|
memory_id = uuid.UUID(memory['id'])
|
|
memory_obj = db.query(Memory).filter(Memory.id == memory_id).first()
|
|
if memory_obj and check_memory_access_permissions(db, memory_obj, app.id):
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="list",
|
|
metadata_={
|
|
"hash": memory.get('hash')
|
|
}
|
|
)
|
|
db.add(access_log)
|
|
filtered_memories.append(memory)
|
|
db.commit()
|
|
return json.dumps(filtered_memories, indent=2)
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
return f"Error getting memories: {e}"
|
|
|
|
|
|
@mcp.tool(description="Delete all memories in the user's memory")
|
|
async def delete_all_memories() -> str:
|
|
uid = user_id_var.get(None)
|
|
client_name = client_name_var.get(None)
|
|
if not uid:
|
|
return "Error: user_id not provided"
|
|
if not client_name:
|
|
return "Error: client_name not provided"
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# Get or create user and app
|
|
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
|
|
|
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
|
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
|
|
|
# delete the accessible memories only
|
|
for memory_id in accessible_memory_ids:
|
|
memory_client.delete(memory_id)
|
|
|
|
# Update each memory's state and create history entries
|
|
now = datetime.datetime.now(datetime.UTC)
|
|
for memory_id in accessible_memory_ids:
|
|
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
|
# Update memory state
|
|
memory.state = MemoryState.deleted
|
|
memory.deleted_at = now
|
|
|
|
# Create history entry
|
|
history = MemoryStatusHistory(
|
|
memory_id=memory_id,
|
|
changed_by=user.id,
|
|
old_state=MemoryState.active,
|
|
new_state=MemoryState.deleted
|
|
)
|
|
db.add(history)
|
|
|
|
# Create access log entry
|
|
access_log = MemoryAccessLog(
|
|
memory_id=memory_id,
|
|
app_id=app.id,
|
|
access_type="delete_all",
|
|
metadata_={"operation": "bulk_delete"}
|
|
)
|
|
db.add(access_log)
|
|
|
|
db.commit()
|
|
return "Successfully deleted all memories"
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
return f"Error deleting memories: {e}"
|
|
|
|
|
|
@mcp_router.get("/{client_name}/sse/{user_id}")
|
|
async def handle_sse(request: Request):
|
|
"""Handle SSE connections for a specific user and client"""
|
|
# Extract user_id and client_name from path parameters
|
|
uid = request.path_params.get("user_id")
|
|
user_token = user_id_var.set(uid or "")
|
|
client_name = request.path_params.get("client_name")
|
|
client_token = client_name_var.set(client_name or "")
|
|
|
|
try:
|
|
# Handle SSE connection
|
|
async with sse.connect_sse(
|
|
request.scope,
|
|
request.receive,
|
|
request._send,
|
|
) as (read_stream, write_stream):
|
|
await mcp._mcp_server.run(
|
|
read_stream,
|
|
write_stream,
|
|
mcp._mcp_server.create_initialization_options(),
|
|
)
|
|
finally:
|
|
# Clean up context variables
|
|
user_id_var.reset(user_token)
|
|
client_name_var.reset(client_token)
|
|
|
|
|
|
@mcp_router.post("/{client_name}/sse/{user_id}/messages/")
|
|
async def handle_post_message(request: Request):
|
|
"""Handle POST messages for SSE"""
|
|
try:
|
|
body = await request.body()
|
|
|
|
# Create a simple receive function that returns the body
|
|
async def receive():
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
# Create a simple send function that does nothing
|
|
async def send(message):
|
|
return {}
|
|
|
|
# Call handle_post_message with the correct arguments
|
|
await sse.handle_post_message(request.scope, receive, send)
|
|
|
|
# Return a success response
|
|
return {"status": "ok"}
|
|
finally:
|
|
pass
|
|
# Clean up context variable
|
|
# client_name_var.reset(client_token)
|
|
|
|
def setup_mcp_server(app: FastAPI):
|
|
"""Setup MCP server with the FastAPI application"""
|
|
mcp._mcp_server.name = f"mem0-mcp-server"
|
|
|
|
# Include MCP router in the FastAPI app
|
|
app.include_router(mcp_router)
|