From fbebe10711b4eb8b55b5539a71062a690dfa465f Mon Sep 17 00:00:00 2001 From: Geutebruck API Developer Date: Tue, 9 Dec 2025 09:04:16 +0100 Subject: [PATCH] Phase 4: Authentication System (T039-T048) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implemented complete JWT-based authentication system with RBAC: **Tests (TDD Approach):** - Created contract tests for /api/v1/auth/login endpoint - Created contract tests for /api/v1/auth/logout endpoint - Created unit tests for AuthService (login, logout, validate_token, password hashing) - Created pytest configuration and fixtures (test DB, test users, tokens) **Schemas:** - LoginRequest: username/password validation - TokenResponse: access_token, refresh_token, user info - LogoutResponse: logout confirmation - RefreshTokenRequest: token refresh payload - UserInfo: user data (excludes password_hash) **Services:** - AuthService: login(), logout(), validate_token(), hash_password(), verify_password() - Integrated bcrypt password hashing - JWT token generation (access + refresh tokens) - Token blacklisting in Redis - Audit logging for all auth operations **Middleware:** - Authentication middleware with JWT validation - Role-based access control (RBAC) helpers - require_role() dependency factory - Convenience dependencies: require_viewer(), require_operator(), require_administrator() - Client IP and User-Agent extraction **Router:** - POST /api/v1/auth/login - Authenticate and get tokens - POST /api/v1/auth/logout - Blacklist token - POST /api/v1/auth/refresh - Refresh access token - GET /api/v1/auth/me - Get current user info **Integration:** - Registered auth router in main.py - Updated startup event to initialize Redis and SDK Bridge clients - Updated shutdown event to cleanup connections properly - Fixed error translation utilities - Added asyncpg dependency for PostgreSQL async driver 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 --- pytest.ini | 30 +++ requirements.txt | 1 + src/api/main.py | 65 +++++- src/api/middleware/auth_middleware.py | 197 ++++++++++++++++ src/api/routers/__init__.py | 3 + src/api/routers/auth.py | 269 ++++++++++++++++++++++ src/api/schemas/__init__.py | 3 + src/api/schemas/auth.py | 145 ++++++++++++ src/api/services/__init__.py | 3 + src/api/services/auth_service.py | 318 ++++++++++++++++++++++++++ src/api/tests/__init__.py | 3 + src/api/tests/conftest.py | 187 +++++++++++++++ src/api/tests/test_auth_api.py | 172 ++++++++++++++ src/api/tests/test_auth_service.py | 266 +++++++++++++++++++++ src/api/utils/error_translation.py | 2 +- 15 files changed, 1651 insertions(+), 13 deletions(-) create mode 100644 pytest.ini create mode 100644 src/api/middleware/auth_middleware.py create mode 100644 src/api/routers/__init__.py create mode 100644 src/api/routers/auth.py create mode 100644 src/api/schemas/__init__.py create mode 100644 src/api/schemas/auth.py create mode 100644 src/api/services/__init__.py create mode 100644 src/api/services/auth_service.py create mode 100644 src/api/tests/__init__.py create mode 100644 src/api/tests/conftest.py create mode 100644 src/api/tests/test_auth_api.py create mode 100644 src/api/tests/test_auth_service.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8f3b95d --- /dev/null +++ b/pytest.ini @@ -0,0 +1,30 @@ +[pytest] +# Pytest configuration +testpaths = src/api/tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +asyncio_mode = auto + +# Add src/api to Python path for imports +pythonpath = src/api + +# Logging +log_cli = true +log_cli_level = INFO +log_cli_format = %(asctime)s [%(levelname)8s] %(message)s +log_cli_date_format = %Y-%m-%d %H:%M:%S + +# Coverage options (if using pytest-cov) +addopts = + --verbose + --strict-markers + --tb=short + --color=yes + +# Markers +markers = + asyncio: mark test as async + unit: mark test as unit test + integration: mark test as integration test + slow: mark test as slow running diff --git a/requirements.txt b/requirements.txt index 4685331..150532e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ python-multipart==0.0.6 sqlalchemy==2.0.25 alembic==1.13.1 psycopg2-binary==2.9.9 +asyncpg==0.29.0 # Redis redis==5.0.1 diff --git a/src/api/main.py b/src/api/main.py index 95419c6..88de33e 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -82,9 +82,25 @@ async def startup_event(): version=settings.API_VERSION, environment=settings.ENVIRONMENT) - # TODO: Initialize database connection pool - # TODO: Initialize Redis connection - # TODO: Initialize gRPC SDK Bridge client + # Initialize Redis connection + try: + from clients.redis_client import redis_client + await redis_client.connect() + logger.info("redis_connected", host=settings.REDIS_HOST, port=settings.REDIS_PORT) + except Exception as e: + logger.error("redis_connection_failed", error=str(e)) + # Non-fatal: API can run without Redis (no caching/token blacklist) + + # Initialize gRPC SDK Bridge client + try: + from clients.sdk_bridge_client import sdk_bridge_client + await sdk_bridge_client.connect() + logger.info("sdk_bridge_connected", url=settings.sdk_bridge_url) + except Exception as e: + logger.error("sdk_bridge_connection_failed", error=str(e)) + # Non-fatal: API can run without SDK Bridge (for testing) + + # Database connection pool is initialized lazily via AsyncSessionLocal logger.info("startup_complete") @@ -94,9 +110,31 @@ async def shutdown_event(): """Cleanup on shutdown""" logger.info("shutdown") - # TODO: Close database connections - # TODO: Close Redis connections - # TODO: Close gRPC connections + # Close Redis connections + try: + from clients.redis_client import redis_client + await redis_client.disconnect() + logger.info("redis_disconnected") + except Exception as e: + logger.error("redis_disconnect_failed", error=str(e)) + + # Close gRPC SDK Bridge connections + try: + from clients.sdk_bridge_client import sdk_bridge_client + await sdk_bridge_client.disconnect() + logger.info("sdk_bridge_disconnected") + except Exception as e: + logger.error("sdk_bridge_disconnect_failed", error=str(e)) + + # Close database connections + try: + from models import engine + await engine.dispose() + logger.info("database_disconnected") + except Exception as e: + logger.error("database_disconnect_failed", error=str(e)) + + logger.info("shutdown_complete") # Health check endpoint @app.get("/health", tags=["system"]) @@ -119,12 +157,15 @@ async def root(): "health": "/health" } -# Register routers (TODO: will add as we implement phases) -# from routers import auth, cameras, monitors, crossswitch -# app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) -# app.include_router(cameras.router, prefix="/api/v1/cameras", tags=["cameras"]) -# app.include_router(monitors.router, prefix="/api/v1/monitors", tags=["monitors"]) -# app.include_router(crossswitch.router, prefix="/api/v1", tags=["crossswitch"]) +# Register routers +from routers import auth +app.include_router(auth.router) + +# TODO: Add remaining routers as phases complete +# from routers import cameras, monitors, crossswitch +# app.include_router(cameras.router) +# app.include_router(monitors.router) +# app.include_router(crossswitch.router) if __name__ == "__main__": import uvicorn diff --git a/src/api/middleware/auth_middleware.py b/src/api/middleware/auth_middleware.py new file mode 100644 index 0000000..79ca0a4 --- /dev/null +++ b/src/api/middleware/auth_middleware.py @@ -0,0 +1,197 @@ +""" +Authentication middleware for protecting endpoints +""" +from fastapi import Request, HTTPException, status +from fastapi.responses import JSONResponse +from typing import Optional, Callable +import structlog + +from services.auth_service import AuthService +from models import AsyncSessionLocal +from models.user import User, UserRole + +logger = structlog.get_logger() + + +async def get_user_from_token(request: Request) -> Optional[User]: + """ + Extract and validate JWT token from request, return user if valid + + Args: + request: FastAPI request object + + Returns: + User object if authenticated, None otherwise + """ + # Extract token from Authorization header + auth_header = request.headers.get("Authorization") + if not auth_header: + return None + + # Check if it's a Bearer token + parts = auth_header.split() + if len(parts) != 2 or parts[0].lower() != "bearer": + return None + + token = parts[1] + + # Validate token and get user + async with AsyncSessionLocal() as db: + auth_service = AuthService(db) + user = await auth_service.validate_token(token) + return user + + +async def require_auth(request: Request, call_next: Callable): + """ + Middleware to require authentication for protected routes + + This middleware should be applied to specific routes via dependencies, + not globally, to allow public endpoints like /health and /docs + """ + user = await get_user_from_token(request) + + if not user: + logger.warning("authentication_required", + path=request.url.path, + method=request.method, + ip=request.client.host if request.client else None) + + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Authentication required" + }, + headers={"WWW-Authenticate": "Bearer"} + ) + + # Add user to request state for downstream handlers + request.state.user = user + request.state.user_id = user.id + + logger.info("authenticated_request", + path=request.url.path, + method=request.method, + user_id=str(user.id), + username=user.username, + role=user.role.value) + + response = await call_next(request) + return response + + +def require_role(required_role: UserRole): + """ + Dependency factory to require specific role + + Usage: + @app.get("/admin-only", dependencies=[Depends(require_role(UserRole.ADMINISTRATOR))]) + + Args: + required_role: Minimum required role + + Returns: + Dependency function + """ + async def role_checker(request: Request) -> User: + user = await get_user_from_token(request) + + if not user: + logger.warning("authentication_required_role_check", + path=request.url.path, + required_role=required_role.value) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required", + headers={"WWW-Authenticate": "Bearer"} + ) + + if not user.has_permission(required_role): + logger.warning("permission_denied", + path=request.url.path, + user_id=str(user.id), + user_role=user.role.value, + required_role=required_role.value) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Requires {required_role.value} role or higher" + ) + + # Add user to request state + request.state.user = user + request.state.user_id = user.id + + return user + + return role_checker + + +# Convenience dependencies for common role checks +async def require_viewer(request: Request) -> User: + """Require at least viewer role (allows all authenticated users)""" + return await require_role(UserRole.VIEWER)(request) + + +async def require_operator(request: Request) -> User: + """Require at least operator role""" + return await require_role(UserRole.OPERATOR)(request) + + +async def require_administrator(request: Request) -> User: + """Require administrator role""" + return await require_role(UserRole.ADMINISTRATOR)(request) + + +def get_current_user(request: Request) -> Optional[User]: + """ + Get currently authenticated user from request state + + This should be used after authentication middleware has run + + Args: + request: FastAPI request object + + Returns: + User object if authenticated, None otherwise + """ + return getattr(request.state, "user", None) + + +def get_client_ip(request: Request) -> Optional[str]: + """ + Extract client IP address from request + + Checks X-Forwarded-For header first (if behind proxy), + then falls back to direct client IP + + Args: + request: FastAPI request object + + Returns: + Client IP address string or None + """ + # Check X-Forwarded-For header (if behind proxy/load balancer) + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + # X-Forwarded-For can contain multiple IPs, take the first + return forwarded_for.split(",")[0].strip() + + # Fall back to direct client IP + if request.client: + return request.client.host + + return None + + +def get_user_agent(request: Request) -> Optional[str]: + """ + Extract user agent from request headers + + Args: + request: FastAPI request object + + Returns: + User agent string or None + """ + return request.headers.get("User-Agent") diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py new file mode 100644 index 0000000..fe574be --- /dev/null +++ b/src/api/routers/__init__.py @@ -0,0 +1,3 @@ +""" +API routers +""" diff --git a/src/api/routers/auth.py b/src/api/routers/auth.py new file mode 100644 index 0000000..dcfa112 --- /dev/null +++ b/src/api/routers/auth.py @@ -0,0 +1,269 @@ +""" +Authentication router for login, logout, and token management +""" +from fastapi import APIRouter, Depends, status, Request +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession +import structlog + +from models import get_db +from schemas.auth import ( + LoginRequest, + TokenResponse, + LogoutResponse, + RefreshTokenRequest, + UserInfo +) +from services.auth_service import AuthService +from middleware.auth_middleware import ( + get_current_user, + get_client_ip, + get_user_agent, + require_viewer +) +from models.user import User + +logger = structlog.get_logger() + +router = APIRouter( + prefix="/api/v1/auth", + tags=["authentication"] +) + + +@router.post( + "/login", + response_model=TokenResponse, + status_code=status.HTTP_200_OK, + summary="User login", + description="Authenticate with username and password to receive JWT tokens" +) +async def login( + request: Request, + credentials: LoginRequest, + db: AsyncSession = Depends(get_db) +): + """ + Authenticate user and return access and refresh tokens + + **Request Body:** + - `username`: User's username + - `password`: User's password + + **Response:** + - `access_token`: JWT access token (short-lived) + - `refresh_token`: JWT refresh token (long-lived) + - `token_type`: Token type (always "bearer") + - `expires_in`: Access token expiration in seconds + - `user`: Authenticated user information + + **Audit Log:** + - Creates audit log entry for login attempt (success or failure) + """ + auth_service = AuthService(db) + + # Get client IP and user agent for audit logging + ip_address = get_client_ip(request) + user_agent = get_user_agent(request) + + # Attempt login + result = await auth_service.login( + username=credentials.username, + password=credentials.password, + ip_address=ip_address, + user_agent=user_agent + ) + + if not result: + logger.warning("login_endpoint_failed", + username=credentials.username, + ip=ip_address) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Invalid username or password" + } + ) + + logger.info("login_endpoint_success", + username=credentials.username, + user_id=result["user"]["id"], + ip=ip_address) + + return result + + +@router.post( + "/logout", + response_model=LogoutResponse, + status_code=status.HTTP_200_OK, + summary="User logout", + description="Logout by blacklisting the current access token", + dependencies=[Depends(require_viewer)] # Requires authentication +) +async def logout( + request: Request, + db: AsyncSession = Depends(get_db) +): + """ + Logout user by blacklisting their access token + + **Authentication Required:** + - Must include valid JWT access token in Authorization header + + **Response:** + - `message`: Logout confirmation message + + **Audit Log:** + - Creates audit log entry for logout + """ + # Extract token from Authorization header + auth_header = request.headers.get("Authorization") + if not auth_header: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Authentication required" + } + ) + + # Extract token (remove "Bearer " prefix) + token = auth_header.split()[1] if len(auth_header.split()) == 2 else None + if not token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Invalid authorization header" + } + ) + + auth_service = AuthService(db) + + # Get client IP and user agent for audit logging + ip_address = get_client_ip(request) + user_agent = get_user_agent(request) + + # Perform logout + success = await auth_service.logout( + token=token, + ip_address=ip_address, + user_agent=user_agent + ) + + if not success: + logger.warning("logout_endpoint_failed", ip=ip_address) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Invalid or expired token" + } + ) + + user = get_current_user(request) + logger.info("logout_endpoint_success", + user_id=str(user.id) if user else None, + username=user.username if user else None, + ip=ip_address) + + return {"message": "Successfully logged out"} + + +@router.post( + "/refresh", + status_code=status.HTTP_200_OK, + summary="Refresh access token", + description="Generate new access token using refresh token" +) +async def refresh_token( + request: Request, + refresh_request: RefreshTokenRequest, + db: AsyncSession = Depends(get_db) +): + """ + Generate new access token from refresh token + + **Request Body:** + - `refresh_token`: Valid JWT refresh token + + **Response:** + - `access_token`: New JWT access token + - `token_type`: Token type (always "bearer") + - `expires_in`: Access token expiration in seconds + + **Note:** + - Refresh token is NOT rotated (same refresh token can be reused) + - For security, consider implementing refresh token rotation in production + """ + auth_service = AuthService(db) + + # Get client IP for logging + ip_address = get_client_ip(request) + + # Refresh token + result = await auth_service.refresh_access_token( + refresh_token=refresh_request.refresh_token, + ip_address=ip_address + ) + + if not result: + logger.warning("refresh_endpoint_failed", ip=ip_address) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Invalid or expired refresh token" + } + ) + + logger.info("refresh_endpoint_success", ip=ip_address) + + return result + + +@router.get( + "/me", + response_model=UserInfo, + status_code=status.HTTP_200_OK, + summary="Get current user", + description="Get information about the currently authenticated user", + dependencies=[Depends(require_viewer)] # Requires authentication +) +async def get_me(request: Request): + """ + Get current authenticated user information + + **Authentication Required:** + - Must include valid JWT access token in Authorization header + + **Response:** + - User information (id, username, role, created_at, updated_at) + + **Note:** + - Password hash is NEVER included in response + """ + user = get_current_user(request) + + if not user: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": "Unauthorized", + "message": "Authentication required" + } + ) + + logger.info("get_me_endpoint", + user_id=str(user.id), + username=user.username) + + return { + "id": str(user.id), + "username": user.username, + "role": user.role.value, + "created_at": user.created_at, + "updated_at": user.updated_at + } diff --git a/src/api/schemas/__init__.py b/src/api/schemas/__init__.py new file mode 100644 index 0000000..070bb08 --- /dev/null +++ b/src/api/schemas/__init__.py @@ -0,0 +1,3 @@ +""" +Pydantic schemas for request/response validation +""" diff --git a/src/api/schemas/auth.py b/src/api/schemas/auth.py new file mode 100644 index 0000000..5f0da93 --- /dev/null +++ b/src/api/schemas/auth.py @@ -0,0 +1,145 @@ +""" +Authentication schemas for request/response validation +""" +from pydantic import BaseModel, Field, field_validator +from typing import Optional +from datetime import datetime + + +class LoginRequest(BaseModel): + """Request schema for user login""" + username: str = Field(..., min_length=1, max_length=50, description="Username") + password: str = Field(..., min_length=1, description="Password") + + @field_validator('username') + @classmethod + def username_not_empty(cls, v: str) -> str: + """Ensure username is not empty or whitespace""" + if not v or not v.strip(): + raise ValueError('Username cannot be empty') + return v.strip() + + @field_validator('password') + @classmethod + def password_not_empty(cls, v: str) -> str: + """Ensure password is not empty""" + if not v: + raise ValueError('Password cannot be empty') + return v + + model_config = { + "json_schema_extra": { + "examples": [ + { + "username": "admin", + "password": "admin123" + } + ] + } + } + + +class UserInfo(BaseModel): + """User information schema (excludes sensitive data)""" + id: str = Field(..., description="User UUID") + username: str = Field(..., description="Username") + role: str = Field(..., description="User role (viewer, operator, administrator)") + created_at: datetime = Field(..., description="Account creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + + model_config = { + "from_attributes": True, + "json_schema_extra": { + "examples": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "username": "admin", + "role": "administrator", + "created_at": "2025-12-08T10:00:00Z", + "updated_at": "2025-12-08T10:00:00Z" + } + ] + } + } + + +class TokenResponse(BaseModel): + """Response schema for successful authentication""" + access_token: str = Field(..., description="JWT access token") + refresh_token: str = Field(..., description="JWT refresh token") + token_type: str = Field(default="bearer", description="Token type (always 'bearer')") + expires_in: int = Field(..., description="Access token expiration time in seconds") + user: UserInfo = Field(..., description="Authenticated user information") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "bearer", + "expires_in": 3600, + "user": { + "id": "550e8400-e29b-41d4-a716-446655440000", + "username": "admin", + "role": "administrator", + "created_at": "2025-12-08T10:00:00Z", + "updated_at": "2025-12-08T10:00:00Z" + } + } + ] + } + } + + +class LogoutResponse(BaseModel): + """Response schema for successful logout""" + message: str = Field(default="Successfully logged out", description="Logout confirmation message") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "message": "Successfully logged out" + } + ] + } + } + + +class RefreshTokenRequest(BaseModel): + """Request schema for token refresh""" + refresh_token: str = Field(..., description="Refresh token") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + } + ] + } + } + + +class TokenValidationResponse(BaseModel): + """Response schema for token validation""" + valid: bool = Field(..., description="Whether the token is valid") + user: Optional[UserInfo] = Field(None, description="User information if token is valid") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "valid": True, + "user": { + "id": "550e8400-e29b-41d4-a716-446655440000", + "username": "admin", + "role": "administrator", + "created_at": "2025-12-08T10:00:00Z", + "updated_at": "2025-12-08T10:00:00Z" + } + } + ] + } + } diff --git a/src/api/services/__init__.py b/src/api/services/__init__.py new file mode 100644 index 0000000..2f5d725 --- /dev/null +++ b/src/api/services/__init__.py @@ -0,0 +1,3 @@ +""" +Business logic services +""" diff --git a/src/api/services/auth_service.py b/src/api/services/auth_service.py new file mode 100644 index 0000000..d3eaa48 --- /dev/null +++ b/src/api/services/auth_service.py @@ -0,0 +1,318 @@ +""" +Authentication service for user login, logout, and token management +""" +from typing import Optional, Dict, Any +from datetime import timedelta +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +from passlib.hash import bcrypt +import structlog + +from models.user import User +from models.audit_log import AuditLog +from utils.jwt_utils import create_access_token, create_refresh_token, verify_token, decode_token +from clients.redis_client import redis_client +from config import settings + +logger = structlog.get_logger() + + +class AuthService: + """Service for authentication operations""" + + def __init__(self, db_session: AsyncSession): + self.db = db_session + + async def login( + self, + username: str, + password: str, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Authenticate user and generate tokens + + Args: + username: Username to authenticate + password: Plain text password + ip_address: Client IP address for audit logging + user_agent: Client user agent for audit logging + + Returns: + Dictionary with tokens and user info, or None if authentication failed + """ + logger.info("login_attempt", username=username, ip_address=ip_address) + + # Find user by username + result = await self.db.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + + if not user: + logger.warning("login_failed_user_not_found", username=username) + # Create audit log for failed login + await self._create_audit_log( + action="auth.login", + target=username, + outcome="failure", + details={"reason": "user_not_found"}, + ip_address=ip_address, + user_agent=user_agent + ) + return None + + # Verify password + if not await self.verify_password(password, user.password_hash): + logger.warning("login_failed_invalid_password", username=username, user_id=str(user.id)) + # Create audit log for failed login + await self._create_audit_log( + action="auth.login", + target=username, + outcome="failure", + details={"reason": "invalid_password"}, + ip_address=ip_address, + user_agent=user_agent, + user_id=user.id + ) + return None + + # Generate tokens + token_data = { + "sub": str(user.id), + "username": user.username, + "role": user.role.value + } + + access_token = create_access_token(token_data) + refresh_token = create_refresh_token(token_data) + + logger.info("login_success", username=username, user_id=str(user.id), role=user.role.value) + + # Create audit log for successful login + await self._create_audit_log( + action="auth.login", + target=username, + outcome="success", + details={"role": user.role.value}, + ip_address=ip_address, + user_agent=user_agent, + user_id=user.id + ) + + # Return token response + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, # Convert to seconds + "user": { + "id": str(user.id), + "username": user.username, + "role": user.role.value, + "created_at": user.created_at, + "updated_at": user.updated_at + } + } + + async def logout( + self, + token: str, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None + ) -> bool: + """ + Logout user by blacklisting their token + + Args: + token: JWT access token to blacklist + ip_address: Client IP address for audit logging + user_agent: Client user agent for audit logging + + Returns: + True if logout successful, False otherwise + """ + # Decode and verify token + payload = decode_token(token) + if not payload: + logger.warning("logout_failed_invalid_token") + return False + + user_id = payload.get("sub") + username = payload.get("username") + + # Calculate remaining TTL for token + exp = payload.get("exp") + if not exp: + logger.warning("logout_failed_no_expiration", user_id=user_id) + return False + + # Blacklist token in Redis with TTL matching token expiration + from datetime import datetime + remaining_seconds = int(exp - datetime.utcnow().timestamp()) + + if remaining_seconds > 0: + blacklist_key = f"blacklist:{token}" + await redis_client.set(blacklist_key, "1", expire=remaining_seconds) + logger.info("token_blacklisted", user_id=user_id, username=username, ttl=remaining_seconds) + + # Create audit log for logout + await self._create_audit_log( + action="auth.logout", + target=username, + outcome="success", + ip_address=ip_address, + user_agent=user_agent, + user_id=user_id + ) + + logger.info("logout_success", user_id=user_id, username=username) + return True + + async def validate_token(self, token: str) -> Optional[User]: + """ + Validate JWT token and return user if valid + + Args: + token: JWT access token + + Returns: + User object if token is valid, None otherwise + """ + # Verify token signature and expiration + payload = verify_token(token, token_type="access") + if not payload: + return None + + # Check if token is blacklisted + blacklist_key = f"blacklist:{token}" + is_blacklisted = await redis_client.get(blacklist_key) + if is_blacklisted: + logger.warning("token_blacklisted_validation_failed", user_id=payload.get("sub")) + return None + + # Get user from database + user_id = payload.get("sub") + if not user_id: + return None + + result = await self.db.execute( + select(User).where(User.id == user_id) + ) + user = result.scalar_one_or_none() + + return user + + async def refresh_access_token( + self, + refresh_token: str, + ip_address: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Generate new access token from refresh token + + Args: + refresh_token: JWT refresh token + ip_address: Client IP address for audit logging + + Returns: + Dictionary with new access token, or None if refresh failed + """ + # Verify refresh token + payload = verify_token(refresh_token, token_type="refresh") + if not payload: + logger.warning("refresh_failed_invalid_token") + return None + + # Check if refresh token is blacklisted + blacklist_key = f"blacklist:{refresh_token}" + is_blacklisted = await redis_client.get(blacklist_key) + if is_blacklisted: + logger.warning("refresh_failed_token_blacklisted", user_id=payload.get("sub")) + return None + + # Generate new access token + token_data = { + "sub": payload.get("sub"), + "username": payload.get("username"), + "role": payload.get("role") + } + + access_token = create_access_token(token_data) + + logger.info("token_refreshed", user_id=payload.get("sub"), username=payload.get("username")) + + return { + "access_token": access_token, + "token_type": "bearer", + "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 + } + + async def hash_password(self, password: str) -> str: + """ + Hash password using bcrypt + + Args: + password: Plain text password + + Returns: + Bcrypt hashed password + """ + return bcrypt.hash(password) + + async def verify_password(self, plain_password: str, hashed_password: str) -> bool: + """ + Verify password against hash + + Args: + plain_password: Plain text password + hashed_password: Bcrypt hashed password + + Returns: + True if password matches, False otherwise + """ + try: + return bcrypt.verify(plain_password, hashed_password) + except Exception as e: + logger.error("password_verification_error", error=str(e)) + return False + + async def _create_audit_log( + self, + action: str, + target: str, + outcome: str, + details: Optional[Dict[str, Any]] = None, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, + user_id: Optional[str] = None + ) -> None: + """ + Create audit log entry + + Args: + action: Action name (e.g., "auth.login") + target: Target of action (e.g., username) + outcome: Outcome ("success", "failure", "error") + details: Additional details as dictionary + ip_address: Client IP address + user_agent: Client user agent + user_id: User UUID (if available) + """ + try: + audit_log = AuditLog( + user_id=user_id, + action=action, + target=target, + outcome=outcome, + details=details, + ip_address=ip_address, + user_agent=user_agent + ) + self.db.add(audit_log) + await self.db.commit() + except Exception as e: + logger.error("audit_log_creation_failed", action=action, error=str(e)) + # Don't let audit log failure break the operation + await self.db.rollback() diff --git a/src/api/tests/__init__.py b/src/api/tests/__init__.py new file mode 100644 index 0000000..e3c0af8 --- /dev/null +++ b/src/api/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests package +""" diff --git a/src/api/tests/conftest.py b/src/api/tests/conftest.py new file mode 100644 index 0000000..545737b --- /dev/null +++ b/src/api/tests/conftest.py @@ -0,0 +1,187 @@ +""" +Pytest fixtures for testing +""" +import pytest +import pytest_asyncio +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from datetime import datetime, timedelta +import jwt + +from main import app +from config import settings +from models import Base, get_db +from models.user import User, UserRole +from utils.jwt_utils import create_access_token +import uuid + + +# Test database URL - use separate test database +TEST_DATABASE_URL = settings.DATABASE_URL.replace("/geutebruck_api", "/geutebruck_api_test") + + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for the test session""" + import asyncio + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture(scope="function") +async def test_db_engine(): + """Create test database engine""" + engine = create_async_engine(TEST_DATABASE_URL, echo=False) + + # Create all tables + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + # Drop all tables after test + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest_asyncio.fixture(scope="function") +async def test_db_session(test_db_engine): + """Create test database session""" + AsyncTestingSessionLocal = async_sessionmaker( + test_db_engine, + class_=AsyncSession, + expire_on_commit=False + ) + + async with AsyncTestingSessionLocal() as session: + yield session + + +@pytest_asyncio.fixture(scope="function") +async def async_client(test_db_session): + """Create async HTTP client for testing""" + + # Override the get_db dependency to use test database + async def override_get_db(): + yield test_db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient(app=app, base_url="http://test") as client: + yield client + + # Clear overrides + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture(scope="function") +async def test_admin_user(test_db_session): + """Create test admin user""" + from passlib.hash import bcrypt + + user = User( + id=uuid.uuid4(), + username="admin", + password_hash=bcrypt.hash("admin123"), + role=UserRole.ADMINISTRATOR, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + test_db_session.add(user) + await test_db_session.commit() + await test_db_session.refresh(user) + + return user + + +@pytest_asyncio.fixture(scope="function") +async def test_operator_user(test_db_session): + """Create test operator user""" + from passlib.hash import bcrypt + + user = User( + id=uuid.uuid4(), + username="operator", + password_hash=bcrypt.hash("operator123"), + role=UserRole.OPERATOR, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + test_db_session.add(user) + await test_db_session.commit() + await test_db_session.refresh(user) + + return user + + +@pytest_asyncio.fixture(scope="function") +async def test_viewer_user(test_db_session): + """Create test viewer user""" + from passlib.hash import bcrypt + + user = User( + id=uuid.uuid4(), + username="viewer", + password_hash=bcrypt.hash("viewer123"), + role=UserRole.VIEWER, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + test_db_session.add(user) + await test_db_session.commit() + await test_db_session.refresh(user) + + return user + + +@pytest.fixture +def auth_token(test_admin_user): + """Generate valid authentication token for admin user""" + token_data = { + "sub": str(test_admin_user.id), + "username": test_admin_user.username, + "role": test_admin_user.role.value + } + return create_access_token(token_data) + + +@pytest.fixture +def operator_token(test_operator_user): + """Generate valid authentication token for operator user""" + token_data = { + "sub": str(test_operator_user.id), + "username": test_operator_user.username, + "role": test_operator_user.role.value + } + return create_access_token(token_data) + + +@pytest.fixture +def viewer_token(test_viewer_user): + """Generate valid authentication token for viewer user""" + token_data = { + "sub": str(test_viewer_user.id), + "username": test_viewer_user.username, + "role": test_viewer_user.role.value + } + return create_access_token(token_data) + + +@pytest.fixture +def expired_token(): + """Generate expired authentication token""" + token_data = { + "sub": str(uuid.uuid4()), + "username": "testuser", + "role": "viewer", + "exp": datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + "iat": datetime.utcnow() - timedelta(hours=2), + "type": "access" + } + return jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) diff --git a/src/api/tests/test_auth_api.py b/src/api/tests/test_auth_api.py new file mode 100644 index 0000000..adf7f53 --- /dev/null +++ b/src/api/tests/test_auth_api.py @@ -0,0 +1,172 @@ +""" +Contract tests for authentication API endpoints +These tests define the expected behavior - they will FAIL until implementation is complete +""" +import pytest +from httpx import AsyncClient +from fastapi import status +from main import app + + +@pytest.mark.asyncio +class TestAuthLogin: + """Contract tests for POST /api/v1/auth/login""" + + async def test_login_success(self, async_client: AsyncClient): + """Test successful login with valid credentials""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "username": "admin", + "password": "admin123" + } + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + # Verify response structure + assert "access_token" in data + assert "refresh_token" in data + assert "token_type" in data + assert "expires_in" in data + assert "user" in data + + # Verify token type + assert data["token_type"] == "bearer" + + # Verify user info + assert data["user"]["username"] == "admin" + assert data["user"]["role"] == "administrator" + assert "password_hash" not in data["user"] # Never expose password hash + + async def test_login_invalid_username(self, async_client: AsyncClient): + """Test login with non-existent username""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "username": "nonexistent", + "password": "somepassword" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + data = response.json() + assert "error" in data + assert data["error"] == "Unauthorized" + + async def test_login_invalid_password(self, async_client: AsyncClient): + """Test login with incorrect password""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "username": "admin", + "password": "wrongpassword" + } + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + data = response.json() + assert "error" in data + + async def test_login_missing_username(self, async_client: AsyncClient): + """Test login with missing username field""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "password": "admin123" + } + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + async def test_login_missing_password(self, async_client: AsyncClient): + """Test login with missing password field""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "username": "admin" + } + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + async def test_login_empty_username(self, async_client: AsyncClient): + """Test login with empty username""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "username": "", + "password": "admin123" + } + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + async def test_login_empty_password(self, async_client: AsyncClient): + """Test login with empty password""" + response = await async_client.post( + "/api/v1/auth/login", + json={ + "username": "admin", + "password": "" + } + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +class TestAuthLogout: + """Contract tests for POST /api/v1/auth/logout""" + + async def test_logout_success(self, async_client: AsyncClient, auth_token: str): + """Test successful logout with valid token""" + response = await async_client.post( + "/api/v1/auth/logout", + headers={"Authorization": f"Bearer {auth_token}"} + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["message"] == "Successfully logged out" + + async def test_logout_no_token(self, async_client: AsyncClient): + """Test logout without authentication token""" + response = await async_client.post("/api/v1/auth/logout") + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_logout_invalid_token(self, async_client: AsyncClient): + """Test logout with invalid token""" + response = await async_client.post( + "/api/v1/auth/logout", + headers={"Authorization": "Bearer invalid_token_here"} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_logout_expired_token(self, async_client: AsyncClient, expired_token: str): + """Test logout with expired token""" + response = await async_client.post( + "/api/v1/auth/logout", + headers={"Authorization": f"Bearer {expired_token}"} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@pytest.mark.asyncio +class TestAuthProtectedEndpoint: + """Test authentication middleware on protected endpoints""" + + async def test_protected_endpoint_with_valid_token(self, async_client: AsyncClient, auth_token: str): + """Test accessing protected endpoint with valid token""" + # This will be used to test any protected endpoint once we have them + # For now, we'll test with a mock protected endpoint + pass + + async def test_protected_endpoint_without_token(self, async_client: AsyncClient): + """Test accessing protected endpoint without token""" + # Will be implemented when we have actual protected endpoints + pass diff --git a/src/api/tests/test_auth_service.py b/src/api/tests/test_auth_service.py new file mode 100644 index 0000000..4889da1 --- /dev/null +++ b/src/api/tests/test_auth_service.py @@ -0,0 +1,266 @@ +""" +Unit tests for AuthService +These tests will FAIL until AuthService is implemented +""" +import pytest +from datetime import datetime, timedelta +import uuid + +from services.auth_service import AuthService +from models.user import User, UserRole + + +@pytest.mark.asyncio +class TestAuthServiceLogin: + """Unit tests for AuthService.login()""" + + async def test_login_success(self, test_db_session, test_admin_user): + """Test successful login with valid credentials""" + auth_service = AuthService(test_db_session) + + result = await auth_service.login("admin", "admin123", ip_address="127.0.0.1") + + assert result is not None + assert "access_token" in result + assert "refresh_token" in result + assert "token_type" in result + assert result["token_type"] == "bearer" + assert "expires_in" in result + assert "user" in result + assert result["user"]["username"] == "admin" + assert result["user"]["role"] == "administrator" + + async def test_login_invalid_username(self, test_db_session): + """Test login with non-existent username""" + auth_service = AuthService(test_db_session) + + result = await auth_service.login("nonexistent", "somepassword", ip_address="127.0.0.1") + + assert result is None + + async def test_login_invalid_password(self, test_db_session, test_admin_user): + """Test login with incorrect password""" + auth_service = AuthService(test_db_session) + + result = await auth_service.login("admin", "wrongpassword", ip_address="127.0.0.1") + + assert result is None + + async def test_login_operator(self, test_db_session, test_operator_user): + """Test successful login for operator role""" + auth_service = AuthService(test_db_session) + + result = await auth_service.login("operator", "operator123", ip_address="127.0.0.1") + + assert result is not None + assert result["user"]["role"] == "operator" + + async def test_login_viewer(self, test_db_session, test_viewer_user): + """Test successful login for viewer role""" + auth_service = AuthService(test_db_session) + + result = await auth_service.login("viewer", "viewer123", ip_address="127.0.0.1") + + assert result is not None + assert result["user"]["role"] == "viewer" + + +@pytest.mark.asyncio +class TestAuthServiceLogout: + """Unit tests for AuthService.logout()""" + + async def test_logout_success(self, test_db_session, test_admin_user, auth_token): + """Test successful logout""" + auth_service = AuthService(test_db_session) + + # Logout should add token to blacklist + result = await auth_service.logout(auth_token, ip_address="127.0.0.1") + + assert result is True + + async def test_logout_invalid_token(self, test_db_session): + """Test logout with invalid token""" + auth_service = AuthService(test_db_session) + + result = await auth_service.logout("invalid_token", ip_address="127.0.0.1") + + assert result is False + + async def test_logout_expired_token(self, test_db_session, expired_token): + """Test logout with expired token""" + auth_service = AuthService(test_db_session) + + result = await auth_service.logout(expired_token, ip_address="127.0.0.1") + + assert result is False + + +@pytest.mark.asyncio +class TestAuthServiceValidateToken: + """Unit tests for AuthService.validate_token()""" + + async def test_validate_token_success(self, test_db_session, test_admin_user, auth_token): + """Test validation of valid token""" + auth_service = AuthService(test_db_session) + + user = await auth_service.validate_token(auth_token) + + assert user is not None + assert isinstance(user, User) + assert user.username == "admin" + assert user.role == UserRole.ADMINISTRATOR + + async def test_validate_token_invalid(self, test_db_session): + """Test validation of invalid token""" + auth_service = AuthService(test_db_session) + + user = await auth_service.validate_token("invalid_token") + + assert user is None + + async def test_validate_token_expired(self, test_db_session, expired_token): + """Test validation of expired token""" + auth_service = AuthService(test_db_session) + + user = await auth_service.validate_token(expired_token) + + assert user is None + + async def test_validate_token_blacklisted(self, test_db_session, test_admin_user, auth_token): + """Test validation of blacklisted token (after logout)""" + auth_service = AuthService(test_db_session) + + # First logout to blacklist the token + await auth_service.logout(auth_token, ip_address="127.0.0.1") + + # Then try to validate it + user = await auth_service.validate_token(auth_token) + + assert user is None + + +@pytest.mark.asyncio +class TestAuthServicePasswordHashing: + """Unit tests for password hashing and verification""" + + async def test_hash_password(self, test_db_session): + """Test password hashing""" + auth_service = AuthService(test_db_session) + + plain_password = "mypassword123" + hashed = await auth_service.hash_password(plain_password) + + # Hash should not equal plain text + assert hashed != plain_password + # Hash should start with bcrypt identifier + assert hashed.startswith("$2b$") + + async def test_verify_password_success(self, test_db_session): + """Test successful password verification""" + auth_service = AuthService(test_db_session) + + plain_password = "mypassword123" + hashed = await auth_service.hash_password(plain_password) + + # Verification should succeed + result = await auth_service.verify_password(plain_password, hashed) + assert result is True + + async def test_verify_password_failure(self, test_db_session): + """Test failed password verification""" + auth_service = AuthService(test_db_session) + + plain_password = "mypassword123" + hashed = await auth_service.hash_password(plain_password) + + # Verification with wrong password should fail + result = await auth_service.verify_password("wrongpassword", hashed) + assert result is False + + async def test_hash_password_different_each_time(self, test_db_session): + """Test that same password produces different hashes (due to salt)""" + auth_service = AuthService(test_db_session) + + plain_password = "mypassword123" + hash1 = await auth_service.hash_password(plain_password) + hash2 = await auth_service.hash_password(plain_password) + + # Hashes should be different (bcrypt uses random salt) + assert hash1 != hash2 + + # But both should verify successfully + assert await auth_service.verify_password(plain_password, hash1) + assert await auth_service.verify_password(plain_password, hash2) + + +@pytest.mark.asyncio +class TestAuthServiceAuditLogging: + """Unit tests for audit logging in AuthService""" + + async def test_login_success_creates_audit_log(self, test_db_session, test_admin_user): + """Test that successful login creates audit log entry""" + from models.audit_log import AuditLog + from sqlalchemy import select + + auth_service = AuthService(test_db_session) + + # Perform login + await auth_service.login("admin", "admin123", ip_address="192.168.1.100") + + # Check audit log was created + result = await test_db_session.execute( + select(AuditLog).where(AuditLog.action == "auth.login") + ) + audit_logs = result.scalars().all() + + assert len(audit_logs) >= 1 + audit_log = audit_logs[-1] # Get most recent + assert audit_log.action == "auth.login" + assert audit_log.target == "admin" + assert audit_log.outcome == "success" + assert audit_log.ip_address == "192.168.1.100" + + async def test_login_failure_creates_audit_log(self, test_db_session): + """Test that failed login creates audit log entry""" + from models.audit_log import AuditLog + from sqlalchemy import select + + auth_service = AuthService(test_db_session) + + # Attempt login with invalid credentials + await auth_service.login("admin", "wrongpassword", ip_address="192.168.1.100") + + # Check audit log was created + result = await test_db_session.execute( + select(AuditLog).where(AuditLog.action == "auth.login").where(AuditLog.outcome == "failure") + ) + audit_logs = result.scalars().all() + + assert len(audit_logs) >= 1 + audit_log = audit_logs[-1] + assert audit_log.action == "auth.login" + assert audit_log.target == "admin" + assert audit_log.outcome == "failure" + assert audit_log.ip_address == "192.168.1.100" + + async def test_logout_creates_audit_log(self, test_db_session, test_admin_user, auth_token): + """Test that logout creates audit log entry""" + from models.audit_log import AuditLog + from sqlalchemy import select + + auth_service = AuthService(test_db_session) + + # Perform logout + await auth_service.logout(auth_token, ip_address="192.168.1.100") + + # Check audit log was created + result = await test_db_session.execute( + select(AuditLog).where(AuditLog.action == "auth.logout") + ) + audit_logs = result.scalars().all() + + assert len(audit_logs) >= 1 + audit_log = audit_logs[-1] + assert audit_log.action == "auth.logout" + assert audit_log.outcome == "success" + assert audit_log.ip_address == "192.168.1.100" diff --git a/src/api/utils/error_translation.py b/src/api/utils/error_translation.py index ab3e41b..85c5ffa 100644 --- a/src/api/utils/error_translation.py +++ b/src/api/utils/error_translation.py @@ -2,7 +2,7 @@ Error translation utilities Maps gRPC errors to HTTP status codes and user-friendly messages """ -from typing import Tuple +from typing import Tuple, Any import grpc from fastapi import status