""" Authentication and authorization for the API """ import os import time from typing import Optional, List from fastapi import HTTPException, Security, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, HTTP_429_TOO_MANY_REQUESTS import hashlib import hmac class APIKeyAuth: """API Key authentication handler""" def __init__(self): self.api_keys = self._load_api_keys() self.admin_keys = self._load_admin_keys() self.security = HTTPBearer() def _load_api_keys(self) -> List[str]: """Load API keys from environment""" keys_str = os.getenv("API_KEYS", "mem0_dev_key_123456789") return [key.strip() for key in keys_str.split(",") if key.strip()] def _load_admin_keys(self) -> List[str]: """Load admin API keys from environment""" keys_str = os.getenv("ADMIN_API_KEYS", "mem0_admin_key_987654321") return [key.strip() for key in keys_str.split(",") if key.strip()] def _validate_api_key_format(self, api_key: str) -> bool: """Validate API key format""" if not api_key.startswith("mem0_"): return False if len(api_key) < 15: # mem0_ + at least 10 chars return False return True def _is_valid_key(self, api_key: str) -> bool: """Check if API key is valid""" return api_key in self.api_keys or api_key in self.admin_keys def _is_admin_key(self, api_key: str) -> bool: """Check if API key has admin privileges""" return api_key in self.admin_keys async def get_api_key(self, credentials: HTTPAuthorizationCredentials = Security(HTTPBearer())) -> str: """Extract and validate API key from request""" if not credentials: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail={ "code": "UNAUTHORIZED", "message": "Missing authorization header", "details": {"required_format": "Bearer mem0_your_api_key"} } ) api_key = credentials.credentials # Validate format if not self._validate_api_key_format(api_key): raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail={ "code": "INVALID_API_KEY_FORMAT", "message": "Invalid API key format", "details": {"expected_format": "mem0_"} } ) # Validate key if not self._is_valid_key(api_key): raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail={ "code": "INVALID_API_KEY", "message": "Invalid API key", "details": {} } ) return api_key async def get_admin_api_key(self, api_key: str = Depends(get_api_key)) -> str: """Validate admin API key""" if not self._is_admin_key(api_key): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail={ "code": "INSUFFICIENT_PERMISSIONS", "message": "Admin API key required", "details": {} } ) return api_key # Global auth instance auth_handler = APIKeyAuth() # Dependency functions for FastAPI async def get_api_key(credentials: HTTPAuthorizationCredentials = Security(HTTPBearer())) -> str: """Get validated API key""" return await auth_handler.get_api_key(credentials) async def get_admin_api_key(api_key: str = Depends(get_api_key)) -> str: """Get validated admin API key""" return await auth_handler.get_admin_api_key(api_key) class RateLimiter: """Simple in-memory rate limiter""" def __init__(self): self.requests = {} # {api_key: [(timestamp, count), ...]} self.max_requests = int(os.getenv("RATE_LIMIT_REQUESTS", "100")) self.window_minutes = int(os.getenv("RATE_LIMIT_WINDOW_MINUTES", "1")) self.window_seconds = self.window_minutes * 60 def _cleanup_old_requests(self, api_key: str, current_time: float): """Remove old requests outside the window""" if api_key not in self.requests: return cutoff_time = current_time - self.window_seconds self.requests[api_key] = [ (timestamp, count) for timestamp, count in self.requests[api_key] if timestamp > cutoff_time ] def check_rate_limit(self, api_key: str) -> tuple[bool, dict]: """Check if request is within rate limit""" current_time = time.time() # Initialize if new key if api_key not in self.requests: self.requests[api_key] = [] # Clean up old requests self._cleanup_old_requests(api_key, current_time) # Count current requests in window current_count = sum(count for _, count in self.requests[api_key]) # Calculate remaining and reset time remaining = max(0, self.max_requests - current_count) reset_time = int(current_time + self.window_seconds) rate_limit_info = { "limit": self.max_requests, "remaining": remaining, "reset": reset_time, "window_minutes": self.window_minutes } if current_count >= self.max_requests: return False, rate_limit_info # Add current request self.requests[api_key].append((current_time, 1)) rate_limit_info["remaining"] = remaining - 1 return True, rate_limit_info # Global rate limiter instance rate_limiter = RateLimiter() async def check_rate_limit(api_key: str = Depends(get_api_key)) -> str: """Rate limiting dependency""" allowed, rate_info = rate_limiter.check_rate_limit(api_key) if not allowed: raise HTTPException( status_code=HTTP_429_TOO_MANY_REQUESTS, detail={ "code": "RATE_LIMIT_EXCEEDED", "message": f"Rate limit exceeded. Maximum {rate_info['limit']} requests per {rate_info['window_minutes']} minute(s)", "details": { "limit": rate_info["limit"], "reset_time": rate_info["reset"], "retry_after": rate_info["window_minutes"] * 60 } }, headers={ "X-RateLimit-Limit": str(rate_info["limit"]), "X-RateLimit-Remaining": str(rate_info["remaining"]), "X-RateLimit-Reset": str(rate_info["reset"]), "Retry-After": str(rate_info["window_minutes"] * 60) } ) return api_key