Add OpenMemory (#2676)
Co-authored-by: Saket Aryan <94069182+whysosaket@users.noreply.github.com> Co-authored-by: Saket Aryan <saketaryan2002@gmail.com>
This commit is contained in:
1
openmemory/api/app/__init__.py
Normal file
1
openmemory/api/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# This file makes the app directory a Python package
|
||||
4
openmemory/api/app/config.py
Normal file
4
openmemory/api/app/config.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import os
|
||||
|
||||
USER_ID = os.getenv("USER", "default_user")
|
||||
DEFAULT_APP_ID = "openmemory"
|
||||
29
openmemory/api/app/database.py
Normal file
29
openmemory/api/app/database.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# load .env file (make sure you have DATABASE_URL set)
|
||||
load_dotenv()
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./openmemory.db")
|
||||
if not DATABASE_URL:
|
||||
raise RuntimeError("DATABASE_URL is not set in environment")
|
||||
|
||||
# SQLAlchemy engine & session
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
connect_args={"check_same_thread": False} # Needed for SQLite
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Base class for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Dependency for FastAPI
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
382
openmemory/api/app/mcp_server.py
Normal file
382
openmemory/api/app/mcp_server.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import logging
|
||||
import json
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from app.utils.memory import get_memory_client
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.routing import APIRouter
|
||||
import contextvars
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from app.database import SessionLocal
|
||||
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
|
||||
from app.utils.db import get_user_and_app
|
||||
import uuid
|
||||
import datetime
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
from qdrant_client import models as qdrant_models
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize MCP and memory client
|
||||
mcp = FastMCP("mem0-mcp-server")
|
||||
|
||||
# Check if OpenAI API key is set
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
raise Exception("OPENAI_API_KEY is not set in .env file")
|
||||
|
||||
memory_client = get_memory_client()
|
||||
|
||||
# Context variables for user_id and client_name
|
||||
user_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("user_id")
|
||||
client_name_var: contextvars.ContextVar[str] = contextvars.ContextVar("client_name")
|
||||
|
||||
# Create a router for MCP endpoints
|
||||
mcp_router = APIRouter(prefix="/mcp")
|
||||
|
||||
# Initialize SSE transport
|
||||
sse = SseServerTransport("/mcp/messages/")
|
||||
|
||||
@mcp.tool(description="Add new memories to the user's memory")
|
||||
async def add_memories(text: str) -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
# Check if app is active
|
||||
if not app.is_active:
|
||||
return f"Error: App {app.name} is currently paused on OpenMemory. Cannot create new memories."
|
||||
|
||||
response = memory_client.add(text,
|
||||
user_id=uid,
|
||||
metadata={
|
||||
"source_app": "openmemory",
|
||||
"mcp_client": client_name,
|
||||
})
|
||||
|
||||
# Process the response and update database
|
||||
if isinstance(response, dict) and 'results' in response:
|
||||
for result in response['results']:
|
||||
memory_id = uuid.UUID(result['id'])
|
||||
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
|
||||
if result['event'] == 'ADD':
|
||||
if not memory:
|
||||
memory = Memory(
|
||||
id=memory_id,
|
||||
user_id=user.id,
|
||||
app_id=app.id,
|
||||
content=result['memory'],
|
||||
state=MemoryState.active
|
||||
)
|
||||
db.add(memory)
|
||||
else:
|
||||
memory.state = MemoryState.active
|
||||
memory.content = result['memory']
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.deleted if memory else None,
|
||||
new_state=MemoryState.active
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
elif result['event'] == 'DELETE':
|
||||
if memory:
|
||||
memory.state = MemoryState.deleted
|
||||
memory.deleted_at = datetime.datetime.now(datetime.UTC)
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.active,
|
||||
new_state=MemoryState.deleted
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
db.commit()
|
||||
|
||||
return response
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
return f"Error adding to memory: {e}"
|
||||
|
||||
|
||||
@mcp.tool(description="Search the user's memory for memories that match the query")
|
||||
async def search_memory(query: str) -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
# Get accessible memory IDs based on ACL
|
||||
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
||||
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
||||
conditions = [qdrant_models.FieldCondition(key="user_id", match=qdrant_models.MatchValue(value=uid))]
|
||||
logging.info(f"Accessible memory IDs: {accessible_memory_ids}")
|
||||
logging.info(f"Conditions: {conditions}")
|
||||
if accessible_memory_ids:
|
||||
# Convert UUIDs to strings for Qdrant
|
||||
accessible_memory_ids_str = [str(memory_id) for memory_id in accessible_memory_ids]
|
||||
conditions.append(qdrant_models.HasIdCondition(has_id=accessible_memory_ids_str))
|
||||
filters = qdrant_models.Filter(must=conditions)
|
||||
logging.info(f"Filters: {filters}")
|
||||
embeddings = memory_client.embedding_model.embed(query, "search")
|
||||
hits = memory_client.vector_store.client.query_points(
|
||||
collection_name=memory_client.vector_store.collection_name,
|
||||
query=embeddings,
|
||||
query_filter=filters,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
memories = hits.points
|
||||
memories = [
|
||||
{
|
||||
"id": memory.id,
|
||||
"memory": memory.payload["data"],
|
||||
"hash": memory.payload.get("hash"),
|
||||
"created_at": memory.payload.get("created_at"),
|
||||
"updated_at": memory.payload.get("updated_at"),
|
||||
"score": memory.score,
|
||||
}
|
||||
for memory in memories
|
||||
]
|
||||
|
||||
# Log memory access for each memory found
|
||||
if isinstance(memories, dict) and 'results' in memories:
|
||||
print(f"Memories: {memories}")
|
||||
for memory_data in memories['results']:
|
||||
if 'id' in memory_data:
|
||||
memory_id = uuid.UUID(memory_data['id'])
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="search",
|
||||
metadata_={
|
||||
"query": query,
|
||||
"score": memory_data.get('score'),
|
||||
"hash": memory_data.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
db.commit()
|
||||
else:
|
||||
for memory in memories:
|
||||
memory_id = uuid.UUID(memory['id'])
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="search",
|
||||
metadata_={
|
||||
"query": query,
|
||||
"score": memory.get('score'),
|
||||
"hash": memory.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
db.commit()
|
||||
return json.dumps(memories, indent=2)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return f"Error searching memory: {e}"
|
||||
|
||||
|
||||
@mcp.tool(description="List all memories in the user's memory")
|
||||
async def list_memories() -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
# Get all memories
|
||||
memories = memory_client.get_all(user_id=uid)
|
||||
filtered_memories = []
|
||||
|
||||
# Filter memories based on permissions
|
||||
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
||||
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
||||
if isinstance(memories, dict) and 'results' in memories:
|
||||
for memory_data in memories['results']:
|
||||
if 'id' in memory_data:
|
||||
memory_id = uuid.UUID(memory_data['id'])
|
||||
if memory_id in accessible_memory_ids:
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="list",
|
||||
metadata_={
|
||||
"hash": memory_data.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
filtered_memories.append(memory_data)
|
||||
db.commit()
|
||||
else:
|
||||
for memory in memories:
|
||||
memory_id = uuid.UUID(memory['id'])
|
||||
memory_obj = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
if memory_obj and check_memory_access_permissions(db, memory_obj, app.id):
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="list",
|
||||
metadata_={
|
||||
"hash": memory.get('hash')
|
||||
}
|
||||
)
|
||||
db.add(access_log)
|
||||
filtered_memories.append(memory)
|
||||
db.commit()
|
||||
return json.dumps(filtered_memories, indent=2)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
return f"Error getting memories: {e}"
|
||||
|
||||
|
||||
@mcp.tool(description="Delete all memories in the user's memory")
|
||||
async def delete_all_memories() -> str:
|
||||
uid = user_id_var.get(None)
|
||||
client_name = client_name_var.get(None)
|
||||
if not uid:
|
||||
return "Error: user_id not provided"
|
||||
if not client_name:
|
||||
return "Error: client_name not provided"
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get or create user and app
|
||||
user, app = get_user_and_app(db, user_id=uid, app_id=client_name)
|
||||
|
||||
user_memories = db.query(Memory).filter(Memory.user_id == user.id).all()
|
||||
accessible_memory_ids = [memory.id for memory in user_memories if check_memory_access_permissions(db, memory, app.id)]
|
||||
|
||||
# delete the accessible memories only
|
||||
for memory_id in accessible_memory_ids:
|
||||
memory_client.delete(memory_id)
|
||||
|
||||
# Update each memory's state and create history entries
|
||||
now = datetime.datetime.now(datetime.UTC)
|
||||
for memory_id in accessible_memory_ids:
|
||||
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
# Update memory state
|
||||
memory.state = MemoryState.deleted
|
||||
memory.deleted_at = now
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.active,
|
||||
new_state=MemoryState.deleted
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
# Create access log entry
|
||||
access_log = MemoryAccessLog(
|
||||
memory_id=memory_id,
|
||||
app_id=app.id,
|
||||
access_type="delete_all",
|
||||
metadata_={"operation": "bulk_delete"}
|
||||
)
|
||||
db.add(access_log)
|
||||
|
||||
db.commit()
|
||||
return "Successfully deleted all memories"
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
return f"Error deleting memories: {e}"
|
||||
|
||||
|
||||
@mcp_router.get("/{client_name}/sse/{user_id}")
|
||||
async def handle_sse(request: Request):
|
||||
"""Handle SSE connections for a specific user and client"""
|
||||
# Extract user_id and client_name from path parameters
|
||||
uid = request.path_params.get("user_id")
|
||||
user_token = user_id_var.set(uid or "")
|
||||
client_name = request.path_params.get("client_name")
|
||||
client_token = client_name_var.set(client_name or "")
|
||||
|
||||
try:
|
||||
# Handle SSE connection
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send,
|
||||
) as (read_stream, write_stream):
|
||||
await mcp._mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp._mcp_server.create_initialization_options(),
|
||||
)
|
||||
finally:
|
||||
# Clean up context variables
|
||||
user_id_var.reset(user_token)
|
||||
client_name_var.reset(client_token)
|
||||
|
||||
|
||||
@mcp_router.post("/messages/")
|
||||
async def handle_post_message(request: Request):
|
||||
"""Handle POST messages for SSE"""
|
||||
try:
|
||||
body = await request.body()
|
||||
|
||||
# Create a simple receive function that returns the body
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
# Create a simple send function that does nothing
|
||||
async def send(message):
|
||||
pass
|
||||
|
||||
# Call handle_post_message with the correct arguments
|
||||
await sse.handle_post_message(request.scope, receive, send)
|
||||
|
||||
# Return a success response
|
||||
return {"status": "ok"}
|
||||
finally:
|
||||
pass
|
||||
# Clean up context variable
|
||||
# client_name_var.reset(client_token)
|
||||
|
||||
def setup_mcp_server(app: FastAPI):
|
||||
"""Setup MCP server with the FastAPI application"""
|
||||
mcp._mcp_server.name = f"mem0-mcp-server"
|
||||
|
||||
# Include MCP router in the FastAPI app
|
||||
app.include_router(mcp_router)
|
||||
217
openmemory/api/app/models.py
Normal file
217
openmemory/api/app/models.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import enum
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy import (
|
||||
Column, String, Boolean, ForeignKey, Enum, Table,
|
||||
DateTime, JSON, Integer, UUID, Index, event
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
from sqlalchemy.orm import Session
|
||||
from app.utils.categorization import get_categories_for_memory
|
||||
|
||||
|
||||
def get_current_utc_time():
|
||||
"""Get current UTC time"""
|
||||
return datetime.datetime.now(datetime.UTC)
|
||||
|
||||
|
||||
class MemoryState(enum.Enum):
|
||||
active = "active"
|
||||
paused = "paused"
|
||||
archived = "archived"
|
||||
deleted = "deleted"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
user_id = Column(String, nullable=False, unique=True, index=True)
|
||||
name = Column(String, nullable=True, index=True)
|
||||
email = Column(String, unique=True, nullable=True, index=True)
|
||||
metadata_ = Column('metadata', JSON, default=dict)
|
||||
created_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
updated_at = Column(DateTime,
|
||||
default=get_current_utc_time,
|
||||
onupdate=get_current_utc_time)
|
||||
|
||||
apps = relationship("App", back_populates="owner")
|
||||
memories = relationship("Memory", back_populates="user")
|
||||
|
||||
|
||||
class App(Base):
|
||||
__tablename__ = "apps"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
owner_id = Column(UUID, ForeignKey("users.id"), nullable=False, index=True)
|
||||
name = Column(String, unique=True, nullable=False, index=True)
|
||||
description = Column(String)
|
||||
metadata_ = Column('metadata', JSON, default=dict)
|
||||
is_active = Column(Boolean, default=True, index=True)
|
||||
created_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
updated_at = Column(DateTime,
|
||||
default=get_current_utc_time,
|
||||
onupdate=get_current_utc_time)
|
||||
|
||||
owner = relationship("User", back_populates="apps")
|
||||
memories = relationship("Memory", back_populates="app")
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
__tablename__ = "memories"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
user_id = Column(UUID, ForeignKey("users.id"), nullable=False, index=True)
|
||||
app_id = Column(UUID, ForeignKey("apps.id"), nullable=False, index=True)
|
||||
content = Column(String, nullable=False)
|
||||
vector = Column(String)
|
||||
metadata_ = Column('metadata', JSON, default=dict)
|
||||
state = Column(Enum(MemoryState), default=MemoryState.active, index=True)
|
||||
created_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
updated_at = Column(DateTime,
|
||||
default=get_current_utc_time,
|
||||
onupdate=get_current_utc_time)
|
||||
archived_at = Column(DateTime, nullable=True, index=True)
|
||||
deleted_at = Column(DateTime, nullable=True, index=True)
|
||||
|
||||
user = relationship("User", back_populates="memories")
|
||||
app = relationship("App", back_populates="memories")
|
||||
categories = relationship("Category", secondary="memory_categories", back_populates="memories")
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_memory_user_state', 'user_id', 'state'),
|
||||
Index('idx_memory_app_state', 'app_id', 'state'),
|
||||
Index('idx_memory_user_app', 'user_id', 'app_id'),
|
||||
)
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = "categories"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
name = Column(String, unique=True, nullable=False, index=True)
|
||||
description = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC), index=True)
|
||||
updated_at = Column(DateTime,
|
||||
default=get_current_utc_time,
|
||||
onupdate=get_current_utc_time)
|
||||
|
||||
memories = relationship("Memory", secondary="memory_categories", back_populates="categories")
|
||||
|
||||
memory_categories = Table(
|
||||
"memory_categories", Base.metadata,
|
||||
Column("memory_id", UUID, ForeignKey("memories.id"), primary_key=True, index=True),
|
||||
Column("category_id", UUID, ForeignKey("categories.id"), primary_key=True, index=True),
|
||||
Index('idx_memory_category', 'memory_id', 'category_id')
|
||||
)
|
||||
|
||||
|
||||
class AccessControl(Base):
|
||||
__tablename__ = "access_controls"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
subject_type = Column(String, nullable=False, index=True)
|
||||
subject_id = Column(UUID, nullable=True, index=True)
|
||||
object_type = Column(String, nullable=False, index=True)
|
||||
object_id = Column(UUID, nullable=True, index=True)
|
||||
effect = Column(String, nullable=False, index=True)
|
||||
created_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_access_subject', 'subject_type', 'subject_id'),
|
||||
Index('idx_access_object', 'object_type', 'object_id'),
|
||||
)
|
||||
|
||||
|
||||
class ArchivePolicy(Base):
|
||||
__tablename__ = "archive_policies"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
criteria_type = Column(String, nullable=False, index=True)
|
||||
criteria_id = Column(UUID, nullable=True, index=True)
|
||||
days_to_archive = Column(Integer, nullable=False)
|
||||
created_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_policy_criteria', 'criteria_type', 'criteria_id'),
|
||||
)
|
||||
|
||||
|
||||
class MemoryStatusHistory(Base):
|
||||
__tablename__ = "memory_status_history"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
memory_id = Column(UUID, ForeignKey("memories.id"), nullable=False, index=True)
|
||||
changed_by = Column(UUID, ForeignKey("users.id"), nullable=False, index=True)
|
||||
old_state = Column(Enum(MemoryState), nullable=False, index=True)
|
||||
new_state = Column(Enum(MemoryState), nullable=False, index=True)
|
||||
changed_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_history_memory_state', 'memory_id', 'new_state'),
|
||||
Index('idx_history_user_time', 'changed_by', 'changed_at'),
|
||||
)
|
||||
|
||||
|
||||
class MemoryAccessLog(Base):
|
||||
__tablename__ = "memory_access_logs"
|
||||
id = Column(UUID, primary_key=True, default=lambda: uuid.uuid4())
|
||||
memory_id = Column(UUID, ForeignKey("memories.id"), nullable=False, index=True)
|
||||
app_id = Column(UUID, ForeignKey("apps.id"), nullable=False, index=True)
|
||||
accessed_at = Column(DateTime, default=get_current_utc_time, index=True)
|
||||
access_type = Column(String, nullable=False, index=True)
|
||||
metadata_ = Column('metadata', JSON, default=dict)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_access_memory_time', 'memory_id', 'accessed_at'),
|
||||
Index('idx_access_app_time', 'app_id', 'accessed_at'),
|
||||
)
|
||||
|
||||
def categorize_memory(memory: Memory, db: Session) -> None:
|
||||
"""Categorize a memory using OpenAI and store the categories in the database."""
|
||||
try:
|
||||
# Get categories from OpenAI
|
||||
categories = get_categories_for_memory(memory.content)
|
||||
|
||||
# Get or create categories in the database
|
||||
for category_name in categories:
|
||||
category = db.query(Category).filter(Category.name == category_name).first()
|
||||
if not category:
|
||||
category = Category(
|
||||
name=category_name,
|
||||
description=f"Automatically created category for {category_name}"
|
||||
)
|
||||
db.add(category)
|
||||
db.flush() # Flush to get the category ID
|
||||
|
||||
# Check if the memory-category association already exists
|
||||
existing = db.execute(
|
||||
memory_categories.select().where(
|
||||
(memory_categories.c.memory_id == memory.id) &
|
||||
(memory_categories.c.category_id == category.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
# Create the association
|
||||
db.execute(
|
||||
memory_categories.insert().values(
|
||||
memory_id=memory.id,
|
||||
category_id=category.id
|
||||
)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"Error categorizing memory: {e}")
|
||||
|
||||
|
||||
@event.listens_for(Memory, 'after_insert')
|
||||
def after_memory_insert(mapper, connection, target):
|
||||
"""Trigger categorization after a memory is inserted."""
|
||||
db = Session(bind=connection)
|
||||
categorize_memory(target, db)
|
||||
db.close()
|
||||
|
||||
|
||||
@event.listens_for(Memory, 'after_update')
|
||||
def after_memory_update(mapper, connection, target):
|
||||
"""Trigger categorization after a memory is updated."""
|
||||
db = Session(bind=connection)
|
||||
categorize_memory(target, db)
|
||||
db.close()
|
||||
5
openmemory/api/app/routers/__init__.py
Normal file
5
openmemory/api/app/routers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .memories import router as memories_router
|
||||
from .apps import router as apps_router
|
||||
from .stats import router as stats_router
|
||||
|
||||
__all__ = ["memories_router", "apps_router", "stats_router"]
|
||||
223
openmemory/api/app/routers/apps.py
Normal file
223
openmemory/api/app/routers/apps.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import func, desc
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import App, Memory, MemoryAccessLog, MemoryState
|
||||
|
||||
router = APIRouter(prefix="/api/v1/apps", tags=["apps"])
|
||||
|
||||
# Helper functions
|
||||
def get_app_or_404(db: Session, app_id: UUID) -> App:
|
||||
app = db.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise HTTPException(status_code=404, detail="App not found")
|
||||
return app
|
||||
|
||||
# List all apps with filtering
|
||||
@router.get("/")
|
||||
async def list_apps(
|
||||
name: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
sort_by: str = 'name',
|
||||
sort_direction: str = 'asc',
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
# Create a subquery for memory counts
|
||||
memory_counts = db.query(
|
||||
Memory.app_id,
|
||||
func.count(Memory.id).label('memory_count')
|
||||
).filter(
|
||||
Memory.state.in_([MemoryState.active, MemoryState.paused, MemoryState.archived])
|
||||
).group_by(Memory.app_id).subquery()
|
||||
|
||||
# Create a subquery for access counts
|
||||
access_counts = db.query(
|
||||
MemoryAccessLog.app_id,
|
||||
func.count(func.distinct(MemoryAccessLog.memory_id)).label('access_count')
|
||||
).group_by(MemoryAccessLog.app_id).subquery()
|
||||
|
||||
# Base query
|
||||
query = db.query(
|
||||
App,
|
||||
func.coalesce(memory_counts.c.memory_count, 0).label('total_memories_created'),
|
||||
func.coalesce(access_counts.c.access_count, 0).label('total_memories_accessed')
|
||||
)
|
||||
|
||||
# Join with subqueries
|
||||
query = query.outerjoin(
|
||||
memory_counts,
|
||||
App.id == memory_counts.c.app_id
|
||||
).outerjoin(
|
||||
access_counts,
|
||||
App.id == access_counts.c.app_id
|
||||
)
|
||||
|
||||
if name:
|
||||
query = query.filter(App.name.ilike(f"%{name}%"))
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(App.is_active == is_active)
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == 'name':
|
||||
sort_field = App.name
|
||||
elif sort_by == 'memories':
|
||||
sort_field = func.coalesce(memory_counts.c.memory_count, 0)
|
||||
elif sort_by == 'memories_accessed':
|
||||
sort_field = func.coalesce(access_counts.c.access_count, 0)
|
||||
else:
|
||||
sort_field = App.name # default sort
|
||||
|
||||
if sort_direction == 'desc':
|
||||
query = query.order_by(desc(sort_field))
|
||||
else:
|
||||
query = query.order_by(sort_field)
|
||||
|
||||
total = query.count()
|
||||
apps = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"apps": [
|
||||
{
|
||||
"id": app[0].id,
|
||||
"name": app[0].name,
|
||||
"is_active": app[0].is_active,
|
||||
"total_memories_created": app[1],
|
||||
"total_memories_accessed": app[2]
|
||||
}
|
||||
for app in apps
|
||||
]
|
||||
}
|
||||
|
||||
# Get app details
|
||||
@router.get("/{app_id}")
|
||||
async def get_app_details(
|
||||
app_id: UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
app = get_app_or_404(db, app_id)
|
||||
|
||||
# Get memory access statistics
|
||||
access_stats = db.query(
|
||||
func.count(MemoryAccessLog.id).label("total_memories_accessed"),
|
||||
func.min(MemoryAccessLog.accessed_at).label("first_accessed"),
|
||||
func.max(MemoryAccessLog.accessed_at).label("last_accessed")
|
||||
).filter(MemoryAccessLog.app_id == app_id).first()
|
||||
|
||||
return {
|
||||
"is_active": app.is_active,
|
||||
"total_memories_created": db.query(Memory)
|
||||
.filter(Memory.app_id == app_id)
|
||||
.count(),
|
||||
"total_memories_accessed": access_stats.total_memories_accessed or 0,
|
||||
"first_accessed": access_stats.first_accessed,
|
||||
"last_accessed": access_stats.last_accessed
|
||||
}
|
||||
|
||||
# List memories created by app
|
||||
@router.get("/{app_id}/memories")
|
||||
async def list_app_memories(
|
||||
app_id: UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
get_app_or_404(db, app_id)
|
||||
query = db.query(Memory).filter(
|
||||
Memory.app_id == app_id,
|
||||
Memory.state.in_([MemoryState.active, MemoryState.paused, MemoryState.archived])
|
||||
)
|
||||
# Add eager loading for categories
|
||||
query = query.options(joinedload(Memory.categories))
|
||||
total = query.count()
|
||||
memories = query.order_by(Memory.created_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"memories": [
|
||||
{
|
||||
"id": memory.id,
|
||||
"content": memory.content,
|
||||
"created_at": memory.created_at,
|
||||
"state": memory.state.value,
|
||||
"app_id": memory.app_id,
|
||||
"categories": [category.name for category in memory.categories],
|
||||
"metadata_": memory.metadata_
|
||||
}
|
||||
for memory in memories
|
||||
]
|
||||
}
|
||||
|
||||
# List memories accessed by app
|
||||
@router.get("/{app_id}/accessed")
|
||||
async def list_app_accessed_memories(
|
||||
app_id: UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
|
||||
# Get memories with access counts
|
||||
query = db.query(
|
||||
Memory,
|
||||
func.count(MemoryAccessLog.id).label("access_count")
|
||||
).join(
|
||||
MemoryAccessLog,
|
||||
Memory.id == MemoryAccessLog.memory_id
|
||||
).filter(
|
||||
MemoryAccessLog.app_id == app_id
|
||||
).group_by(
|
||||
Memory.id
|
||||
).order_by(
|
||||
desc("access_count")
|
||||
)
|
||||
|
||||
# Add eager loading for categories
|
||||
query = query.options(joinedload(Memory.categories))
|
||||
|
||||
total = query.count()
|
||||
results = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"memories": [
|
||||
{
|
||||
"memory": {
|
||||
"id": memory.id,
|
||||
"content": memory.content,
|
||||
"created_at": memory.created_at,
|
||||
"state": memory.state.value,
|
||||
"app_id": memory.app_id,
|
||||
"app_name": memory.app.name if memory.app else None,
|
||||
"categories": [category.name for category in memory.categories],
|
||||
"metadata_": memory.metadata_
|
||||
},
|
||||
"access_count": count
|
||||
}
|
||||
for memory, count in results
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{app_id}")
|
||||
async def update_app_details(
|
||||
app_id: UUID,
|
||||
is_active: bool,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
app = get_app_or_404(db, app_id)
|
||||
app.is_active = is_active
|
||||
db.commit()
|
||||
return {"status": "success", "message": "Updated app details successfully"}
|
||||
575
openmemory/api/app/routers/memories.py
Normal file
575
openmemory/api/app/routers/memories.py
Normal file
@@ -0,0 +1,575 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Set
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from fastapi_pagination import Page, Params
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import or_, func
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
Memory, MemoryState, MemoryAccessLog, App,
|
||||
MemoryStatusHistory, User, Category, AccessControl
|
||||
)
|
||||
from app.schemas import MemoryResponse, PaginatedMemoryResponse
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
|
||||
router = APIRouter(prefix="/api/v1/memories", tags=["memories"])
|
||||
|
||||
|
||||
def get_memory_or_404(db: Session, memory_id: UUID) -> Memory:
|
||||
memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
if not memory:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
return memory
|
||||
|
||||
|
||||
def update_memory_state(db: Session, memory_id: UUID, new_state: MemoryState, user_id: UUID):
|
||||
memory = get_memory_or_404(db, memory_id)
|
||||
old_state = memory.state
|
||||
|
||||
# Update memory state
|
||||
memory.state = new_state
|
||||
if new_state == MemoryState.archived:
|
||||
memory.archived_at = datetime.now(UTC)
|
||||
elif new_state == MemoryState.deleted:
|
||||
memory.deleted_at = datetime.now(UTC)
|
||||
|
||||
# Record state change
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user_id,
|
||||
old_state=old_state,
|
||||
new_state=new_state
|
||||
)
|
||||
db.add(history)
|
||||
db.commit()
|
||||
return memory
|
||||
|
||||
|
||||
def get_accessible_memory_ids(db: Session, app_id: UUID) -> Set[UUID]:
|
||||
"""
|
||||
Get the set of memory IDs that the app has access to based on app-level ACL rules.
|
||||
Returns all memory IDs if no specific restrictions are found.
|
||||
"""
|
||||
# Get app-level access controls
|
||||
app_access = db.query(AccessControl).filter(
|
||||
AccessControl.subject_type == "app",
|
||||
AccessControl.subject_id == app_id,
|
||||
AccessControl.object_type == "memory"
|
||||
).all()
|
||||
|
||||
# If no app-level rules exist, return None to indicate all memories are accessible
|
||||
if not app_access:
|
||||
return None
|
||||
|
||||
# Initialize sets for allowed and denied memory IDs
|
||||
allowed_memory_ids = set()
|
||||
denied_memory_ids = set()
|
||||
|
||||
# Process app-level rules
|
||||
for rule in app_access:
|
||||
if rule.effect == "allow":
|
||||
if rule.object_id: # Specific memory access
|
||||
allowed_memory_ids.add(rule.object_id)
|
||||
else: # All memories access
|
||||
return None # All memories allowed
|
||||
elif rule.effect == "deny":
|
||||
if rule.object_id: # Specific memory denied
|
||||
denied_memory_ids.add(rule.object_id)
|
||||
else: # All memories denied
|
||||
return set() # No memories accessible
|
||||
|
||||
# Remove denied memories from allowed set
|
||||
if allowed_memory_ids:
|
||||
allowed_memory_ids -= denied_memory_ids
|
||||
|
||||
return allowed_memory_ids
|
||||
|
||||
|
||||
# List all memories with filtering
|
||||
@router.get("/", response_model=Page[MemoryResponse])
|
||||
async def list_memories(
|
||||
user_id: str,
|
||||
app_id: Optional[UUID] = None,
|
||||
from_date: Optional[int] = Query(
|
||||
None,
|
||||
description="Filter memories created after this date (timestamp)",
|
||||
examples=[1718505600]
|
||||
),
|
||||
to_date: Optional[int] = Query(
|
||||
None,
|
||||
description="Filter memories created before this date (timestamp)",
|
||||
examples=[1718505600]
|
||||
),
|
||||
categories: Optional[str] = None,
|
||||
params: Params = Depends(),
|
||||
search_query: Optional[str] = None,
|
||||
sort_column: Optional[str] = Query(None, description="Column to sort by (memory, categories, app_name, created_at)"),
|
||||
sort_direction: Optional[str] = Query(None, description="Sort direction (asc or desc)"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Build base query
|
||||
query = db.query(Memory).filter(
|
||||
Memory.user_id == user.id,
|
||||
Memory.state != MemoryState.deleted,
|
||||
Memory.state != MemoryState.archived,
|
||||
Memory.content.ilike(f"%{search_query}%") if search_query else True
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if app_id:
|
||||
query = query.filter(Memory.app_id == app_id)
|
||||
|
||||
if from_date:
|
||||
from_datetime = datetime.fromtimestamp(from_date, tz=UTC)
|
||||
query = query.filter(Memory.created_at >= from_datetime)
|
||||
|
||||
if to_date:
|
||||
to_datetime = datetime.fromtimestamp(to_date, tz=UTC)
|
||||
query = query.filter(Memory.created_at <= to_datetime)
|
||||
|
||||
# Add joins for app and categories after filtering
|
||||
query = query.outerjoin(App, Memory.app_id == App.id)
|
||||
query = query.outerjoin(Memory.categories)
|
||||
|
||||
# Apply category filter if provided
|
||||
if categories:
|
||||
category_list = [c.strip() for c in categories.split(",")]
|
||||
query = query.filter(Category.name.in_(category_list))
|
||||
|
||||
# Apply sorting if specified
|
||||
if sort_column:
|
||||
sort_field = getattr(Memory, sort_column, None)
|
||||
if sort_field:
|
||||
query = query.order_by(sort_field.desc()) if sort_direction == "desc" else query.order_by(sort_field.asc())
|
||||
|
||||
|
||||
# Get paginated results
|
||||
paginated_results = sqlalchemy_paginate(query, params)
|
||||
|
||||
# Filter results based on permissions
|
||||
filtered_items = []
|
||||
for item in paginated_results.items:
|
||||
if check_memory_access_permissions(db, item, app_id):
|
||||
filtered_items.append(item)
|
||||
|
||||
# Update paginated results with filtered items
|
||||
paginated_results.items = filtered_items
|
||||
paginated_results.total = len(filtered_items)
|
||||
|
||||
return paginated_results
|
||||
|
||||
|
||||
# Get all categories
|
||||
@router.get("/categories")
|
||||
async def get_categories(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get unique categories associated with the user's memories
|
||||
# Get all memories
|
||||
memories = db.query(Memory).filter(Memory.user_id == user.id, Memory.state != MemoryState.deleted, Memory.state != MemoryState.archived).all()
|
||||
# Get all categories from memories
|
||||
categories = [category for memory in memories for category in memory.categories]
|
||||
# Get unique categories
|
||||
unique_categories = list(set(categories))
|
||||
|
||||
return {
|
||||
"categories": unique_categories,
|
||||
"total": len(unique_categories)
|
||||
}
|
||||
|
||||
|
||||
class CreateMemoryRequest(BaseModel):
|
||||
user_id: str
|
||||
text: str
|
||||
metadata: dict = {}
|
||||
infer: bool = True
|
||||
app: str = "openmemory"
|
||||
|
||||
|
||||
# Create new memory
|
||||
@router.post("/")
|
||||
async def create_memory(
|
||||
request: CreateMemoryRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == request.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
# Get or create app
|
||||
app_obj = db.query(App).filter(App.name == request.app).first()
|
||||
if not app_obj:
|
||||
app_obj = App(name=request.app, owner_id=user.id)
|
||||
db.add(app_obj)
|
||||
db.commit()
|
||||
db.refresh(app_obj)
|
||||
|
||||
# Check if app is active
|
||||
if not app_obj.is_active:
|
||||
raise HTTPException(status_code=403, detail=f"App {request.app} is currently paused on OpenMemory. Cannot create new memories.")
|
||||
|
||||
# Create memory
|
||||
memory = Memory(
|
||||
user_id=user.id,
|
||||
app_id=app_obj.id,
|
||||
content=request.text,
|
||||
metadata_=request.metadata
|
||||
)
|
||||
db.add(memory)
|
||||
db.commit()
|
||||
db.refresh(memory)
|
||||
return memory
|
||||
|
||||
|
||||
# Get memory by ID
|
||||
@router.get("/{memory_id}")
|
||||
async def get_memory(
|
||||
memory_id: UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
memory = get_memory_or_404(db, memory_id)
|
||||
return {
|
||||
"id": memory.id,
|
||||
"text": memory.content,
|
||||
"created_at": int(memory.created_at.timestamp()),
|
||||
"state": memory.state.value,
|
||||
"app_id": memory.app_id,
|
||||
"app_name": memory.app.name if memory.app else None,
|
||||
"categories": [category.name for category in memory.categories],
|
||||
"metadata_": memory.metadata_
|
||||
}
|
||||
|
||||
|
||||
class DeleteMemoriesRequest(BaseModel):
|
||||
memory_ids: List[UUID]
|
||||
user_id: str
|
||||
|
||||
# Delete multiple memories
|
||||
@router.delete("/")
|
||||
async def delete_memories(
|
||||
request: DeleteMemoriesRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == request.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
for memory_id in request.memory_ids:
|
||||
update_memory_state(db, memory_id, MemoryState.deleted, user.id)
|
||||
return {"message": f"Successfully deleted {len(request.memory_ids)} memories"}
|
||||
|
||||
|
||||
# Archive memories
|
||||
@router.post("/actions/archive")
|
||||
async def archive_memories(
|
||||
memory_ids: List[UUID],
|
||||
user_id: UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
for memory_id in memory_ids:
|
||||
update_memory_state(db, memory_id, MemoryState.archived, user_id)
|
||||
return {"message": f"Successfully archived {len(memory_ids)} memories"}
|
||||
|
||||
|
||||
class PauseMemoriesRequest(BaseModel):
|
||||
memory_ids: Optional[List[UUID]] = None
|
||||
category_ids: Optional[List[UUID]] = None
|
||||
app_id: Optional[UUID] = None
|
||||
all_for_app: bool = False
|
||||
global_pause: bool = False
|
||||
state: Optional[MemoryState] = None
|
||||
user_id: str
|
||||
|
||||
# Pause access to memories
|
||||
@router.post("/actions/pause")
|
||||
async def pause_memories(
|
||||
request: PauseMemoriesRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
|
||||
global_pause = request.global_pause
|
||||
all_for_app = request.all_for_app
|
||||
app_id = request.app_id
|
||||
memory_ids = request.memory_ids
|
||||
category_ids = request.category_ids
|
||||
state = request.state or MemoryState.paused
|
||||
|
||||
user = db.query(User).filter(User.user_id == request.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user_id = user.id
|
||||
|
||||
if global_pause:
|
||||
# Pause all memories
|
||||
memories = db.query(Memory).filter(
|
||||
Memory.state != MemoryState.deleted,
|
||||
Memory.state != MemoryState.archived
|
||||
).all()
|
||||
for memory in memories:
|
||||
update_memory_state(db, memory.id, state, user_id)
|
||||
return {"message": "Successfully paused all memories"}
|
||||
|
||||
if app_id:
|
||||
# Pause all memories for an app
|
||||
memories = db.query(Memory).filter(
|
||||
Memory.app_id == app_id,
|
||||
Memory.user_id == user.id,
|
||||
Memory.state != MemoryState.deleted,
|
||||
Memory.state != MemoryState.archived
|
||||
).all()
|
||||
for memory in memories:
|
||||
update_memory_state(db, memory.id, state, user_id)
|
||||
return {"message": f"Successfully paused all memories for app {app_id}"}
|
||||
|
||||
if all_for_app and memory_ids:
|
||||
# Pause all memories for an app
|
||||
memories = db.query(Memory).filter(
|
||||
Memory.user_id == user.id,
|
||||
Memory.state != MemoryState.deleted,
|
||||
Memory.id.in_(memory_ids)
|
||||
).all()
|
||||
for memory in memories:
|
||||
update_memory_state(db, memory.id, state, user_id)
|
||||
return {"message": f"Successfully paused all memories"}
|
||||
|
||||
if memory_ids:
|
||||
# Pause specific memories
|
||||
for memory_id in memory_ids:
|
||||
update_memory_state(db, memory_id, state, user_id)
|
||||
return {"message": f"Successfully paused {len(memory_ids)} memories"}
|
||||
|
||||
if category_ids:
|
||||
# Pause memories by category
|
||||
memories = db.query(Memory).join(Memory.categories).filter(
|
||||
Category.id.in_(category_ids),
|
||||
Memory.state != MemoryState.deleted,
|
||||
Memory.state != MemoryState.archived
|
||||
).all()
|
||||
for memory in memories:
|
||||
update_memory_state(db, memory.id, state, user_id)
|
||||
return {"message": f"Successfully paused memories in {len(category_ids)} categories"}
|
||||
|
||||
raise HTTPException(status_code=400, detail="Invalid pause request parameters")
|
||||
|
||||
|
||||
# Get memory access logs
|
||||
@router.get("/{memory_id}/access-log")
|
||||
async def get_memory_access_log(
|
||||
memory_id: UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
query = db.query(MemoryAccessLog).filter(MemoryAccessLog.memory_id == memory_id)
|
||||
total = query.count()
|
||||
logs = query.order_by(MemoryAccessLog.accessed_at.desc()).offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
||||
# Get app name
|
||||
for log in logs:
|
||||
app = db.query(App).filter(App.id == log.app_id).first()
|
||||
log.app_name = app.name if app else None
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"logs": logs
|
||||
}
|
||||
|
||||
|
||||
class UpdateMemoryRequest(BaseModel):
|
||||
memory_content: str
|
||||
user_id: str
|
||||
|
||||
# Update a memory
|
||||
@router.put("/{memory_id}")
|
||||
async def update_memory(
|
||||
memory_id: UUID,
|
||||
request: UpdateMemoryRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == request.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
memory = get_memory_or_404(db, memory_id)
|
||||
memory.content = request.memory_content
|
||||
db.commit()
|
||||
db.refresh(memory)
|
||||
return memory
|
||||
|
||||
class FilterMemoriesRequest(BaseModel):
|
||||
user_id: str
|
||||
page: int = 1
|
||||
size: int = 10
|
||||
search_query: Optional[str] = None
|
||||
app_ids: Optional[List[UUID]] = None
|
||||
category_ids: Optional[List[UUID]] = None
|
||||
sort_column: Optional[str] = None
|
||||
sort_direction: Optional[str] = None
|
||||
from_date: Optional[int] = None
|
||||
to_date: Optional[int] = None
|
||||
show_archived: Optional[bool] = False
|
||||
|
||||
@router.post("/filter", response_model=Page[MemoryResponse])
|
||||
async def filter_memories(
|
||||
request: FilterMemoriesRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == request.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Build base query
|
||||
query = db.query(Memory).filter(
|
||||
Memory.user_id == user.id,
|
||||
Memory.state != MemoryState.deleted,
|
||||
)
|
||||
|
||||
# Filter archived memories based on show_archived parameter
|
||||
if not request.show_archived:
|
||||
query = query.filter(Memory.state != MemoryState.archived)
|
||||
|
||||
# Apply search filter
|
||||
if request.search_query:
|
||||
query = query.filter(Memory.content.ilike(f"%{request.search_query}%"))
|
||||
|
||||
# Apply app filter
|
||||
if request.app_ids:
|
||||
query = query.filter(Memory.app_id.in_(request.app_ids))
|
||||
|
||||
# Add joins for app and categories
|
||||
query = query.outerjoin(App, Memory.app_id == App.id)
|
||||
|
||||
# Apply category filter
|
||||
if request.category_ids:
|
||||
query = query.join(Memory.categories).filter(Category.id.in_(request.category_ids))
|
||||
else:
|
||||
query = query.outerjoin(Memory.categories)
|
||||
|
||||
# Apply date filters
|
||||
if request.from_date:
|
||||
from_datetime = datetime.fromtimestamp(request.from_date, tz=UTC)
|
||||
query = query.filter(Memory.created_at >= from_datetime)
|
||||
|
||||
if request.to_date:
|
||||
to_datetime = datetime.fromtimestamp(request.to_date, tz=UTC)
|
||||
query = query.filter(Memory.created_at <= to_datetime)
|
||||
|
||||
# Apply sorting
|
||||
if request.sort_column and request.sort_direction:
|
||||
sort_direction = request.sort_direction.lower()
|
||||
if sort_direction not in ['asc', 'desc']:
|
||||
raise HTTPException(status_code=400, detail="Invalid sort direction")
|
||||
|
||||
sort_mapping = {
|
||||
'memory': Memory.content,
|
||||
'app_name': App.name,
|
||||
'created_at': Memory.created_at
|
||||
}
|
||||
|
||||
if request.sort_column not in sort_mapping:
|
||||
raise HTTPException(status_code=400, detail="Invalid sort column")
|
||||
|
||||
sort_field = sort_mapping[request.sort_column]
|
||||
if sort_direction == 'desc':
|
||||
query = query.order_by(sort_field.desc())
|
||||
else:
|
||||
query = query.order_by(sort_field.asc())
|
||||
else:
|
||||
# Default sorting
|
||||
query = query.order_by(Memory.created_at.desc())
|
||||
|
||||
# Add eager loading for categories and make the query distinct
|
||||
query = query.options(
|
||||
joinedload(Memory.categories)
|
||||
).distinct(Memory.id)
|
||||
|
||||
# Use fastapi-pagination's paginate function
|
||||
return sqlalchemy_paginate(
|
||||
query,
|
||||
Params(page=request.page, size=request.size),
|
||||
transformer=lambda items: [
|
||||
MemoryResponse(
|
||||
id=memory.id,
|
||||
content=memory.content,
|
||||
created_at=memory.created_at,
|
||||
state=memory.state.value,
|
||||
app_id=memory.app_id,
|
||||
app_name=memory.app.name if memory.app else None,
|
||||
categories=[category.name for category in memory.categories],
|
||||
metadata_=memory.metadata_
|
||||
)
|
||||
for memory in items
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{memory_id}/related", response_model=Page[MemoryResponse])
|
||||
async def get_related_memories(
|
||||
memory_id: UUID,
|
||||
user_id: str,
|
||||
params: Params = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
# Validate user
|
||||
user = db.query(User).filter(User.user_id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get the source memory
|
||||
memory = get_memory_or_404(db, memory_id)
|
||||
|
||||
# Extract category IDs from the source memory
|
||||
category_ids = [category.id for category in memory.categories]
|
||||
|
||||
if not category_ids:
|
||||
return Page.create([], total=0, params=params)
|
||||
|
||||
# Build query for related memories
|
||||
query = db.query(Memory).distinct(Memory.id).filter(
|
||||
Memory.user_id == user.id,
|
||||
Memory.id != memory_id,
|
||||
Memory.state != MemoryState.deleted
|
||||
).join(Memory.categories).filter(
|
||||
Category.id.in_(category_ids)
|
||||
).options(
|
||||
joinedload(Memory.categories),
|
||||
joinedload(Memory.app)
|
||||
).order_by(
|
||||
func.count(Category.id).desc(),
|
||||
Memory.created_at.desc()
|
||||
).group_by(Memory.id)
|
||||
|
||||
# ⚡ Force page size to be 5
|
||||
params = Params(page=params.page, size=5)
|
||||
|
||||
return sqlalchemy_paginate(
|
||||
query,
|
||||
params,
|
||||
transformer=lambda items: [
|
||||
MemoryResponse(
|
||||
id=memory.id,
|
||||
content=memory.content,
|
||||
created_at=memory.created_at,
|
||||
state=memory.state.value,
|
||||
app_id=memory.app_id,
|
||||
app_name=memory.app.name if memory.app else None,
|
||||
categories=[category.name for category in memory.categories],
|
||||
metadata_=memory.metadata_
|
||||
)
|
||||
for memory in items
|
||||
]
|
||||
)
|
||||
30
openmemory/api/app/routers/stats.py
Normal file
30
openmemory/api/app/routers/stats.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models import User, Memory, App, MemoryState
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/stats", tags=["stats"])
|
||||
|
||||
@router.get("/")
|
||||
async def get_profile(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = db.query(User).filter(User.user_id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get total number of memories
|
||||
total_memories = db.query(Memory).filter(Memory.user_id == user.id, Memory.state != MemoryState.deleted).count()
|
||||
|
||||
# Get total number of apps
|
||||
apps = db.query(App).filter(App.owner == user)
|
||||
total_apps = apps.count()
|
||||
|
||||
return {
|
||||
"total_memories": total_memories,
|
||||
"total_apps": total_apps,
|
||||
"apps": apps.all()
|
||||
}
|
||||
|
||||
64
openmemory/api/app/schemas.py
Normal file
64
openmemory/api/app/schemas.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
class MemoryBase(BaseModel):
|
||||
content: str
|
||||
metadata_: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
class MemoryCreate(MemoryBase):
|
||||
user_id: UUID
|
||||
app_id: UUID
|
||||
|
||||
|
||||
class Category(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class App(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
|
||||
|
||||
class Memory(MemoryBase):
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
app_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
state: str
|
||||
categories: Optional[List[Category]] = None
|
||||
app: App
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class MemoryUpdate(BaseModel):
|
||||
content: Optional[str] = None
|
||||
metadata_: Optional[dict] = None
|
||||
state: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
id: UUID
|
||||
content: str
|
||||
created_at: int
|
||||
state: str
|
||||
app_id: UUID
|
||||
app_name: str
|
||||
categories: List[str]
|
||||
metadata_: Optional[dict] = None
|
||||
|
||||
@validator('created_at', pre=True)
|
||||
def convert_to_epoch(cls, v):
|
||||
if isinstance(v, datetime):
|
||||
return int(v.timestamp())
|
||||
return v
|
||||
|
||||
class PaginatedMemoryResponse(BaseModel):
|
||||
items: List[MemoryResponse]
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
pages: int
|
||||
0
openmemory/api/app/utils/__init__.py
Normal file
0
openmemory/api/app/utils/__init__.py
Normal file
37
openmemory/api/app/utils/categorization.py
Normal file
37
openmemory/api/app/utils/categorization.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
from typing import List
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
|
||||
|
||||
load_dotenv()
|
||||
|
||||
openai_client = OpenAI()
|
||||
|
||||
|
||||
class MemoryCategories(BaseModel):
|
||||
categories: List[str]
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
|
||||
def get_categories_for_memory(memory: str) -> List[str]:
|
||||
"""Get categories for a memory."""
|
||||
try:
|
||||
response = openai_client.responses.parse(
|
||||
model="gpt-4o-mini",
|
||||
instructions=MEMORY_CATEGORIZATION_PROMPT,
|
||||
input=memory,
|
||||
temperature=0,
|
||||
text_format=MemoryCategories,
|
||||
)
|
||||
response_json =json.loads(response.output[0].content[0].text)
|
||||
categories = response_json['categories']
|
||||
categories = [cat.strip().lower() for cat in categories]
|
||||
# TODO: Validate categories later may be
|
||||
return categories
|
||||
except Exception as e:
|
||||
raise e
|
||||
32
openmemory/api/app/utils/db.py
Normal file
32
openmemory/api/app/utils/db.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models import User, App
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def get_or_create_user(db: Session, user_id: str) -> User:
|
||||
"""Get or create a user with the given user_id"""
|
||||
user = db.query(User).filter(User.user_id == user_id).first()
|
||||
if not user:
|
||||
user = User(user_id=user_id)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_or_create_app(db: Session, user: User, app_id: str) -> App:
|
||||
"""Get or create an app for the given user"""
|
||||
app = db.query(App).filter(App.owner_id == user.id, App.name == app_id).first()
|
||||
if not app:
|
||||
app = App(owner_id=user.id, name=app_id)
|
||||
db.add(app)
|
||||
db.commit()
|
||||
db.refresh(app)
|
||||
return app
|
||||
|
||||
|
||||
def get_user_and_app(db: Session, user_id: str, app_id: str) -> Tuple[User, App]:
|
||||
"""Get or create both user and their app"""
|
||||
user = get_or_create_user(db, user_id)
|
||||
app = get_or_create_app(db, user, app_id)
|
||||
return user, app
|
||||
51
openmemory/api/app/utils/memory.py
Normal file
51
openmemory/api/app/utils/memory.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
|
||||
from mem0 import Memory
|
||||
|
||||
|
||||
memory_client = None
|
||||
|
||||
|
||||
def get_memory_client(custom_instructions: str = None):
|
||||
"""
|
||||
Get or initialize the Mem0 client.
|
||||
|
||||
Args:
|
||||
custom_instructions: Optional instructions for the memory project.
|
||||
|
||||
Returns:
|
||||
Initialized Mem0 client instance.
|
||||
|
||||
Raises:
|
||||
Exception: If required API keys are not set.
|
||||
"""
|
||||
global memory_client
|
||||
|
||||
if memory_client is not None:
|
||||
return memory_client
|
||||
|
||||
try:
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "openmemory",
|
||||
"host": "mem0_store",
|
||||
"port": 6333,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memory_client = Memory.from_config(config_dict=config)
|
||||
except Exception:
|
||||
raise Exception("Exception occurred while initializing memory client")
|
||||
|
||||
# Update project with custom instructions if provided
|
||||
if custom_instructions:
|
||||
memory_client.update_project(custom_instructions=custom_instructions)
|
||||
|
||||
return memory_client
|
||||
|
||||
|
||||
def get_default_user_id():
|
||||
return "default_user"
|
||||
52
openmemory/api/app/utils/permissions.py
Normal file
52
openmemory/api/app/utils/permissions.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models import Memory, App, MemoryState
|
||||
|
||||
|
||||
def check_memory_access_permissions(
|
||||
db: Session,
|
||||
memory: Memory,
|
||||
app_id: Optional[UUID] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the given app has permission to access a memory based on:
|
||||
1. Memory state (must be active)
|
||||
2. App state (must not be paused)
|
||||
3. App-specific access controls
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
memory: Memory object to check access for
|
||||
app_id: Optional app ID to check permissions for
|
||||
|
||||
Returns:
|
||||
bool: True if access is allowed, False otherwise
|
||||
"""
|
||||
# Check if memory is active
|
||||
if memory.state != MemoryState.active:
|
||||
return False
|
||||
|
||||
# If no app_id provided, only check memory state
|
||||
if not app_id:
|
||||
return True
|
||||
|
||||
# Check if app exists and is active
|
||||
app = db.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
return False
|
||||
|
||||
# Check if app is paused/inactive
|
||||
if not app.is_active:
|
||||
return False
|
||||
|
||||
# Check app-specific access controls
|
||||
from app.routers.memories import get_accessible_memory_ids
|
||||
accessible_memory_ids = get_accessible_memory_ids(db, app_id)
|
||||
|
||||
# If accessible_memory_ids is None, all memories are accessible
|
||||
if accessible_memory_ids is None:
|
||||
return True
|
||||
|
||||
# Check if memory is in the accessible set
|
||||
return memory.id in accessible_memory_ids
|
||||
28
openmemory/api/app/utils/prompts.py
Normal file
28
openmemory/api/app/utils/prompts.py
Normal file
@@ -0,0 +1,28 @@
|
||||
MEMORY_CATEGORIZATION_PROMPT = """Your task is to assign each piece of information (or “memory”) to one or more of the following categories. Feel free to use multiple categories per item when appropriate.
|
||||
|
||||
- Personal: family, friends, home, hobbies, lifestyle
|
||||
- Relationships: social network, significant others, colleagues
|
||||
- Preferences: likes, dislikes, habits, favorite media
|
||||
- Health: physical fitness, mental health, diet, sleep
|
||||
- Travel: trips, commutes, favorite places, itineraries
|
||||
- Work: job roles, companies, projects, promotions
|
||||
- Education: courses, degrees, certifications, skills development
|
||||
- Projects: to‑dos, milestones, deadlines, status updates
|
||||
- AI, ML & Technology: infrastructure, algorithms, tools, research
|
||||
- Technical Support: bug reports, error logs, fixes
|
||||
- Finance: income, expenses, investments, billing
|
||||
- Shopping: purchases, wishlists, returns, deliveries
|
||||
- Legal: contracts, policies, regulations, privacy
|
||||
- Entertainment: movies, music, games, books, events
|
||||
- Messages: emails, SMS, alerts, reminders
|
||||
- Customer Support: tickets, inquiries, resolutions
|
||||
- Product Feedback: ratings, bug reports, feature requests
|
||||
- News: articles, headlines, trending topics
|
||||
- Organization: meetings, appointments, calendars
|
||||
- Goals: ambitions, KPIs, long‑term objectives
|
||||
|
||||
Guidelines:
|
||||
- Return only the categories under 'categories' key in the JSON format.
|
||||
- If you cannot categorize the memory, return an empty list with key 'categories'.
|
||||
- Don't limit yourself to the categories listed above only. Feel free to create new categories based on the memory. Make sure that it is a single phrase.
|
||||
"""
|
||||
Reference in New Issue
Block a user