use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::{Response, IntoResponse}}; use axum::http::header::{SET_COOKIE, HeaderValue}; use serde::{Deserialize, Serialize}; use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation}; use std::sync::Arc; use std::collections::HashMap; use std::time::Instant; use tokio::sync::Mutex; use crate::AppState; use super::ApiResponse; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Claims { pub sub: i64, pub username: String, pub role: String, pub exp: u64, pub iat: u64, pub token_type: String, /// Random family ID for refresh token rotation detection pub family: String, } #[derive(Debug, Deserialize)] pub struct LoginRequest { pub username: String, pub password: String, } #[derive(Debug, Serialize)] pub struct LoginResponse { pub user: UserInfo, } #[derive(Debug, Serialize)] pub struct MeResponse { pub user: UserInfo, pub expires_at: String, } #[derive(Debug, Serialize, sqlx::FromRow)] pub struct UserInfo { pub id: i64, pub username: String, pub role: String, } #[derive(Debug, Deserialize)] pub struct ChangePasswordRequest { pub old_password: String, pub new_password: String, } // --------------------------------------------------------------------------- // Cookie helpers // --------------------------------------------------------------------------- fn is_secure_cookies() -> bool { std::env::var("CSM_DEV").is_err() } fn access_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue { let secure = if is_secure_cookies() { "; Secure" } else { "" }; HeaderValue::from_str(&format!( "access_token={}; HttpOnly{}; SameSite=Strict; Path=/; Max-Age={}", token, secure, ttl_secs )).expect("valid cookie header") } fn refresh_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue { let secure = if is_secure_cookies() { "; Secure" } else { "" }; HeaderValue::from_str(&format!( "refresh_token={}; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age={}", token, secure, ttl_secs )).expect("valid cookie header") } fn clear_cookie_headers() -> Vec { let secure = if is_secure_cookies() { "; Secure" } else { "" }; vec![ HeaderValue::from_str(&format!("access_token=; HttpOnly{}; SameSite=Strict; Path=/; Max-Age=0", secure)).expect("valid"), HeaderValue::from_str(&format!("refresh_token=; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age=0", secure)).expect("valid"), ] } /// Attach Set-Cookie headers to a response. fn with_cookies(mut response: Response, cookies: Vec) -> Response { for cookie in cookies { response.headers_mut().append(SET_COOKIE, cookie); } response } /// Extract a cookie value by name from the raw Cookie header. fn extract_cookie_value(headers: &axum::http::HeaderMap, name: &str) -> Option { let cookie_header = headers.get("cookie")?.to_str().ok()?; for cookie in cookie_header.split(';') { let cookie = cookie.trim(); if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) { return Some(value.to_string()); } } None } // --------------------------------------------------------------------------- // Rate limiter // --------------------------------------------------------------------------- #[derive(Clone, Default)] pub struct LoginRateLimiter { attempts: Arc>>, } impl LoginRateLimiter { pub fn new() -> Self { Self::default() } pub async fn is_limited(&self, key: &str) -> bool { let mut attempts = self.attempts.lock().await; let now = Instant::now(); let window = std::time::Duration::from_secs(300); let max_attempts = 10u32; if let Some((first_attempt, count)) = attempts.get_mut(key) { if now.duration_since(*first_attempt) > window { *first_attempt = now; *count = 1; false } else if *count >= max_attempts { true } else { *count += 1; false } } else { attempts.insert(key.to_string(), (now, 1)); if attempts.len() > 1000 { let cutoff = now - window; attempts.retain(|_, (t, _)| *t > cutoff); } false } } } // --------------------------------------------------------------------------- // Endpoints // --------------------------------------------------------------------------- pub async fn login( State(state): State, Json(req): Json, ) -> impl IntoResponse { if state.login_limiter.is_limited(&req.username).await { return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::::error("Too many login attempts. Try again later."))).into_response(); } if state.login_limiter.is_limited("ip:default").await { return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::::error("Too many login attempts from this location. Try again later."))).into_response(); } let row: Option<(UserInfo, String)> = sqlx::query_as::<_, (i64, String, String, String)>( "SELECT id, username, role, password FROM users WHERE username = ?" ) .bind(&req.username) .fetch_optional(&state.db) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .ok() .flatten() .map(|(id, username, role, password)| { (UserInfo { id, username, role }, password) }); let (user, hash) = match row { Some(r) => r, None => { let _ = bcrypt::verify("timing-constant-dummy", "$2b$12$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid credentials"))).into_response(); } }; if !bcrypt::verify(&req.password, &hash).unwrap_or(false) { return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid credentials"))).into_response(); } let now = chrono::Utc::now().timestamp() as u64; let family = uuid::Uuid::new_v4().to_string(); let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) { Ok(t) => t, Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), }; let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) { Ok(t) => t, Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), }; let refresh_expires = now + state.config.auth.refresh_token_ttl_secs; let _ = sqlx::query( "INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))" ) .bind(user.id) .bind(&family) .bind(refresh_expires as i64) .execute(&state.db) .await; let _ = sqlx::query( "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)" ) .bind(user.id) .bind(format!("User {} logged in", user.username)) .execute(&state.db) .await; let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response(); with_cookies(response, vec![ access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs), refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs), ]) } pub async fn refresh( State(state): State, headers: axum::http::HeaderMap, ) -> impl IntoResponse { let refresh_token = match extract_cookie_value(&headers, "refresh_token") { Some(t) => t, None => return with_cookies( (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Missing refresh token"))).into_response(), clear_cookie_headers(), ), }; let claims = match decode::( &refresh_token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), ) { Ok(c) => c.claims, Err(_) => return with_cookies( (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid refresh token"))).into_response(), clear_cookie_headers(), ), }; if claims.token_type != "refresh" { return with_cookies( (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token type"))).into_response(), clear_cookie_headers(), ); } let mut tx = match state.db.begin().await { Ok(tx) => tx, Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), }; let revoked: bool = sqlx::query_scalar::<_, i64>( "SELECT COUNT(*) FROM revoked_token_families WHERE family = ?" ) .bind(&claims.family) .fetch_one(&mut *tx) .await .unwrap_or(0) > 0; if revoked { tx.rollback().await.ok(); tracing::warn!("Refresh token reuse detected for user {} family {}", claims.sub, claims.family); let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?") .bind(claims.sub) .execute(&state.db) .await; return with_cookies( (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Token reuse detected. Please log in again."))).into_response(), clear_cookie_headers(), ); } let family_exists: bool = sqlx::query_scalar::<_, i64>( "SELECT COUNT(*) FROM refresh_tokens WHERE family = ? AND user_id = ?" ) .bind(&claims.family) .bind(claims.sub) .fetch_one(&mut *tx) .await .unwrap_or(0) > 0; if !family_exists { tx.rollback().await.ok(); return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid refresh token"))).into_response(); } let user = UserInfo { id: claims.sub, username: claims.username, role: claims.role, }; let new_family = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().timestamp() as u64; let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) { Ok(t) => t, Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), }; let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) { Ok(t) => t, Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), }; if sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))") .bind(&claims.family) .bind(claims.sub) .execute(&mut *tx) .await .is_err() { return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } let refresh_expires = now + state.config.auth.refresh_token_ttl_secs; if sqlx::query( "INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))" ) .bind(user.id) .bind(&new_family) .bind(refresh_expires as i64) .execute(&mut *tx) .await .is_err() { return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } if tx.commit().await.is_err() { return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response(); with_cookies(response, vec![ access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs), refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs), ]) } /// Get current authenticated user info from access_token cookie. pub async fn me( State(state): State, headers: axum::http::HeaderMap, ) -> impl IntoResponse { let token = match extract_cookie_value(&headers, "access_token") { Some(t) => t, None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Not authenticated"))).into_response(), }; let claims = match decode::( &token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), ) { Ok(c) => c.claims, Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token"))).into_response(), }; if claims.token_type != "access" { return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token type"))).into_response(); } let expires_at = chrono::DateTime::from_timestamp(claims.exp as i64, 0) .map(|t| t.to_rfc3339()) .unwrap_or_default(); (StatusCode::OK, Json(ApiResponse::ok(MeResponse { user: UserInfo { id: claims.sub, username: claims.username, role: claims.role, }, expires_at, }))).into_response() } /// Logout: clear auth cookies and revoke refresh token family. pub async fn logout( State(state): State, headers: axum::http::HeaderMap, ) -> impl IntoResponse { if let Some(token) = extract_cookie_value(&headers, "access_token") { if let Ok(claims) = decode::( &token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), ) { let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?") .bind(claims.claims.sub) .execute(&state.db) .await; let _ = sqlx::query( "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'logout', ?)" ) .bind(claims.claims.sub) .bind(format!("User {} logged out", claims.claims.username)) .execute(&state.db) .await; } } let response = (StatusCode::OK, Json(ApiResponse::ok(()))).into_response(); with_cookies(response, clear_cookie_headers()) } // --------------------------------------------------------------------------- // WebSocket ticket // --------------------------------------------------------------------------- #[derive(Debug, Serialize)] pub struct WsTicketResponse { pub ticket: String, pub expires_in: u64, } /// Create a one-time ticket for WebSocket authentication. /// Requires a valid access_token cookie (set by login). pub async fn create_ws_ticket( State(state): State, headers: axum::http::HeaderMap, ) -> impl IntoResponse { let token = match extract_cookie_value(&headers, "access_token") { Some(t) => t, None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Not authenticated"))).into_response(), }; let claims = match decode::( &token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), ) { Ok(c) => c.claims, Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token"))).into_response(), }; if claims.token_type != "access" { return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token type"))).into_response(); } let ticket = uuid::Uuid::new_v4().to_string(); let claim = crate::ws::TicketClaim { user_id: claims.sub, username: claims.username, role: claims.role, created_at: std::time::Instant::now(), }; { let mut tickets = state.ws_tickets.lock().await; tickets.insert(ticket.clone(), claim); // Cleanup expired tickets (>30s old) on every creation tickets.retain(|_, c| c.created_at.elapsed().as_secs() < 30); } (StatusCode::OK, Json(ApiResponse::ok(WsTicketResponse { ticket, expires_in: 30, }))).into_response() } // --------------------------------------------------------------------------- // Internal helpers // --------------------------------------------------------------------------- fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result { let claims = Claims { sub: user.id, username: user.username.clone(), role: user.role.clone(), exp: now + ttl, iat: now, token_type: token_type.to_string(), family: family.to_string(), }; encode( &Header::default(), &claims, &EncodingKey::from_secret(secret.as_bytes()), ) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } /// Axum middleware: require valid JWT access token from cookie pub async fn require_auth( State(state): State, mut request: Request, next: Next, ) -> Result { let token = extract_cookie_value(request.headers(), "access_token") .ok_or(StatusCode::UNAUTHORIZED)?; let claims = decode::( &token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), ) .map_err(|_| StatusCode::UNAUTHORIZED)?; if claims.claims.token_type != "access" { return Err(StatusCode::UNAUTHORIZED); } request.extensions_mut().insert(claims.claims); Ok(next.run(request).await) } /// Axum middleware: require admin role for write operations + audit log pub async fn require_admin( State(state): State, request: Request, next: Next, ) -> Result { let claims = request.extensions() .get::() .ok_or(StatusCode::UNAUTHORIZED)?; if claims.role != "admin" { return Err(StatusCode::FORBIDDEN); } let method = request.method().clone(); let path = request.uri().path().to_string(); let user_id = claims.sub; let username = claims.username.clone(); let response = next.run(request).await; let status = response.status(); if status.is_success() { let action = format!("{} {}", method, path); let detail = format!("by {}", username); let _ = sqlx::query( "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)" ) .bind(user_id) .bind(&action) .bind(&detail) .execute(&state.db) .await; } Ok(response) } pub async fn change_password( State(state): State, claims: axum::Extension, Json(req): Json, ) -> Result<(StatusCode, Json>), StatusCode> { if req.new_password.len() < 6 { return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("新密码至少6位")))); } if req.new_password.len() > 72 { return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("密码不能超过72位")))); } let hash: String = sqlx::query_scalar::<_, String>( "SELECT password FROM users WHERE id = ?" ) .bind(claims.sub) .fetch_one(&state.db) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; if !bcrypt::verify(&req.old_password, &hash).unwrap_or(false) { return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("当前密码错误")))); } let new_hash = bcrypt::hash(&req.new_password, 12).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; sqlx::query("UPDATE users SET password = ? WHERE id = ?") .bind(&new_hash) .bind(claims.sub) .execute(&state.db) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let _ = sqlx::query( "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'change_password', ?)" ) .bind(claims.sub) .bind(format!("User {} changed password", claims.username)) .execute(&state.db) .await; Ok((StatusCode::OK, Json(ApiResponse::ok(())))) }