chore: 提交所有工作进度 — SaaS 后端增强、Admin UI、桌面端集成

包含大量 SaaS 平台改进、Admin 管理后台更新、桌面端集成完善、
文档同步、测试文件重构等内容。为 QA 测试准备干净工作树。
This commit is contained in:
iven
2026-03-29 10:46:26 +08:00
parent 9a5fad2b59
commit 5fdf96c3f5
268 changed files with 22011 additions and 3886 deletions

View File

@@ -6,26 +6,45 @@ use secrecy::ExposeSecret;
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use super::{
jwt::create_token,
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
password::{hash_password, verify_password},
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic},
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
};
/// POST /api/v1/auth/register
/// 注册成功后自动签发 JWT返回与 login 一致的 LoginResponse
pub async fn register(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<RegisterRequest>,
) -> SaasResult<(StatusCode, Json<AccountPublic>)> {
) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
if req.username.len() < 3 {
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
}
if req.username.len() > 32 {
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
}
let username_re = regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
if !username_re.is_match(&req.username) {
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
}
if !req.email.contains('@') || !req.email.contains('.') {
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
}
if req.password.len() < 8 {
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
}
if req.password.len() > 128 {
return Err(SaasError::InvalidInput("密码最多 128 个字符".into()));
}
if let Some(ref name) = req.display_name {
if name.len() > 64 {
return Err(SaasError::InvalidInput("显示名称最多 64 个字符".into()));
}
}
let existing: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM accounts WHERE username = ?1 OR email = ?2"
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
)
.bind(&req.username)
.bind(&req.email)
@@ -44,7 +63,7 @@ pub async fn register(
sqlx::query(
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'active', ?7, ?7)"
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7)"
)
.bind(&account_id)
.bind(&req.username)
@@ -59,15 +78,39 @@ pub async fn register(
let client_ip = addr.ip().to_string();
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
Ok((StatusCode::CREATED, Json(AccountPublic {
id: account_id,
username: req.username,
email: req.email,
display_name,
role,
status: "active".into(),
totp_enabled: false,
created_at: now,
// 注册成功后自动签发 JWT + Refresh Token
let permissions = get_role_permissions(&state.db, &role).await?;
let config = state.config.read().await;
let token = create_token(
&account_id, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
)?;
let refresh_token = create_refresh_token(
&account_id, &role, permissions,
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
)?;
drop(config);
store_refresh_token(
&state.db, &account_id, &refresh_token,
state.jwt_secret.expose_secret(), 168,
).await?;
Ok((StatusCode::CREATED, Json(LoginResponse {
token,
refresh_token,
account: AccountPublic {
id: account_id,
username: req.username,
email: req.email,
display_name,
role,
status: "active".into(),
totp_enabled: false,
created_at: now,
},
})))
}
@@ -80,7 +123,7 @@ pub async fn login(
let row: Option<(String, String, String, String, String, String, bool, String)> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE username = ?1 OR email = ?1"
FROM accounts WHERE username = $1 OR email = $1"
)
.bind(&req.username)
.fetch_optional(&state.db)
@@ -94,7 +137,7 @@ pub async fn login(
}
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = ?1"
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&id)
.fetch_one(&state.db)
@@ -110,7 +153,7 @@ pub async fn login(
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = ?1"
"SELECT totp_secret FROM accounts WHERE id = $1"
)
.bind(&id)
.fetch_one(&state.db)
@@ -120,6 +163,12 @@ pub async fn login(
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
})?;
// 解密 TOTP secret (兼容旧的明文格式)
let config = state.config.read().await;
let enc_key = config.totp_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
let secret = super::totp::decrypt_totp_for_login(&secret, &enc_key)?;
if !super::totp::verify_totp_code(&secret, code) {
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
}
@@ -132,16 +181,28 @@ pub async fn login(
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
)?;
let refresh_token = create_refresh_token(
&id, &role, permissions,
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
)?;
drop(config);
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET last_login_at = ?1 WHERE id = ?2")
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
.bind(&now).bind(&id)
.execute(&state.db).await?;
let client_ip = addr.ip().to_string();
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
store_refresh_token(
&state.db, &id, &refresh_token,
state.jwt_secret.expose_secret(), 168,
).await?;
Ok(Json(LoginResponse {
token,
refresh_token,
account: AccountPublic {
id, username, email, display_name, role, status, totp_enabled, created_at,
},
@@ -149,17 +210,92 @@ pub async fn login(
}
/// POST /api/v1/auth/refresh
/// 使用 refresh_token 换取新的 access + refresh token 对
/// refresh_token 一次性使用,使用后立即失效
pub async fn refresh(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
Json(req): Json<RefreshRequest>,
) -> SaasResult<Json<serde_json::Value>> {
// 1. 验证 refresh token 签名 (跳过过期检查,但有 7 天窗口限制)
let claims = verify_token_skip_expiry(&req.refresh_token, state.jwt_secret.expose_secret())?;
// 2. 确认是 refresh 类型 token
if claims.token_type != "refresh" {
return Err(SaasError::AuthError("无效的 refresh token".into()));
}
let jti = claims.jti.as_deref()
.ok_or_else(|| SaasError::AuthError("refresh token 缺少 jti".into()))?;
// 3. 从 DB 查找 refresh token确保未被使用
let row: Option<(String,)> = sqlx::query_as(
"SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at > $2"
)
.bind(jti)
.bind(&chrono::Utc::now().to_rfc3339())
.fetch_optional(&state.db)
.await?;
let token_account_id = row
.ok_or_else(|| SaasError::AuthError("refresh token 已使用、已过期或不存在".into()))?
.0;
// 4. 验证 token 中的 account_id 与 DB 中的一致
if token_account_id != claims.sub {
return Err(SaasError::AuthError("refresh token 账号不匹配".into()));
}
// 5. 标记旧 refresh token 为已使用 (一次性)
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2")
.bind(&now).bind(jti)
.execute(&state.db).await?;
// 6. 获取最新角色权限
let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
)
.bind(&claims.sub)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
let permissions = get_role_permissions(&state.db, &role).await?;
// 7. 创建新的 access token + refresh token
let config = state.config.read().await;
let token = create_token(
&ctx.account_id, &ctx.role, ctx.permissions.clone(),
let new_access = create_token(
&claims.sub, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
)?;
Ok(Json(serde_json::json!({ "token": token })))
let new_refresh = create_refresh_token(
&claims.sub, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
)?;
drop(config);
// 8. 存储新 refresh token 到 DB
let new_claims = verify_token(&new_refresh, state.jwt_secret.expose_secret())?;
let new_jti = new_claims.jti.unwrap_or_default();
let new_id = uuid::Uuid::new_v4().to_string();
let refresh_expires = (chrono::Utc::now() + chrono::Duration::hours(168)).to_rfc3339();
sqlx::query(
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6)"
)
.bind(&new_id).bind(&claims.sub).bind(&new_jti)
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
.execute(&state.db).await?;
// 9. 清理过期/已使用的 refresh tokens (异步, 不阻塞)
cleanup_expired_refresh_tokens(&state.db).await?;
Ok(Json(serde_json::json!({
"token": new_access,
"refresh_token": new_refresh,
})))
}
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
@@ -170,7 +306,7 @@ pub async fn me(
let row: Option<(String, String, String, String, String, String, bool, String)> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = ?1"
FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
@@ -196,7 +332,7 @@ pub async fn change_password(
// 获取当前密码哈希
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = ?1"
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
@@ -210,7 +346,7 @@ pub async fn change_password(
// 更新密码
let new_hash = hash_password(&req.new_password)?;
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
.bind(&new_hash)
.bind(&now)
.bind(&ctx.account_id)
@@ -223,9 +359,9 @@ pub async fn change_password(
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
}
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = ?1"
"SELECT permissions FROM roles WHERE id = $1"
)
.bind(role)
.fetch_optional(db)
@@ -252,7 +388,7 @@ pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
/// 记录操作日志
pub async fn log_operation(
db: &sqlx::SqlitePool,
db: &sqlx::PgPool,
account_id: &str,
action: &str,
target_type: &str,
@@ -263,7 +399,7 @@ pub async fn log_operation(
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
.bind(account_id)
.bind(action)
@@ -276,3 +412,45 @@ pub async fn log_operation(
.await?;
Ok(())
}
/// 存储 refresh token 到 DB
async fn store_refresh_token(
db: &sqlx::PgPool,
account_id: &str,
refresh_token: &str,
secret: &str,
refresh_hours: i64,
) -> SaasResult<()> {
let claims = verify_token(refresh_token, secret)?;
let jti = claims.jti.unwrap_or_default();
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let expires_at = (chrono::Utc::now() + chrono::Duration::hours(refresh_hours)).to_rfc3339();
sqlx::query(
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6)"
)
.bind(&id).bind(account_id).bind(&jti)
.bind(sha256_hex(refresh_token)).bind(&expires_at).bind(&now)
.execute(db).await?;
Ok(())
}
/// 清理过期和已使用的 refresh tokens
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
sqlx::query(
"DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at < $1) OR (expires_at < $1)"
)
.bind(&now)
.execute(db).await?;
Ok(())
}
/// SHA-256 hex digest
fn sha256_hex(input: &str) -> String {
use sha2::{Sha256, Digest};
hex::encode(Sha256::digest(input.as_bytes()))
}

View File

@@ -9,27 +9,52 @@ use crate::error::SaasResult;
/// JWT Claims
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
/// JWT ID — 唯一标识,用于 token 追踪和吊销
pub jti: Option<String>,
pub sub: String,
pub role: String,
pub permissions: Vec<String>,
/// token 类型: "access" 或 "refresh"
#[serde(default = "default_token_type")]
pub token_type: String,
pub iat: i64,
pub exp: i64,
}
fn default_token_type() -> String {
"access".to_string()
}
impl Claims {
pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
pub fn new_access(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
let now = Utc::now();
Self {
jti: Some(uuid::Uuid::new_v4().to_string()),
sub: account_id.to_string(),
role: role.to_string(),
permissions,
token_type: "access".to_string(),
iat: now.timestamp(),
exp: (now + Duration::hours(expiration_hours)).timestamp(),
}
}
/// 创建 refresh token claims (有效期更长,用于一次性刷新)
pub fn new_refresh(account_id: &str, role: &str, permissions: Vec<String>, refresh_hours: i64) -> Self {
let now = Utc::now();
Self {
jti: Some(uuid::Uuid::new_v4().to_string()),
sub: account_id.to_string(),
role: role.to_string(),
permissions,
token_type: "refresh".to_string(),
iat: now.timestamp(),
exp: (now + Duration::hours(refresh_hours)).timestamp(),
}
}
}
/// 创建 JWT Token
/// 创建 Access JWT Token
pub fn create_token(
account_id: &str,
role: &str,
@@ -37,7 +62,24 @@ pub fn create_token(
secret: &str,
expiration_hours: i64,
) -> SaasResult<String> {
let claims = Claims::new(account_id, role, permissions, expiration_hours);
let claims = Claims::new_access(account_id, role, permissions, expiration_hours);
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)?;
Ok(token)
}
/// 创建 Refresh JWT Token (独立 jti有效期更长)
pub fn create_refresh_token(
account_id: &str,
role: &str,
permissions: Vec<String>,
secret: &str,
refresh_hours: i64,
) -> SaasResult<String> {
let claims = Claims::new_refresh(account_id, role, permissions, refresh_hours);
let token = encode(
&Header::default(),
&claims,
@@ -56,6 +98,52 @@ pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
Ok(token_data.claims)
}
/// 验证 JWT Token 但跳过过期检查(仅用于 refresh token 刷新)
/// 限制: 原始 token 的 iat 必须在 7 天内
pub fn verify_token_skip_expiry(token: &str, secret: &str) -> SaasResult<Claims> {
let mut validation = Validation::default();
validation.validate_exp = false;
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&validation,
)?;
let claims = &token_data.claims;
// 限制刷新窗口: token 签发时间必须在 7 天内
let now = Utc::now().timestamp();
let max_refresh_window = 7 * 24 * 3600; // 7 天
if now - claims.iat > max_refresh_window {
return Err(jsonwebtoken::errors::Error::from(
jsonwebtoken::errors::ErrorKind::ExpiredSignature
).into());
}
Ok(token_data.claims)
}
/// Token 对: access token + refresh token
#[derive(Debug, serde::Serialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
}
/// 创建 access + refresh token 对
pub fn create_token_pair(
account_id: &str,
role: &str,
permissions: Vec<String>,
secret: &str,
access_hours: i64,
refresh_hours: i64,
) -> SaasResult<TokenPair> {
Ok(TokenPair {
access_token: create_token(account_id, role, permissions.clone(), secret, access_hours)?,
refresh_token: create_refresh_token(account_id, role, permissions, secret, refresh_hours)?,
})
}
#[cfg(test)]
mod tests {
use super::*;
@@ -74,6 +162,8 @@ mod tests {
assert_eq!(claims.sub, "account-123");
assert_eq!(claims.role, "admin");
assert_eq!(claims.permissions, vec!["model:read"]);
assert!(claims.jti.is_some());
assert_eq!(claims.token_type, "access");
}
#[test]
@@ -88,4 +178,17 @@ mod tests {
let result = verify_token(&token, "wrong-secret");
assert!(result.is_err());
}
#[test]
fn test_refresh_token_has_different_jti() {
let access = create_token("acct-1", "user", vec![], TEST_SECRET, 1).unwrap();
let refresh = create_refresh_token("acct-1", "user", vec![], TEST_SECRET, 168).unwrap();
let access_claims = verify_token(&access, TEST_SECRET).unwrap();
let refresh_claims = verify_token(&refresh, TEST_SECRET).unwrap();
assert_ne!(access_claims.jti, refresh_claims.jti);
assert_eq!(access_claims.token_type, "access");
assert_eq!(refresh_claims.token_type, "refresh");
}
}

View File

@@ -29,7 +29,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
"SELECT account_id, expires_at, permissions FROM api_tokens
WHERE token_hash = ?1 AND revoked_at IS NULL"
WHERE token_hash = $1 AND revoked_at IS NULL"
)
.bind(&token_hash)
.fetch_optional(&state.db)
@@ -50,7 +50,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
// 查询关联账号的角色
let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'"
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
)
.bind(&account_id)
.fetch_optional(&state.db)
@@ -71,7 +71,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
let db = state.db.clone();
tokio::spawn(async move {
let now = chrono::Utc::now().to_rfc3339();
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2")
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
.bind(&now).bind(&token_hash)
.execute(&db).await;
});
@@ -121,7 +121,8 @@ pub async fn auth_middleware(
verify_api_token(&state, token, client_ip.clone()).await
} else {
// JWT 路径
jwt::verify_token(token, state.jwt_secret.expose_secret())
let verify_result = jwt::verify_token(token, state.jwt_secret.expose_secret());
verify_result
.map(|claims| AuthContext {
account_id: claims.sub,
role: claims.role,
@@ -153,6 +154,7 @@ pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
.route("/api/v1/auth/register", post(handlers::register))
.route("/api/v1/auth/login", post(handlers::login))
.route("/api/v1/auth/refresh", post(handlers::refresh))
}
/// 需要认证的路由
@@ -160,7 +162,6 @@ pub fn protected_routes() -> axum::Router<AppState> {
use axum::routing::{get, post, put};
axum::Router::new()
.route("/api/v1/auth/refresh", post(handlers::refresh))
.route("/api/v1/auth/me", get(handlers::me))
.route("/api/v1/auth/password", put(handlers::change_password))
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))

View File

@@ -8,6 +8,7 @@ use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::log_operation;
use crate::crypto;
use serde::{Deserialize, Serialize};
/// TOTP 设置响应
@@ -46,6 +47,21 @@ fn base32_decode(data: &str) -> Option<Vec<u8>> {
data_encoding::BASE32.decode(data.as_bytes()).ok()
}
/// 加密 TOTP secret (AES-256-GCM随机 nonce)
/// 存储格式: enc:<base64(nonce||ciphertext)>
/// 委托给 crypto::encrypt_value 统一加密
fn encrypt_totp_secret(plaintext: &str, key: &[u8; 32]) -> Result<String, SaasError> {
crate::crypto::encrypt_value(plaintext, key)
.map_err(|e| SaasError::Internal(e.to_string()))
}
/// 解密 TOTP secret (仅支持新格式: 随机 nonce)
/// 旧的固定 nonce 格式应通过启动时迁移转换。
fn decrypt_totp_secret(encrypted: &str, key: &[u8; 32]) -> Result<String, SaasError> {
crate::crypto::decrypt_value(encrypted, key)
.map_err(|e| SaasError::Internal(e.to_string()))
}
/// 生成 TOTP 密钥并返回 otpauth URI
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
let secret = generate_random_secret();
@@ -94,7 +110,7 @@ pub async fn setup_totp(
) -> SaasResult<Json<TotpSetupResponse>> {
// 如果已启用 TOTP先清除旧密钥
let (username,): (String,) = sqlx::query_as(
"SELECT username FROM accounts WHERE id = ?1"
"SELECT username FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
@@ -103,9 +119,13 @@ pub async fn setup_totp(
let config = state.config.read().await;
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
// 存储密钥 (但不启用,需要 /verify 确认)
sqlx::query("UPDATE accounts SET totp_secret = ?1 WHERE id = ?2")
.bind(&setup.secret)
// 加密后存储密钥 (但不启用,需要 /verify 确认)
let enc_key = config.totp_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
let encrypted_secret = encrypt_totp_secret(&setup.secret, &enc_key)?;
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
.bind(&encrypted_secret)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;
@@ -130,23 +150,42 @@ pub async fn verify_totp(
// 获取存储的密钥
let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = ?1"
"SELECT totp_secret FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await?;
let secret = totp_secret.ok_or_else(|| {
let encrypted_secret = totp_secret.ok_or_else(|| {
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
})?;
// 解密 secret (兼容旧的明文格式)
let config = state.config.read().await;
let enc_key = config.totp_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
let secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
decrypt_totp_secret(&encrypted_secret, &enc_key)?
} else {
// 旧格式: 明文存储,需要迁移
encrypted_secret.clone()
};
if !verify_totp_code(&secret, code) {
return Err(SaasError::Totp("TOTP 码验证失败".into()));
}
// 验证成功 → 启用 TOTP
// 验证成功 → 启用 TOTP,同时确保密钥已加密
let final_secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
encrypted_secret
} else {
// 迁移: 加密旧明文密钥
encrypt_totp_secret(&secret, &enc_key)?
};
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET totp_enabled = 1, updated_at = ?1 WHERE id = ?2")
sqlx::query("UPDATE accounts SET totp_enabled = true, totp_secret = $1, updated_at = $2 WHERE id = $3")
.bind(&final_secret)
.bind(&now)
.bind(&ctx.account_id)
.execute(&state.db)
@@ -167,7 +206,7 @@ pub async fn disable_totp(
) -> SaasResult<Json<serde_json::Value>> {
// 验证密码
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = ?1"
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
@@ -179,7 +218,7 @@ pub async fn disable_totp(
// 清除 TOTP
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET totp_enabled = 0, totp_secret = NULL, updated_at = ?1 WHERE id = ?2")
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
.bind(&now)
.bind(&ctx.account_id)
.execute(&state.db)
@@ -190,3 +229,14 @@ pub async fn disable_totp(
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
}
/// 解密 TOTP secret (供 login handler 使用)
/// 返回解密后的明文 secret
pub fn decrypt_totp_for_login(encrypted_secret: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
decrypt_totp_secret(encrypted_secret, enc_key)
} else {
// 兼容旧的明文格式
Ok(encrypted_secret.to_string())
}
}

View File

@@ -14,6 +14,7 @@ pub struct LoginRequest {
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub token: String,
pub refresh_token: String,
pub account: AccountPublic,
}
@@ -54,3 +55,9 @@ pub struct AuthContext {
pub permissions: Vec<String>,
pub client_ip: Option<String>,
}
/// Token 刷新请求
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}