//! TOTP 双因素认证 use axum::{ extract::{Extension, State}, Json, }; use crate::state::AppState; use crate::error::{SaasError, SaasResult}; use crate::auth::types::AuthContext; use crate::auth::handlers::log_operation; use serde::{Deserialize, Serialize}; /// TOTP 设置响应 #[derive(Debug, Serialize)] pub struct TotpSetupResponse { /// otpauth:// URI,用于扫码绑定 pub otpauth_uri: String, /// Base32 编码的密钥(备用手动输入) pub secret: String, /// issuer 名称 pub issuer: String, } /// TOTP 验证请求 #[derive(Debug, Deserialize)] pub struct TotpVerifyRequest { pub code: String, } /// TOTP 禁用请求 #[derive(Debug, Deserialize)] pub struct TotpDisableRequest { pub password: String, } /// 生成随机 Base32 密钥 (20 字节 = 32 字符 Base32) fn generate_random_secret() -> String { use rand::Rng; let mut bytes = [0u8; 20]; rand::thread_rng().fill(&mut bytes); data_encoding::BASE32.encode(&bytes) } /// Base32 解码 fn base32_decode(data: &str) -> Option> { data_encoding::BASE32.decode(data.as_bytes()).ok() } /// 生成 TOTP 密钥并返回 otpauth URI pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse { let secret = generate_random_secret(); let otpauth_uri = format!( "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits=6&period=30", urlencoding::encode(issuer), urlencoding::encode(account_name), secret, urlencoding::encode(issuer), ); TotpSetupResponse { otpauth_uri, secret, issuer: issuer.to_string(), } } /// 验证 TOTP 6 位码 pub fn verify_totp_code(secret: &str, code: &str) -> bool { let secret_bytes = match base32_decode(secret) { Some(b) => b, None => return false, }; let totp = match totp_rs::TOTP::new( totp_rs::Algorithm::SHA1, 6, // digits 1, // skew (允许 1 个周期偏差) 30, // step (秒) secret_bytes, ) { Ok(t) => t, Err(_) => return false, }; totp.check_current(code).unwrap_or(false) } /// POST /api/v1/auth/totp/setup /// 生成 TOTP 密钥并返回 otpauth URI /// 用户扫码后需要调用 /verify 验证一个码才能激活 pub async fn setup_totp( State(state): State, Extension(ctx): Extension, ) -> SaasResult> { // 如果已启用 TOTP,先清除旧密钥 let (username,): (String,) = sqlx::query_as( "SELECT username FROM accounts WHERE id = $1" ) .bind(&ctx.account_id) .fetch_one(&state.db) .await?; let config = state.config.read().await; let setup = generate_totp_secret(&config.auth.totp_issuer, &username); // 加密 TOTP 密钥后存储 (但不启用,需要 /verify 确认) let encrypted_secret = state.field_encryption.encrypt(&setup.secret)?; sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2") .bind(&encrypted_secret) .bind(&ctx.account_id) .execute(&state.db) .await?; log_operation(&state.db, &ctx.account_id, "totp.setup", "account", &ctx.account_id, None, ctx.client_ip.as_deref()).await?; Ok(Json(setup)) } /// POST /api/v1/auth/totp/verify /// 验证 TOTP 码并启用 2FA pub async fn verify_totp( State(state): State, Extension(ctx): Extension, Json(req): Json, ) -> SaasResult> { let code = req.code.trim(); if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) { return Err(SaasError::InvalidInput("TOTP 码必须是 6 位数字".into())); } // 获取存储的密钥 let (totp_secret,): (Option,) = sqlx::query_as( "SELECT totp_secret FROM accounts WHERE id = $1" ) .bind(&ctx.account_id) .fetch_one(&state.db) .await?; let secret = totp_secret.ok_or_else(|| { SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into()) })?; // 解密 TOTP 密钥(兼容迁移期间的明文数据) let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret); if !verify_totp_code(&decrypted_secret, code) { return Err(SaasError::Totp("TOTP 码验证失败".into())); } // 验证成功 → 启用 TOTP let now = chrono::Utc::now(); sqlx::query("UPDATE accounts SET totp_enabled = true, updated_at = $1 WHERE id = $2") .bind(now) .bind(&ctx.account_id) .execute(&state.db) .await?; log_operation(&state.db, &ctx.account_id, "totp.verify", "account", &ctx.account_id, None, ctx.client_ip.as_deref()).await?; Ok(Json(serde_json::json!({"ok": true, "totp_enabled": true, "message": "TOTP 已启用"}))) } /// POST /api/v1/auth/totp/disable /// 禁用 TOTP (需要密码确认) pub async fn disable_totp( State(state): State, Extension(ctx): Extension, Json(req): Json, ) -> SaasResult> { // 验证密码 let (password_hash,): (String,) = sqlx::query_as( "SELECT password_hash FROM accounts WHERE id = $1" ) .bind(&ctx.account_id) .fetch_one(&state.db) .await?; if !crate::auth::password::verify_password(&req.password, &password_hash)? { return Err(SaasError::AuthError("密码错误".into())); } // 清除 TOTP let now = chrono::Utc::now(); sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2") .bind(now) .bind(&ctx.account_id) .execute(&state.db) .await?; log_operation(&state.db, &ctx.account_id, "totp.disable", "account", &ctx.account_id, None, ctx.client_ip.as_deref()).await?; Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"}))) } #[cfg(test)] mod tests { use super::*; #[test] fn test_generate_totp_secret_format() { let result = generate_totp_secret("TestIssuer", "user@example.com"); assert!(result.otpauth_uri.starts_with("otpauth://totp/")); assert!(result.otpauth_uri.contains("secret=")); assert!(result.otpauth_uri.contains("issuer=TestIssuer")); assert!(result.otpauth_uri.contains("algorithm=SHA1")); assert!(result.otpauth_uri.contains("digits=6")); assert!(result.otpauth_uri.contains("period=30")); // Base32 编码的 20 字节 = 32 字符 assert_eq!(result.secret.len(), 32); assert_eq!(result.issuer, "TestIssuer"); } #[test] fn test_generate_totp_secret_special_chars() { let result = generate_totp_secret("My App", "user@domain:8080"); // 特殊字符应被 URL 编码 assert!(!result.otpauth_uri.contains("user@domain:8080")); assert!(result.otpauth_uri.contains("user%40domain")); } #[test] fn test_verify_totp_code_valid() { // 使用 generate_random_secret 创建合法 secret,然后生成并验证码 let secret = generate_random_secret(); let secret_bytes = data_encoding::BASE32.decode(secret.as_bytes()).unwrap(); let totp = totp_rs::TOTP::new( totp_rs::Algorithm::SHA1, 6, 1, 30, secret_bytes, ).unwrap(); let valid_code = totp.generate(chrono::Utc::now().timestamp() as u64); assert!(verify_totp_code(&secret, &valid_code)); } #[test] fn test_verify_totp_code_invalid() { let secret = generate_random_secret(); assert!(!verify_totp_code(&secret, "000000")); assert!(!verify_totp_code(&secret, "999999")); assert!(!verify_totp_code(&secret, "abcdef")); } #[test] fn test_verify_totp_code_invalid_secret() { assert!(!verify_totp_code("not-valid-base32!!!", "123456")); assert!(!verify_totp_code("", "123456")); assert!(!verify_totp_code("短", "123456")); } #[test] fn test_verify_totp_code_empty() { let secret = "JBSWY3DPEHPK3PXP"; assert!(!verify_totp_code(secret, "")); assert!(!verify_totp_code(secret, "12345")); assert!(!verify_totp_code(secret, "1234567")); } }