refactor: 重构数据库连接使用PostgreSQL替代SQLite feat(auth): 增加JWT验证的audience和issuer检查 feat(crypto): 添加AES-256-GCM字段加密支持 feat(api): 集成utoipa实现OpenAPI文档 fix(admin): 修复配置项表单验证逻辑 style: 统一代码格式与类型定义 docs: 更新技术栈文档说明PostgreSQL
259 lines
8.1 KiB
Rust
259 lines
8.1 KiB
Rust
//! TOTP 双因素认证
|
||
|
||
use axum::{
|
||
extract::{Extension, State},
|
||
Json,
|
||
};
|
||
use crate::state::AppState;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::auth::types::AuthContext;
|
||
use crate::auth::handlers::log_operation;
|
||
use serde::{Deserialize, Serialize};
|
||
|
||
/// TOTP 设置响应
|
||
#[derive(Debug, Serialize)]
|
||
pub struct TotpSetupResponse {
|
||
/// otpauth:// URI,用于扫码绑定
|
||
pub otpauth_uri: String,
|
||
/// Base32 编码的密钥(备用手动输入)
|
||
pub secret: String,
|
||
/// issuer 名称
|
||
pub issuer: String,
|
||
}
|
||
|
||
/// TOTP 验证请求
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct TotpVerifyRequest {
|
||
pub code: String,
|
||
}
|
||
|
||
/// TOTP 禁用请求
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct TotpDisableRequest {
|
||
pub password: String,
|
||
}
|
||
|
||
/// 生成随机 Base32 密钥 (20 字节 = 32 字符 Base32)
|
||
fn generate_random_secret() -> String {
|
||
use rand::Rng;
|
||
let mut bytes = [0u8; 20];
|
||
rand::thread_rng().fill(&mut bytes);
|
||
data_encoding::BASE32.encode(&bytes)
|
||
}
|
||
|
||
/// Base32 解码
|
||
fn base32_decode(data: &str) -> Option<Vec<u8>> {
|
||
data_encoding::BASE32.decode(data.as_bytes()).ok()
|
||
}
|
||
|
||
/// 生成 TOTP 密钥并返回 otpauth URI
|
||
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
|
||
let secret = generate_random_secret();
|
||
let otpauth_uri = format!(
|
||
"otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits=6&period=30",
|
||
urlencoding::encode(issuer),
|
||
urlencoding::encode(account_name),
|
||
secret,
|
||
urlencoding::encode(issuer),
|
||
);
|
||
|
||
TotpSetupResponse {
|
||
otpauth_uri,
|
||
secret,
|
||
issuer: issuer.to_string(),
|
||
}
|
||
}
|
||
|
||
/// 验证 TOTP 6 位码
|
||
pub fn verify_totp_code(secret: &str, code: &str) -> bool {
|
||
let secret_bytes = match base32_decode(secret) {
|
||
Some(b) => b,
|
||
None => return false,
|
||
};
|
||
|
||
let totp = match totp_rs::TOTP::new(
|
||
totp_rs::Algorithm::SHA1,
|
||
6, // digits
|
||
1, // skew (允许 1 个周期偏差)
|
||
30, // step (秒)
|
||
secret_bytes,
|
||
) {
|
||
Ok(t) => t,
|
||
Err(_) => return false,
|
||
};
|
||
|
||
totp.check_current(code).unwrap_or(false)
|
||
}
|
||
|
||
/// POST /api/v1/auth/totp/setup
|
||
/// 生成 TOTP 密钥并返回 otpauth URI
|
||
/// 用户扫码后需要调用 /verify 验证一个码才能激活
|
||
pub async fn setup_totp(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
) -> SaasResult<Json<TotpSetupResponse>> {
|
||
// 如果已启用 TOTP,先清除旧密钥
|
||
let (username,): (String,) = sqlx::query_as(
|
||
"SELECT username FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&ctx.account_id)
|
||
.fetch_one(&state.db)
|
||
.await?;
|
||
|
||
let config = state.config.read().await;
|
||
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
|
||
|
||
// 加密 TOTP 密钥后存储 (但不启用,需要 /verify 确认)
|
||
let encrypted_secret = state.field_encryption.encrypt(&setup.secret)?;
|
||
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
|
||
.bind(&encrypted_secret)
|
||
.bind(&ctx.account_id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
|
||
log_operation(&state.db, &ctx.account_id, "totp.setup", "account", &ctx.account_id,
|
||
None, ctx.client_ip.as_deref()).await?;
|
||
|
||
Ok(Json(setup))
|
||
}
|
||
|
||
/// POST /api/v1/auth/totp/verify
|
||
/// 验证 TOTP 码并启用 2FA
|
||
pub async fn verify_totp(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Json(req): Json<TotpVerifyRequest>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
let code = req.code.trim();
|
||
if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
|
||
return Err(SaasError::InvalidInput("TOTP 码必须是 6 位数字".into()));
|
||
}
|
||
|
||
// 获取存储的密钥
|
||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&ctx.account_id)
|
||
.fetch_one(&state.db)
|
||
.await?;
|
||
|
||
let secret = totp_secret.ok_or_else(|| {
|
||
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
|
||
})?;
|
||
|
||
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
|
||
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
|
||
|
||
if !verify_totp_code(&decrypted_secret, code) {
|
||
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
||
}
|
||
|
||
// 验证成功 → 启用 TOTP
|
||
let now = chrono::Utc::now();
|
||
sqlx::query("UPDATE accounts SET totp_enabled = true, updated_at = $1 WHERE id = $2")
|
||
.bind(now)
|
||
.bind(&ctx.account_id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
|
||
log_operation(&state.db, &ctx.account_id, "totp.verify", "account", &ctx.account_id,
|
||
None, ctx.client_ip.as_deref()).await?;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": true, "message": "TOTP 已启用"})))
|
||
}
|
||
|
||
/// POST /api/v1/auth/totp/disable
|
||
/// 禁用 TOTP (需要密码确认)
|
||
pub async fn disable_totp(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Json(req): Json<TotpDisableRequest>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
// 验证密码
|
||
let (password_hash,): (String,) = sqlx::query_as(
|
||
"SELECT password_hash FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&ctx.account_id)
|
||
.fetch_one(&state.db)
|
||
.await?;
|
||
|
||
if !crate::auth::password::verify_password(&req.password, &password_hash)? {
|
||
return Err(SaasError::AuthError("密码错误".into()));
|
||
}
|
||
|
||
// 清除 TOTP
|
||
let now = chrono::Utc::now();
|
||
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
|
||
.bind(now)
|
||
.bind(&ctx.account_id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
|
||
log_operation(&state.db, &ctx.account_id, "totp.disable", "account", &ctx.account_id,
|
||
None, ctx.client_ip.as_deref()).await?;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_generate_totp_secret_format() {
|
||
let result = generate_totp_secret("TestIssuer", "user@example.com");
|
||
assert!(result.otpauth_uri.starts_with("otpauth://totp/"));
|
||
assert!(result.otpauth_uri.contains("secret="));
|
||
assert!(result.otpauth_uri.contains("issuer=TestIssuer"));
|
||
assert!(result.otpauth_uri.contains("algorithm=SHA1"));
|
||
assert!(result.otpauth_uri.contains("digits=6"));
|
||
assert!(result.otpauth_uri.contains("period=30"));
|
||
// Base32 编码的 20 字节 = 32 字符
|
||
assert_eq!(result.secret.len(), 32);
|
||
assert_eq!(result.issuer, "TestIssuer");
|
||
}
|
||
|
||
#[test]
|
||
fn test_generate_totp_secret_special_chars() {
|
||
let result = generate_totp_secret("My App", "user@domain:8080");
|
||
// 特殊字符应被 URL 编码
|
||
assert!(!result.otpauth_uri.contains("user@domain:8080"));
|
||
assert!(result.otpauth_uri.contains("user%40domain"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_verify_totp_code_valid() {
|
||
// 使用 generate_random_secret 创建合法 secret,然后生成并验证码
|
||
let secret = generate_random_secret();
|
||
let secret_bytes = data_encoding::BASE32.decode(secret.as_bytes()).unwrap();
|
||
let totp = totp_rs::TOTP::new(
|
||
totp_rs::Algorithm::SHA1, 6, 1, 30, secret_bytes,
|
||
).unwrap();
|
||
let valid_code = totp.generate(chrono::Utc::now().timestamp() as u64);
|
||
assert!(verify_totp_code(&secret, &valid_code));
|
||
}
|
||
|
||
#[test]
|
||
fn test_verify_totp_code_invalid() {
|
||
let secret = generate_random_secret();
|
||
assert!(!verify_totp_code(&secret, "000000"));
|
||
assert!(!verify_totp_code(&secret, "999999"));
|
||
assert!(!verify_totp_code(&secret, "abcdef"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_verify_totp_code_invalid_secret() {
|
||
assert!(!verify_totp_code("not-valid-base32!!!", "123456"));
|
||
assert!(!verify_totp_code("", "123456"));
|
||
assert!(!verify_totp_code("短", "123456"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_verify_totp_code_empty() {
|
||
let secret = "JBSWY3DPEHPK3PXP";
|
||
assert!(!verify_totp_code(secret, ""));
|
||
assert!(!verify_totp_code(secret, "12345"));
|
||
assert!(!verify_totp_code(secret, "1234567"));
|
||
}
|
||
}
|