Add OpenMemory (#2676)
Co-authored-by: Saket Aryan <94069182+whysosaket@users.noreply.github.com> Co-authored-by: Saket Aryan <saketaryan2002@gmail.com>
This commit is contained in:
382
openmemory/api/app/mcp_server.py
Normal file
382
openmemory/api/app/mcp_server.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import logging
|
||||
import json
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from app.utils.memory import get_memory_client
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.routing import APIRouter
|
||||
import contextvars
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from app.database import SessionLocal
|
||||
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
|
||||
from app.utils.db import get_user_and_app
|
||||
import uuid
|
||||
import datetime
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
from qdrant_client import models as qdrant_models
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize MCP and memory client
|
||||
mcp = FastMCP("mem0-mcp-server")
|
||||
|
||||
# Check if OpenAI API key is set
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
raise Exception("OPENAI_API_KEY is not set in .env file")
|
||||
|
||||
memory_client = get_memory_client()
|
||||
|
||||
# Context variables for user_id and client_name
|
||||
user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id")
|
||||
client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name")
|
||||
|
||||
# Create a router for MCP endpoints
|
||||
mcp_router = APIRouter(prefix="/mcp")
|
||||
|
||||
# Initialize SSE transport
|
||||
sse = SseServerTransport("/mcp/messages/")
|
||||
|
||||
@mcp.tool(description="Add new memories to the user's memory")
|
||||
async def add_memories(text: str) -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
# Check if app is active
|
||||
if not app.is_active:
|
||||
return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories."
|
||||
|
||||
response = memory_client.add(text,
|
||||
user_id=uid,
|
||||
metadata={
|
||||
"source_app": "openmemory",
|
||||
"mcp_client": client_name,
|
||||
})
|
||||
|
||||
# Process the response and update database
|
||||
if isinstance(response, dict) and 'results' in response:
|
||||
for result in response['results']:
|
||||
memory_id = uuid.UUID(result['id'])
|
||||
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
|
||||
if result['event'] == 'ADD':
|
||||
if not memory:
|
||||
memory = Memory(
|
||||
id=memory_id,
|
||||
user_id=user.id,
|
||||
app_id=app.id,
|
||||
content=result['memory'],
|
||||
state=MemoryState.active
|
||||
)
|
||||
db.add(memory)
|
||||
else:
|
||||
memory.state = MemoryState.active
|
||||
memory.content = result['memory']
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.deleted if memory else None,
|
||||
new_state=MemoryState.active
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
elif result['event'] == 'DELETE':
|
||||
if memory:
|
||||
memory.state = MemoryState.deleted
|
||||
memory.deleted_at = datetime.datetime.now(datetime.UTC)
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.active,
|
||||
new_state=MemoryState.deleted
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
db.commit()
|
||||
|
||||
return response
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
return f"Error adding to memory: {e}"
|
||||
|
||||
|
||||
@mcp.tool(description="Search the user's memory for memories that match the query")
|
||||
async def search_memory(query: str) -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
# Get accessible memory IDs based on ACL
|
||||
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
||||
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
||||
conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))]
|
||||
logging.info(f"Accessible memory IDs: {accessible_memory_ids}")
|
||||
logging.info(f"Conditions: {conditions}")
|
||||
if accessible_memory_ids:
|
||||
# Convert UUIDs to strings for Qdrant
|
||||
accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids]
|
||||
conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str))
|
||||
filters = qdrant_models.Filter(must=conditions)
|
||||
logging.info(f"Filters: {filters}")
|
||||
embeddings = memory_client.embedding_model.embed(query, "search")
|
||||
hits = memory_client.vector_store.client.query_points(
|
||||
collection_name=memory_client.vector_store.collection_name,
|
||||
query=embeddings,
|
||||
query_filter=filters,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
memories = hits.points
|
||||
memories = [
|
||||
{
|
||||
"id": memory.id,
|
||||
"memory": memory.payload["data"],
|
||||
"hash": memory.payload.get("hash"),
|
||||
"created_at": memory.payload.get("created_at"),
|
||||
"updated_at": memory.payload.get("updated_at"),
|
||||
"score": memory.score,
|
||||
}
|
||||
for memory in memories
|
||||
]
|
||||
|
||||
# Log memory access for each memory found
|
||||
if isinstance(memories, dict) and 'results' in memories:
|
||||
print(f"Memories: {memories}")
|
||||
for memory_data in memories['results']:
|
||||
if 'id' in memory_data:
|
||||
memory_id = uuid.UUID(memory_data['id'])
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="search",
|
||||
metadata_={
|
||||
"query": query,
|
||||
"score": memory_data.get('score'),
|
||||
"hash": memory_data.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
db.commit()
|
||||
else:
|
||||
for memory in memories:
|
||||
memory_id = uuid.UUID(memory['id'])
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="search",
|
||||
metadata_={
|
||||
"query": query,
|
||||
"score": memory.get('score'),
|
||||
"hash": memory.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
db.commit()
|
||||
return json.dumps(memories, indent=2)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return f"Error searching memory: {e}"
|
||||
|
||||
|
||||
@mcp.tool(description="List all memories in the user's memory")
|
||||
async def list_memories() -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
# Get all memories
|
||||
memories = memory_client.get_all(user_id=uid)
|
||||
filtered_memories = []
|
||||
|
||||
# Filter memories based on permissions
|
||||
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
||||
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
||||
if isinstance(memories, dict) and 'results' in memories:
|
||||
for memory_data in memories['results']:
|
||||
if 'id' in memory_data:
|
||||
memory_id = uuid.UUID(memory_data['id'])
|
||||
if memory_id in accessible_memory_ids:
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="list",
|
||||
metadata_={
|
||||
"hash": memory_data.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
filtered_memories.append(memory_data)
|
||||
db.commit()
|
||||
else:
|
||||
for memory in memories:
|
||||
memory_id = uuid.UUID(memory['id'])
|
||||
memory_obj = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
if memory_obj and check_memory_access_permissions(db, memory_obj, app.id):
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="list",
|
||||
metadata_={
|
||||
"hash": memory.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
filtered_memories.append(memory)
|
||||
db.commit()
|
||||
return json.dumps(filtered_memories, indent=2)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
return f"Error getting memories: {e}"
|
||||
|
||||
|
||||
@mcp.tool(description="Delete all memories in the user's memory")
|
||||
async def delete_all_memories() -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
||||
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
||||
|
||||
# delete the accessible memories only
|
||||
for memory_id in accessible_memory_ids:
|
||||
memory_client.delete(memory_id)
|
||||
|
||||
# Update each memory's state and create history entries
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
for memory_id in accessible_memory_ids:
|
||||
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
# Update memory state
|
||||
memory.state = MemoryState.deleted
|
||||
memory.deleted_at = now
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.active,
|
||||
new_state=MemoryState.deleted
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="delete_all",
|
||||
metadata_={"operation": "bulk_delete"}
|
||||
)
|
||||
db.add(access_log)
|
||||
|
||||
db.commit()
|
||||
return "Successfully deleted all memories"
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
return f"Error deleting memories: {e}"
|
||||
|
||||
|
||||
@mcp_router.get("/{client_name}/sse/{user_id}")
|
||||
async def handle_sse(request: Request):
|
||||
"""Handle SSE connections for a specific user and client"""
|
||||
# Extract user_id and client_name from path parameters
|
||||
uid = request.path_params.get("user_id")
|
||||
user_token = user_id_var.set(uid or "")
|
||||
client_name = request.path_params.get("client_name")
|
||||
client_token = client_name_var.set(client_name or "")
|
||||
|
||||
try:
|
||||
# Handle SSE connection
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send,
|
||||
) as (read_stream, write_stream):
|
||||
await mcp._mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp._mcp_server.create_initialization_options(),
|
||||
)
|
||||
finally:
|
||||
# Clean up context variables
|
||||
user_id_var.reset(user_token)
|
||||
client_name_var.reset(client_token)
|
||||
|
||||
|
||||
@mcp_router.post("/messages/")
|
||||
async def handle_post_message(request: Request):
|
||||
"""Handle POST messages for SSE"""
|
||||
try:
|
||||
body = await request.body()
|
||||
|
||||
# Create a simple receive function that returns the body
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
# Create a simple send function that does nothing
|
||||
async def send(message):
|
||||
pass
|
||||
|
||||
# Call handle_post_message with the correct arguments
|
||||
await sse.handle_post_message(request.scope, receive, send)
|
||||
|
||||
# Return a success response
|
||||
return {"status": "ok"}
|
||||
finally:
|
||||
pass
|
||||
# Clean up context variable
|
||||
# client_name_var.reset(client_token)
|
||||
|
||||
def setup_mcp_server(app: FastAPI):
|
||||
"""Setup MCP server with the FastAPI application"""
|
||||
mcp._mcp_server.name = f"mem0-mcp-server"
|
||||
|
||||
# Include MCP router in the FastAPI app
|
||||
app.include_router(mcp_router)
|
||||
Reference in New Issue
Block a user