Phase 4: Authentication System (T039-T048)
Implemented complete JWT-based authentication system with RBAC: **Tests (TDD Approach):** - Created contract tests for /api/v1/auth/login endpoint - Created contract tests for /api/v1/auth/logout endpoint - Created unit tests for AuthService (login, logout, validate_token, password hashing) - Created pytest configuration and fixtures (test DB, test users, tokens) **Schemas:** - LoginRequest: username/password validation - TokenResponse: access_token, refresh_token, user info - LogoutResponse: logout confirmation - RefreshTokenRequest: token refresh payload - UserInfo: user data (excludes password_hash) **Services:** - AuthService: login(), logout(), validate_token(), hash_password(), verify_password() - Integrated bcrypt password hashing - JWT token generation (access + refresh tokens) - Token blacklisting in Redis - Audit logging for all auth operations **Middleware:** - Authentication middleware with JWT validation - Role-based access control (RBAC) helpers - require_role() dependency factory - Convenience dependencies: require_viewer(), require_operator(), require_administrator() - Client IP and User-Agent extraction **Router:** - POST /api/v1/auth/login - Authenticate and get tokens - POST /api/v1/auth/logout - Blacklist token - POST /api/v1/auth/refresh - Refresh access token - GET /api/v1/auth/me - Get current user info **Integration:** - Registered auth router in main.py - Updated startup event to initialize Redis and SDK Bridge clients - Updated shutdown event to cleanup connections properly - Fixed error translation utilities - Added asyncpg dependency for PostgreSQL async driver 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -82,9 +82,25 @@ async def startup_event():
|
||||
version=settings.API_VERSION,
|
||||
environment=settings.ENVIRONMENT)
|
||||
|
||||
# TODO: Initialize database connection pool
|
||||
# TODO: Initialize Redis connection
|
||||
# TODO: Initialize gRPC SDK Bridge client
|
||||
# Initialize Redis connection
|
||||
try:
|
||||
from clients.redis_client import redis_client
|
||||
await redis_client.connect()
|
||||
logger.info("redis_connected", host=settings.REDIS_HOST, port=settings.REDIS_PORT)
|
||||
except Exception as e:
|
||||
logger.error("redis_connection_failed", error=str(e))
|
||||
# Non-fatal: API can run without Redis (no caching/token blacklist)
|
||||
|
||||
# Initialize gRPC SDK Bridge client
|
||||
try:
|
||||
from clients.sdk_bridge_client import sdk_bridge_client
|
||||
await sdk_bridge_client.connect()
|
||||
logger.info("sdk_bridge_connected", url=settings.sdk_bridge_url)
|
||||
except Exception as e:
|
||||
logger.error("sdk_bridge_connection_failed", error=str(e))
|
||||
# Non-fatal: API can run without SDK Bridge (for testing)
|
||||
|
||||
# Database connection pool is initialized lazily via AsyncSessionLocal
|
||||
|
||||
logger.info("startup_complete")
|
||||
|
||||
@@ -94,9 +110,31 @@ async def shutdown_event():
|
||||
"""Cleanup on shutdown"""
|
||||
logger.info("shutdown")
|
||||
|
||||
# TODO: Close database connections
|
||||
# TODO: Close Redis connections
|
||||
# TODO: Close gRPC connections
|
||||
# Close Redis connections
|
||||
try:
|
||||
from clients.redis_client import redis_client
|
||||
await redis_client.disconnect()
|
||||
logger.info("redis_disconnected")
|
||||
except Exception as e:
|
||||
logger.error("redis_disconnect_failed", error=str(e))
|
||||
|
||||
# Close gRPC SDK Bridge connections
|
||||
try:
|
||||
from clients.sdk_bridge_client import sdk_bridge_client
|
||||
await sdk_bridge_client.disconnect()
|
||||
logger.info("sdk_bridge_disconnected")
|
||||
except Exception as e:
|
||||
logger.error("sdk_bridge_disconnect_failed", error=str(e))
|
||||
|
||||
# Close database connections
|
||||
try:
|
||||
from models import engine
|
||||
await engine.dispose()
|
||||
logger.info("database_disconnected")
|
||||
except Exception as e:
|
||||
logger.error("database_disconnect_failed", error=str(e))
|
||||
|
||||
logger.info("shutdown_complete")
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health", tags=["system"])
|
||||
@@ -119,12 +157,15 @@ async def root():
|
||||
"health": "/health"
|
||||
}
|
||||
|
||||
# Register routers (TODO: will add as we implement phases)
|
||||
# from routers import auth, cameras, monitors, crossswitch
|
||||
# app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
||||
# app.include_router(cameras.router, prefix="/api/v1/cameras", tags=["cameras"])
|
||||
# app.include_router(monitors.router, prefix="/api/v1/monitors", tags=["monitors"])
|
||||
# app.include_router(crossswitch.router, prefix="/api/v1", tags=["crossswitch"])
|
||||
# Register routers
|
||||
from routers import auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# TODO: Add remaining routers as phases complete
|
||||
# from routers import cameras, monitors, crossswitch
|
||||
# app.include_router(cameras.router)
|
||||
# app.include_router(monitors.router)
|
||||
# app.include_router(crossswitch.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
197
src/api/middleware/auth_middleware.py
Normal file
197
src/api/middleware/auth_middleware.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Authentication middleware for protecting endpoints
|
||||
"""
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, Callable
|
||||
import structlog
|
||||
|
||||
from services.auth_service import AuthService
|
||||
from models import AsyncSessionLocal
|
||||
from models.user import User, UserRole
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_user_from_token(request: Request) -> Optional[User]:
|
||||
"""
|
||||
Extract and validate JWT token from request, return user if valid
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
User object if authenticated, None otherwise
|
||||
"""
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return None
|
||||
|
||||
# Check if it's a Bearer token
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||
return None
|
||||
|
||||
token = parts[1]
|
||||
|
||||
# Validate token and get user
|
||||
async with AsyncSessionLocal() as db:
|
||||
auth_service = AuthService(db)
|
||||
user = await auth_service.validate_token(token)
|
||||
return user
|
||||
|
||||
|
||||
async def require_auth(request: Request, call_next: Callable):
|
||||
"""
|
||||
Middleware to require authentication for protected routes
|
||||
|
||||
This middleware should be applied to specific routes via dependencies,
|
||||
not globally, to allow public endpoints like /health and /docs
|
||||
"""
|
||||
user = await get_user_from_token(request)
|
||||
|
||||
if not user:
|
||||
logger.warning("authentication_required",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
ip=request.client.host if request.client else None)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required"
|
||||
},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
# Add user to request state for downstream handlers
|
||||
request.state.user = user
|
||||
request.state.user_id = user.id
|
||||
|
||||
logger.info("authenticated_request",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
user_id=str(user.id),
|
||||
username=user.username,
|
||||
role=user.role.value)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
def require_role(required_role: UserRole):
|
||||
"""
|
||||
Dependency factory to require specific role
|
||||
|
||||
Usage:
|
||||
@app.get("/admin-only", dependencies=[Depends(require_role(UserRole.ADMINISTRATOR))])
|
||||
|
||||
Args:
|
||||
required_role: Minimum required role
|
||||
|
||||
Returns:
|
||||
Dependency function
|
||||
"""
|
||||
async def role_checker(request: Request) -> User:
|
||||
user = await get_user_from_token(request)
|
||||
|
||||
if not user:
|
||||
logger.warning("authentication_required_role_check",
|
||||
path=request.url.path,
|
||||
required_role=required_role.value)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
if not user.has_permission(required_role):
|
||||
logger.warning("permission_denied",
|
||||
path=request.url.path,
|
||||
user_id=str(user.id),
|
||||
user_role=user.role.value,
|
||||
required_role=required_role.value)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Requires {required_role.value} role or higher"
|
||||
)
|
||||
|
||||
# Add user to request state
|
||||
request.state.user = user
|
||||
request.state.user_id = user.id
|
||||
|
||||
return user
|
||||
|
||||
return role_checker
|
||||
|
||||
|
||||
# Convenience dependencies for common role checks
|
||||
async def require_viewer(request: Request) -> User:
|
||||
"""Require at least viewer role (allows all authenticated users)"""
|
||||
return await require_role(UserRole.VIEWER)(request)
|
||||
|
||||
|
||||
async def require_operator(request: Request) -> User:
|
||||
"""Require at least operator role"""
|
||||
return await require_role(UserRole.OPERATOR)(request)
|
||||
|
||||
|
||||
async def require_administrator(request: Request) -> User:
|
||||
"""Require administrator role"""
|
||||
return await require_role(UserRole.ADMINISTRATOR)(request)
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Optional[User]:
|
||||
"""
|
||||
Get currently authenticated user from request state
|
||||
|
||||
This should be used after authentication middleware has run
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
User object if authenticated, None otherwise
|
||||
"""
|
||||
return getattr(request.state, "user", None)
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract client IP address from request
|
||||
|
||||
Checks X-Forwarded-For header first (if behind proxy),
|
||||
then falls back to direct client IP
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Client IP address string or None
|
||||
"""
|
||||
# Check X-Forwarded-For header (if behind proxy/load balancer)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For can contain multiple IPs, take the first
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Fall back to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_user_agent(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract user agent from request headers
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
User agent string or None
|
||||
"""
|
||||
return request.headers.get("User-Agent")
|
||||
3
src/api/routers/__init__.py
Normal file
3
src/api/routers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API routers
|
||||
"""
|
||||
269
src/api/routers/auth.py
Normal file
269
src/api/routers/auth.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Authentication router for login, logout, and token management
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, status, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from models import get_db
|
||||
from schemas.auth import (
|
||||
LoginRequest,
|
||||
TokenResponse,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
UserInfo
|
||||
)
|
||||
from services.auth_service import AuthService
|
||||
from middleware.auth_middleware import (
|
||||
get_current_user,
|
||||
get_client_ip,
|
||||
get_user_agent,
|
||||
require_viewer
|
||||
)
|
||||
from models.user import User
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/auth",
|
||||
tags=["authentication"]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=TokenResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="User login",
|
||||
description="Authenticate with username and password to receive JWT tokens"
|
||||
)
|
||||
async def login(
|
||||
request: Request,
|
||||
credentials: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Authenticate user and return access and refresh tokens
|
||||
|
||||
**Request Body:**
|
||||
- `username`: User's username
|
||||
- `password`: User's password
|
||||
|
||||
**Response:**
|
||||
- `access_token`: JWT access token (short-lived)
|
||||
- `refresh_token`: JWT refresh token (long-lived)
|
||||
- `token_type`: Token type (always "bearer")
|
||||
- `expires_in`: Access token expiration in seconds
|
||||
- `user`: Authenticated user information
|
||||
|
||||
**Audit Log:**
|
||||
- Creates audit log entry for login attempt (success or failure)
|
||||
"""
|
||||
auth_service = AuthService(db)
|
||||
|
||||
# Get client IP and user agent for audit logging
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = get_user_agent(request)
|
||||
|
||||
# Attempt login
|
||||
result = await auth_service.login(
|
||||
username=credentials.username,
|
||||
password=credentials.password,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
if not result:
|
||||
logger.warning("login_endpoint_failed",
|
||||
username=credentials.username,
|
||||
ip=ip_address)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Invalid username or password"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("login_endpoint_success",
|
||||
username=credentials.username,
|
||||
user_id=result["user"]["id"],
|
||||
ip=ip_address)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout",
|
||||
response_model=LogoutResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="User logout",
|
||||
description="Logout by blacklisting the current access token",
|
||||
dependencies=[Depends(require_viewer)] # Requires authentication
|
||||
)
|
||||
async def logout(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Logout user by blacklisting their access token
|
||||
|
||||
**Authentication Required:**
|
||||
- Must include valid JWT access token in Authorization header
|
||||
|
||||
**Response:**
|
||||
- `message`: Logout confirmation message
|
||||
|
||||
**Audit Log:**
|
||||
- Creates audit log entry for logout
|
||||
"""
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required"
|
||||
}
|
||||
)
|
||||
|
||||
# Extract token (remove "Bearer " prefix)
|
||||
token = auth_header.split()[1] if len(auth_header.split()) == 2 else None
|
||||
if not token:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Invalid authorization header"
|
||||
}
|
||||
)
|
||||
|
||||
auth_service = AuthService(db)
|
||||
|
||||
# Get client IP and user agent for audit logging
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = get_user_agent(request)
|
||||
|
||||
# Perform logout
|
||||
success = await auth_service.logout(
|
||||
token=token,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning("logout_endpoint_failed", ip=ip_address)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Invalid or expired token"
|
||||
}
|
||||
)
|
||||
|
||||
user = get_current_user(request)
|
||||
logger.info("logout_endpoint_success",
|
||||
user_id=str(user.id) if user else None,
|
||||
username=user.username if user else None,
|
||||
ip=ip_address)
|
||||
|
||||
return {"message": "Successfully logged out"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Refresh access token",
|
||||
description="Generate new access token using refresh token"
|
||||
)
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Generate new access token from refresh token
|
||||
|
||||
**Request Body:**
|
||||
- `refresh_token`: Valid JWT refresh token
|
||||
|
||||
**Response:**
|
||||
- `access_token`: New JWT access token
|
||||
- `token_type`: Token type (always "bearer")
|
||||
- `expires_in`: Access token expiration in seconds
|
||||
|
||||
**Note:**
|
||||
- Refresh token is NOT rotated (same refresh token can be reused)
|
||||
- For security, consider implementing refresh token rotation in production
|
||||
"""
|
||||
auth_service = AuthService(db)
|
||||
|
||||
# Get client IP for logging
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# Refresh token
|
||||
result = await auth_service.refresh_access_token(
|
||||
refresh_token=refresh_request.refresh_token,
|
||||
ip_address=ip_address
|
||||
)
|
||||
|
||||
if not result:
|
||||
logger.warning("refresh_endpoint_failed", ip=ip_address)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Invalid or expired refresh token"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("refresh_endpoint_success", ip=ip_address)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserInfo,
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Get current user",
|
||||
description="Get information about the currently authenticated user",
|
||||
dependencies=[Depends(require_viewer)] # Requires authentication
|
||||
)
|
||||
async def get_me(request: Request):
|
||||
"""
|
||||
Get current authenticated user information
|
||||
|
||||
**Authentication Required:**
|
||||
- Must include valid JWT access token in Authorization header
|
||||
|
||||
**Response:**
|
||||
- User information (id, username, role, created_at, updated_at)
|
||||
|
||||
**Note:**
|
||||
- Password hash is NEVER included in response
|
||||
"""
|
||||
user = get_current_user(request)
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Authentication required"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("get_me_endpoint",
|
||||
user_id=str(user.id),
|
||||
username=user.username)
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"created_at": user.created_at,
|
||||
"updated_at": user.updated_at
|
||||
}
|
||||
3
src/api/schemas/__init__.py
Normal file
3
src/api/schemas/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Pydantic schemas for request/response validation
|
||||
"""
|
||||
145
src/api/schemas/auth.py
Normal file
145
src/api/schemas/auth.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Authentication schemas for request/response validation
|
||||
"""
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Request schema for user login"""
|
||||
username: str = Field(..., min_length=1, max_length=50, description="Username")
|
||||
password: str = Field(..., min_length=1, description="Password")
|
||||
|
||||
@field_validator('username')
|
||||
@classmethod
|
||||
def username_not_empty(cls, v: str) -> str:
|
||||
"""Ensure username is not empty or whitespace"""
|
||||
if not v or not v.strip():
|
||||
raise ValueError('Username cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def password_not_empty(cls, v: str) -> str:
|
||||
"""Ensure password is not empty"""
|
||||
if not v:
|
||||
raise ValueError('Password cannot be empty')
|
||||
return v
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""User information schema (excludes sensitive data)"""
|
||||
id: str = Field(..., description="User UUID")
|
||||
username: str = Field(..., description="Username")
|
||||
role: str = Field(..., description="User role (viewer, operator, administrator)")
|
||||
created_at: datetime = Field(..., description="Account creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
model_config = {
|
||||
"from_attributes": True,
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"username": "admin",
|
||||
"role": "administrator",
|
||||
"created_at": "2025-12-08T10:00:00Z",
|
||||
"updated_at": "2025-12-08T10:00:00Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Response schema for successful authentication"""
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
refresh_token: str = Field(..., description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer", description="Token type (always 'bearer')")
|
||||
expires_in: int = Field(..., description="Access token expiration time in seconds")
|
||||
user: UserInfo = Field(..., description="Authenticated user information")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"user": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"username": "admin",
|
||||
"role": "administrator",
|
||||
"created_at": "2025-12-08T10:00:00Z",
|
||||
"updated_at": "2025-12-08T10:00:00Z"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Response schema for successful logout"""
|
||||
message: str = Field(default="Successfully logged out", description="Logout confirmation message")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"message": "Successfully logged out"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request schema for token refresh"""
|
||||
refresh_token: str = Field(..., description="Refresh token")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TokenValidationResponse(BaseModel):
|
||||
"""Response schema for token validation"""
|
||||
valid: bool = Field(..., description="Whether the token is valid")
|
||||
user: Optional[UserInfo] = Field(None, description="User information if token is valid")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"valid": True,
|
||||
"user": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"username": "admin",
|
||||
"role": "administrator",
|
||||
"created_at": "2025-12-08T10:00:00Z",
|
||||
"updated_at": "2025-12-08T10:00:00Z"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
3
src/api/services/__init__.py
Normal file
3
src/api/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Business logic services
|
||||
"""
|
||||
318
src/api/services/auth_service.py
Normal file
318
src/api/services/auth_service.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Authentication service for user login, logout, and token management
|
||||
"""
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from passlib.hash import bcrypt
|
||||
import structlog
|
||||
|
||||
from models.user import User
|
||||
from models.audit_log import AuditLog
|
||||
from utils.jwt_utils import create_access_token, create_refresh_token, verify_token, decode_token
|
||||
from clients.redis_client import redis_client
|
||||
from config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for authentication operations"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Authenticate user and generate tokens
|
||||
|
||||
Args:
|
||||
username: Username to authenticate
|
||||
password: Plain text password
|
||||
ip_address: Client IP address for audit logging
|
||||
user_agent: Client user agent for audit logging
|
||||
|
||||
Returns:
|
||||
Dictionary with tokens and user info, or None if authentication failed
|
||||
"""
|
||||
logger.info("login_attempt", username=username, ip_address=ip_address)
|
||||
|
||||
# Find user by username
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
logger.warning("login_failed_user_not_found", username=username)
|
||||
# Create audit log for failed login
|
||||
await self._create_audit_log(
|
||||
action="auth.login",
|
||||
target=username,
|
||||
outcome="failure",
|
||||
details={"reason": "user_not_found"},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
return None
|
||||
|
||||
# Verify password
|
||||
if not await self.verify_password(password, user.password_hash):
|
||||
logger.warning("login_failed_invalid_password", username=username, user_id=str(user.id))
|
||||
# Create audit log for failed login
|
||||
await self._create_audit_log(
|
||||
action="auth.login",
|
||||
target=username,
|
||||
outcome="failure",
|
||||
details={"reason": "invalid_password"},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
user_id=user.id
|
||||
)
|
||||
return None
|
||||
|
||||
# Generate tokens
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role.value
|
||||
}
|
||||
|
||||
access_token = create_access_token(token_data)
|
||||
refresh_token = create_refresh_token(token_data)
|
||||
|
||||
logger.info("login_success", username=username, user_id=str(user.id), role=user.role.value)
|
||||
|
||||
# Create audit log for successful login
|
||||
await self._create_audit_log(
|
||||
action="auth.login",
|
||||
target=username,
|
||||
outcome="success",
|
||||
details={"role": user.role.value},
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
user_id=user.id
|
||||
)
|
||||
|
||||
# Return token response
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, # Convert to seconds
|
||||
"user": {
|
||||
"id": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"created_at": user.created_at,
|
||||
"updated_at": user.updated_at
|
||||
}
|
||||
}
|
||||
|
||||
async def logout(
|
||||
self,
|
||||
token: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Logout user by blacklisting their token
|
||||
|
||||
Args:
|
||||
token: JWT access token to blacklist
|
||||
ip_address: Client IP address for audit logging
|
||||
user_agent: Client user agent for audit logging
|
||||
|
||||
Returns:
|
||||
True if logout successful, False otherwise
|
||||
"""
|
||||
# Decode and verify token
|
||||
payload = decode_token(token)
|
||||
if not payload:
|
||||
logger.warning("logout_failed_invalid_token")
|
||||
return False
|
||||
|
||||
user_id = payload.get("sub")
|
||||
username = payload.get("username")
|
||||
|
||||
# Calculate remaining TTL for token
|
||||
exp = payload.get("exp")
|
||||
if not exp:
|
||||
logger.warning("logout_failed_no_expiration", user_id=user_id)
|
||||
return False
|
||||
|
||||
# Blacklist token in Redis with TTL matching token expiration
|
||||
from datetime import datetime
|
||||
remaining_seconds = int(exp - datetime.utcnow().timestamp())
|
||||
|
||||
if remaining_seconds > 0:
|
||||
blacklist_key = f"blacklist:{token}"
|
||||
await redis_client.set(blacklist_key, "1", expire=remaining_seconds)
|
||||
logger.info("token_blacklisted", user_id=user_id, username=username, ttl=remaining_seconds)
|
||||
|
||||
# Create audit log for logout
|
||||
await self._create_audit_log(
|
||||
action="auth.logout",
|
||||
target=username,
|
||||
outcome="success",
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
logger.info("logout_success", user_id=user_id, username=username)
|
||||
return True
|
||||
|
||||
async def validate_token(self, token: str) -> Optional[User]:
|
||||
"""
|
||||
Validate JWT token and return user if valid
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
|
||||
Returns:
|
||||
User object if token is valid, None otherwise
|
||||
"""
|
||||
# Verify token signature and expiration
|
||||
payload = verify_token(token, token_type="access")
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
# Check if token is blacklisted
|
||||
blacklist_key = f"blacklist:{token}"
|
||||
is_blacklisted = await redis_client.get(blacklist_key)
|
||||
if is_blacklisted:
|
||||
logger.warning("token_blacklisted_validation_failed", user_id=payload.get("sub"))
|
||||
return None
|
||||
|
||||
# Get user from database
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
return user
|
||||
|
||||
async def refresh_access_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
ip_address: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Generate new access token from refresh token
|
||||
|
||||
Args:
|
||||
refresh_token: JWT refresh token
|
||||
ip_address: Client IP address for audit logging
|
||||
|
||||
Returns:
|
||||
Dictionary with new access token, or None if refresh failed
|
||||
"""
|
||||
# Verify refresh token
|
||||
payload = verify_token(refresh_token, token_type="refresh")
|
||||
if not payload:
|
||||
logger.warning("refresh_failed_invalid_token")
|
||||
return None
|
||||
|
||||
# Check if refresh token is blacklisted
|
||||
blacklist_key = f"blacklist:{refresh_token}"
|
||||
is_blacklisted = await redis_client.get(blacklist_key)
|
||||
if is_blacklisted:
|
||||
logger.warning("refresh_failed_token_blacklisted", user_id=payload.get("sub"))
|
||||
return None
|
||||
|
||||
# Generate new access token
|
||||
token_data = {
|
||||
"sub": payload.get("sub"),
|
||||
"username": payload.get("username"),
|
||||
"role": payload.get("role")
|
||||
}
|
||||
|
||||
access_token = create_access_token(token_data)
|
||||
|
||||
logger.info("token_refreshed", user_id=payload.get("sub"), username=payload.get("username"))
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
}
|
||||
|
||||
async def hash_password(self, password: str) -> str:
|
||||
"""
|
||||
Hash password using bcrypt
|
||||
|
||||
Args:
|
||||
password: Plain text password
|
||||
|
||||
Returns:
|
||||
Bcrypt hashed password
|
||||
"""
|
||||
return bcrypt.hash(password)
|
||||
|
||||
async def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify password against hash
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password
|
||||
hashed_password: Bcrypt hashed password
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
try:
|
||||
return bcrypt.verify(plain_password, hashed_password)
|
||||
except Exception as e:
|
||||
logger.error("password_verification_error", error=str(e))
|
||||
return False
|
||||
|
||||
async def _create_audit_log(
|
||||
self,
|
||||
action: str,
|
||||
target: str,
|
||||
outcome: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create audit log entry
|
||||
|
||||
Args:
|
||||
action: Action name (e.g., "auth.login")
|
||||
target: Target of action (e.g., username)
|
||||
outcome: Outcome ("success", "failure", "error")
|
||||
details: Additional details as dictionary
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
user_id: User UUID (if available)
|
||||
"""
|
||||
try:
|
||||
audit_log = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
target=target,
|
||||
outcome=outcome,
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
self.db.add(audit_log)
|
||||
await self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error("audit_log_creation_failed", action=action, error=str(e))
|
||||
# Don't let audit log failure break the operation
|
||||
await self.db.rollback()
|
||||
3
src/api/tests/__init__.py
Normal file
3
src/api/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Tests package
|
||||
"""
|
||||
187
src/api/tests/conftest.py
Normal file
187
src/api/tests/conftest.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Pytest fixtures for testing
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from datetime import datetime, timedelta
|
||||
import jwt
|
||||
|
||||
from main import app
|
||||
from config import settings
|
||||
from models import Base, get_db
|
||||
from models.user import User, UserRole
|
||||
from utils.jwt_utils import create_access_token
|
||||
import uuid
|
||||
|
||||
|
||||
# Test database URL - use separate test database
|
||||
TEST_DATABASE_URL = settings.DATABASE_URL.replace("/geutebruck_api", "/geutebruck_api_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_db_engine():
|
||||
"""Create test database engine"""
|
||||
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
# Drop all tables after test
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_db_session(test_db_engine):
|
||||
"""Create test database session"""
|
||||
AsyncTestingSessionLocal = async_sessionmaker(
|
||||
test_db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def async_client(test_db_session):
|
||||
"""Create async HTTP client for testing"""
|
||||
|
||||
# Override the get_db dependency to use test database
|
||||
async def override_get_db():
|
||||
yield test_db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
# Clear overrides
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_admin_user(test_db_session):
|
||||
"""Create test admin user"""
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username="admin",
|
||||
password_hash=bcrypt.hash("admin123"),
|
||||
role=UserRole.ADMINISTRATOR,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
test_db_session.add(user)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_operator_user(test_db_session):
|
||||
"""Create test operator user"""
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username="operator",
|
||||
password_hash=bcrypt.hash("operator123"),
|
||||
role=UserRole.OPERATOR,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
test_db_session.add(user)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_viewer_user(test_db_session):
|
||||
"""Create test viewer user"""
|
||||
from passlib.hash import bcrypt
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username="viewer",
|
||||
password_hash=bcrypt.hash("viewer123"),
|
||||
role=UserRole.VIEWER,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
test_db_session.add(user)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token(test_admin_user):
|
||||
"""Generate valid authentication token for admin user"""
|
||||
token_data = {
|
||||
"sub": str(test_admin_user.id),
|
||||
"username": test_admin_user.username,
|
||||
"role": test_admin_user.role.value
|
||||
}
|
||||
return create_access_token(token_data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def operator_token(test_operator_user):
|
||||
"""Generate valid authentication token for operator user"""
|
||||
token_data = {
|
||||
"sub": str(test_operator_user.id),
|
||||
"username": test_operator_user.username,
|
||||
"role": test_operator_user.role.value
|
||||
}
|
||||
return create_access_token(token_data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def viewer_token(test_viewer_user):
|
||||
"""Generate valid authentication token for viewer user"""
|
||||
token_data = {
|
||||
"sub": str(test_viewer_user.id),
|
||||
"username": test_viewer_user.username,
|
||||
"role": test_viewer_user.role.value
|
||||
}
|
||||
return create_access_token(token_data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expired_token():
|
||||
"""Generate expired authentication token"""
|
||||
token_data = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"username": "testuser",
|
||||
"role": "viewer",
|
||||
"exp": datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
"iat": datetime.utcnow() - timedelta(hours=2),
|
||||
"type": "access"
|
||||
}
|
||||
return jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
172
src/api/tests/test_auth_api.py
Normal file
172
src/api/tests/test_auth_api.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Contract tests for authentication API endpoints
|
||||
These tests define the expected behavior - they will FAIL until implementation is complete
|
||||
"""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from fastapi import status
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthLogin:
|
||||
"""Contract tests for POST /api/v1/auth/login"""
|
||||
|
||||
async def test_login_success(self, async_client: AsyncClient):
|
||||
"""Test successful login with valid credentials"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# Verify response structure
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert "token_type" in data
|
||||
assert "expires_in" in data
|
||||
assert "user" in data
|
||||
|
||||
# Verify token type
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# Verify user info
|
||||
assert data["user"]["username"] == "admin"
|
||||
assert data["user"]["role"] == "administrator"
|
||||
assert "password_hash" not in data["user"] # Never expose password hash
|
||||
|
||||
async def test_login_invalid_username(self, async_client: AsyncClient):
|
||||
"""Test login with non-existent username"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"username": "nonexistent",
|
||||
"password": "somepassword"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"] == "Unauthorized"
|
||||
|
||||
async def test_login_invalid_password(self, async_client: AsyncClient):
|
||||
"""Test login with incorrect password"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"username": "admin",
|
||||
"password": "wrongpassword"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
|
||||
async def test_login_missing_username(self, async_client: AsyncClient):
|
||||
"""Test login with missing username field"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"password": "admin123"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_login_missing_password(self, async_client: AsyncClient):
|
||||
"""Test login with missing password field"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"username": "admin"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_login_empty_username(self, async_client: AsyncClient):
|
||||
"""Test login with empty username"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"username": "",
|
||||
"password": "admin123"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_login_empty_password(self, async_client: AsyncClient):
|
||||
"""Test login with empty password"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"username": "admin",
|
||||
"password": ""
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthLogout:
|
||||
"""Contract tests for POST /api/v1/auth/logout"""
|
||||
|
||||
async def test_logout_success(self, async_client: AsyncClient, auth_token: str):
|
||||
"""Test successful logout with valid token"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["message"] == "Successfully logged out"
|
||||
|
||||
async def test_logout_no_token(self, async_client: AsyncClient):
|
||||
"""Test logout without authentication token"""
|
||||
response = await async_client.post("/api/v1/auth/logout")
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
async def test_logout_invalid_token(self, async_client: AsyncClient):
|
||||
"""Test logout with invalid token"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": "Bearer invalid_token_here"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
async def test_logout_expired_token(self, async_client: AsyncClient, expired_token: str):
|
||||
"""Test logout with expired token"""
|
||||
response = await async_client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {expired_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthProtectedEndpoint:
|
||||
"""Test authentication middleware on protected endpoints"""
|
||||
|
||||
async def test_protected_endpoint_with_valid_token(self, async_client: AsyncClient, auth_token: str):
|
||||
"""Test accessing protected endpoint with valid token"""
|
||||
# This will be used to test any protected endpoint once we have them
|
||||
# For now, we'll test with a mock protected endpoint
|
||||
pass
|
||||
|
||||
async def test_protected_endpoint_without_token(self, async_client: AsyncClient):
|
||||
"""Test accessing protected endpoint without token"""
|
||||
# Will be implemented when we have actual protected endpoints
|
||||
pass
|
||||
266
src/api/tests/test_auth_service.py
Normal file
266
src/api/tests/test_auth_service.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Unit tests for AuthService
|
||||
These tests will FAIL until AuthService is implemented
|
||||
"""
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
|
||||
from services.auth_service import AuthService
|
||||
from models.user import User, UserRole
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthServiceLogin:
|
||||
"""Unit tests for AuthService.login()"""
|
||||
|
||||
async def test_login_success(self, test_db_session, test_admin_user):
|
||||
"""Test successful login with valid credentials"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.login("admin", "admin123", ip_address="127.0.0.1")
|
||||
|
||||
assert result is not None
|
||||
assert "access_token" in result
|
||||
assert "refresh_token" in result
|
||||
assert "token_type" in result
|
||||
assert result["token_type"] == "bearer"
|
||||
assert "expires_in" in result
|
||||
assert "user" in result
|
||||
assert result["user"]["username"] == "admin"
|
||||
assert result["user"]["role"] == "administrator"
|
||||
|
||||
async def test_login_invalid_username(self, test_db_session):
|
||||
"""Test login with non-existent username"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.login("nonexistent", "somepassword", ip_address="127.0.0.1")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_login_invalid_password(self, test_db_session, test_admin_user):
|
||||
"""Test login with incorrect password"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.login("admin", "wrongpassword", ip_address="127.0.0.1")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_login_operator(self, test_db_session, test_operator_user):
|
||||
"""Test successful login for operator role"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.login("operator", "operator123", ip_address="127.0.0.1")
|
||||
|
||||
assert result is not None
|
||||
assert result["user"]["role"] == "operator"
|
||||
|
||||
async def test_login_viewer(self, test_db_session, test_viewer_user):
|
||||
"""Test successful login for viewer role"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.login("viewer", "viewer123", ip_address="127.0.0.1")
|
||||
|
||||
assert result is not None
|
||||
assert result["user"]["role"] == "viewer"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthServiceLogout:
|
||||
"""Unit tests for AuthService.logout()"""
|
||||
|
||||
async def test_logout_success(self, test_db_session, test_admin_user, auth_token):
|
||||
"""Test successful logout"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
# Logout should add token to blacklist
|
||||
result = await auth_service.logout(auth_token, ip_address="127.0.0.1")
|
||||
|
||||
assert result is True
|
||||
|
||||
async def test_logout_invalid_token(self, test_db_session):
|
||||
"""Test logout with invalid token"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.logout("invalid_token", ip_address="127.0.0.1")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_logout_expired_token(self, test_db_session, expired_token):
|
||||
"""Test logout with expired token"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
result = await auth_service.logout(expired_token, ip_address="127.0.0.1")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthServiceValidateToken:
|
||||
"""Unit tests for AuthService.validate_token()"""
|
||||
|
||||
async def test_validate_token_success(self, test_db_session, test_admin_user, auth_token):
|
||||
"""Test validation of valid token"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
user = await auth_service.validate_token(auth_token)
|
||||
|
||||
assert user is not None
|
||||
assert isinstance(user, User)
|
||||
assert user.username == "admin"
|
||||
assert user.role == UserRole.ADMINISTRATOR
|
||||
|
||||
async def test_validate_token_invalid(self, test_db_session):
|
||||
"""Test validation of invalid token"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
user = await auth_service.validate_token("invalid_token")
|
||||
|
||||
assert user is None
|
||||
|
||||
async def test_validate_token_expired(self, test_db_session, expired_token):
|
||||
"""Test validation of expired token"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
user = await auth_service.validate_token(expired_token)
|
||||
|
||||
assert user is None
|
||||
|
||||
async def test_validate_token_blacklisted(self, test_db_session, test_admin_user, auth_token):
|
||||
"""Test validation of blacklisted token (after logout)"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
# First logout to blacklist the token
|
||||
await auth_service.logout(auth_token, ip_address="127.0.0.1")
|
||||
|
||||
# Then try to validate it
|
||||
user = await auth_service.validate_token(auth_token)
|
||||
|
||||
assert user is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthServicePasswordHashing:
|
||||
"""Unit tests for password hashing and verification"""
|
||||
|
||||
async def test_hash_password(self, test_db_session):
|
||||
"""Test password hashing"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
plain_password = "mypassword123"
|
||||
hashed = await auth_service.hash_password(plain_password)
|
||||
|
||||
# Hash should not equal plain text
|
||||
assert hashed != plain_password
|
||||
# Hash should start with bcrypt identifier
|
||||
assert hashed.startswith("$2b$")
|
||||
|
||||
async def test_verify_password_success(self, test_db_session):
|
||||
"""Test successful password verification"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
plain_password = "mypassword123"
|
||||
hashed = await auth_service.hash_password(plain_password)
|
||||
|
||||
# Verification should succeed
|
||||
result = await auth_service.verify_password(plain_password, hashed)
|
||||
assert result is True
|
||||
|
||||
async def test_verify_password_failure(self, test_db_session):
|
||||
"""Test failed password verification"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
plain_password = "mypassword123"
|
||||
hashed = await auth_service.hash_password(plain_password)
|
||||
|
||||
# Verification with wrong password should fail
|
||||
result = await auth_service.verify_password("wrongpassword", hashed)
|
||||
assert result is False
|
||||
|
||||
async def test_hash_password_different_each_time(self, test_db_session):
|
||||
"""Test that same password produces different hashes (due to salt)"""
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
plain_password = "mypassword123"
|
||||
hash1 = await auth_service.hash_password(plain_password)
|
||||
hash2 = await auth_service.hash_password(plain_password)
|
||||
|
||||
# Hashes should be different (bcrypt uses random salt)
|
||||
assert hash1 != hash2
|
||||
|
||||
# But both should verify successfully
|
||||
assert await auth_service.verify_password(plain_password, hash1)
|
||||
assert await auth_service.verify_password(plain_password, hash2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthServiceAuditLogging:
|
||||
"""Unit tests for audit logging in AuthService"""
|
||||
|
||||
async def test_login_success_creates_audit_log(self, test_db_session, test_admin_user):
|
||||
"""Test that successful login creates audit log entry"""
|
||||
from models.audit_log import AuditLog
|
||||
from sqlalchemy import select
|
||||
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
# Perform login
|
||||
await auth_service.login("admin", "admin123", ip_address="192.168.1.100")
|
||||
|
||||
# Check audit log was created
|
||||
result = await test_db_session.execute(
|
||||
select(AuditLog).where(AuditLog.action == "auth.login")
|
||||
)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
assert len(audit_logs) >= 1
|
||||
audit_log = audit_logs[-1] # Get most recent
|
||||
assert audit_log.action == "auth.login"
|
||||
assert audit_log.target == "admin"
|
||||
assert audit_log.outcome == "success"
|
||||
assert audit_log.ip_address == "192.168.1.100"
|
||||
|
||||
async def test_login_failure_creates_audit_log(self, test_db_session):
|
||||
"""Test that failed login creates audit log entry"""
|
||||
from models.audit_log import AuditLog
|
||||
from sqlalchemy import select
|
||||
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
# Attempt login with invalid credentials
|
||||
await auth_service.login("admin", "wrongpassword", ip_address="192.168.1.100")
|
||||
|
||||
# Check audit log was created
|
||||
result = await test_db_session.execute(
|
||||
select(AuditLog).where(AuditLog.action == "auth.login").where(AuditLog.outcome == "failure")
|
||||
)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
assert len(audit_logs) >= 1
|
||||
audit_log = audit_logs[-1]
|
||||
assert audit_log.action == "auth.login"
|
||||
assert audit_log.target == "admin"
|
||||
assert audit_log.outcome == "failure"
|
||||
assert audit_log.ip_address == "192.168.1.100"
|
||||
|
||||
async def test_logout_creates_audit_log(self, test_db_session, test_admin_user, auth_token):
|
||||
"""Test that logout creates audit log entry"""
|
||||
from models.audit_log import AuditLog
|
||||
from sqlalchemy import select
|
||||
|
||||
auth_service = AuthService(test_db_session)
|
||||
|
||||
# Perform logout
|
||||
await auth_service.logout(auth_token, ip_address="192.168.1.100")
|
||||
|
||||
# Check audit log was created
|
||||
result = await test_db_session.execute(
|
||||
select(AuditLog).where(AuditLog.action == "auth.logout")
|
||||
)
|
||||
audit_logs = result.scalars().all()
|
||||
|
||||
assert len(audit_logs) >= 1
|
||||
audit_log = audit_logs[-1]
|
||||
assert audit_log.action == "auth.logout"
|
||||
assert audit_log.outcome == "success"
|
||||
assert audit_log.ip_address == "192.168.1.100"
|
||||
@@ -2,7 +2,7 @@
|
||||
Error translation utilities
|
||||
Maps gRPC errors to HTTP status codes and user-friendly messages
|
||||
"""
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Any
|
||||
import grpc
|
||||
from fastapi import status
|
||||
|
||||
|
||||
Reference in New Issue
Block a user