Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- Fix TIMESTAMPTZ decode errors: add ::TEXT cast to all SELECT queries
where Row structs use String for TIMESTAMPTZ columns (~22 locations)
- Fix Axum 0.7 route params: {id} → :id in billing/knowledge/scheduled_task routes
- Fix JSONB bind: scheduled_task INSERT uses ::jsonb cast for input_payload
- Add billing_test.rs (14 tests): plans, subscription, usage, payments, invoices
- Add scheduled_task_test.rs (12 tests): CRUD, validation, isolation
- Add knowledge_test.rs (20 tests): categories, items, versions, search, analytics, permissions
- Fix auth test regression: 6 tests were failing due to TIMESTAMPTZ type mismatch
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
616 lines
22 KiB
Rust
616 lines
22 KiB
Rust
//! 认证 HTTP 处理器
|
||
|
||
use axum::{extract::{State, ConnectInfo}, Json};
|
||
use axum_extra::extract::CookieJar;
|
||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||
use std::net::SocketAddr;
|
||
use secrecy::ExposeSecret;
|
||
use crate::state::AppState;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::models::{AccountAuthRow, AccountLoginRow};
|
||
use super::{
|
||
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
|
||
password::{hash_password_async, verify_password_async},
|
||
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
|
||
};
|
||
|
||
/// Cookie 配置常量
|
||
const ACCESS_TOKEN_COOKIE: &str = "zclaw_access_token";
|
||
const REFRESH_TOKEN_COOKIE: &str = "zclaw_refresh_token";
|
||
|
||
/// 构建 auth cookies 并附加到 CookieJar
|
||
/// secure 标记在开发环境 (ZCLAW_SAAS_DEV=true) 设为 false,生产设为 true
|
||
fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJar {
|
||
let access_max_age = std::time::Duration::from_secs(2 * 3600); // 2h
|
||
let refresh_max_age = std::time::Duration::from_secs(7 * 86400); // 7d
|
||
let secure = !is_dev_mode();
|
||
|
||
// cookie crate 需要 time::Duration,从 std 转换
|
||
let access = Cookie::build((ACCESS_TOKEN_COOKIE, token.to_string()))
|
||
.http_only(true)
|
||
.secure(secure)
|
||
.same_site(SameSite::Strict)
|
||
.path("/api")
|
||
.max_age(access_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(3600).try_into().unwrap()))
|
||
.build();
|
||
|
||
let refresh = Cookie::build((REFRESH_TOKEN_COOKIE, refresh_token.to_string()))
|
||
.http_only(true)
|
||
.secure(secure)
|
||
.same_site(SameSite::Strict)
|
||
.path("/api/v1/auth")
|
||
.max_age(refresh_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(86400).try_into().unwrap()))
|
||
.build();
|
||
|
||
jar.add(access).add(refresh)
|
||
}
|
||
|
||
/// 检查是否为开发模式(Cookie Secure、CORS 等安全策略依据此判断)
|
||
fn is_dev_mode() -> bool {
|
||
std::env::var("ZCLAW_SAAS_DEV")
|
||
.map(|v| v == "true" || v == "1")
|
||
.unwrap_or(false)
|
||
}
|
||
|
||
/// 清除 auth cookies
|
||
fn clear_auth_cookies(jar: CookieJar) -> CookieJar {
|
||
jar.remove(Cookie::build(ACCESS_TOKEN_COOKIE).path("/api"))
|
||
.remove(Cookie::build(REFRESH_TOKEN_COOKIE).path("/api/v1/auth"))
|
||
}
|
||
|
||
/// POST /api/v1/auth/register
|
||
/// 注册成功后自动签发 JWT,返回与 login 一致的 LoginResponse
|
||
pub async fn register(
|
||
State(state): State<AppState>,
|
||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||
jar: CookieJar,
|
||
Json(req): Json<RegisterRequest>,
|
||
) -> SaasResult<(CookieJar, Json<LoginResponse>)> {
|
||
if req.username.len() < 3 {
|
||
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
|
||
}
|
||
if req.username.len() > 32 {
|
||
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
|
||
}
|
||
static USERNAME_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
|
||
let username_re = USERNAME_RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
|
||
if !username_re.is_match(&req.username) {
|
||
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
|
||
}
|
||
if !req.email.contains('@') || !req.email.contains('.') {
|
||
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||
}
|
||
// M3: 严格邮箱格式校验
|
||
static EMAIL_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
|
||
let email_re = EMAIL_RE.get_or_init(|| regex::Regex::new(
|
||
r"^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$"
|
||
).unwrap());
|
||
if !email_re.is_match(&req.email) {
|
||
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||
}
|
||
if req.password.len() < 8 {
|
||
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
||
}
|
||
if req.password.len() > 128 {
|
||
return Err(SaasError::InvalidInput("密码最多 128 个字符".into()));
|
||
}
|
||
if let Some(ref name) = req.display_name {
|
||
if name.len() > 64 {
|
||
return Err(SaasError::InvalidInput("显示名称最多 64 个字符".into()));
|
||
}
|
||
}
|
||
|
||
let existing: Vec<(String,)> = sqlx::query_as(
|
||
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
|
||
)
|
||
.bind(&req.username)
|
||
.bind(&req.email)
|
||
.fetch_all(&state.db)
|
||
.await?;
|
||
|
||
if !existing.is_empty() {
|
||
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
|
||
}
|
||
|
||
let password_hash = hash_password_async(req.password.clone()).await?;
|
||
let account_id = uuid::Uuid::new_v4().to_string();
|
||
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
|
||
let display_name = req.display_name.unwrap_or_default();
|
||
let now = chrono::Utc::now();
|
||
|
||
sqlx::query(
|
||
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at, llm_routing)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7, 'local')"
|
||
)
|
||
.bind(&account_id)
|
||
.bind(&req.username)
|
||
.bind(&req.email)
|
||
.bind(&password_hash)
|
||
.bind(&display_name)
|
||
.bind(&role)
|
||
.bind(&now)
|
||
.execute(&state.db)
|
||
.await?;
|
||
|
||
let client_ip = addr.ip().to_string();
|
||
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
||
|
||
// 注册成功后自动签发 JWT + Refresh Token
|
||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||
// 查询新创建账户的 password_version (默认为 1)
|
||
let (pwv,): (i32,) = sqlx::query_as(
|
||
"SELECT password_version FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&account_id)
|
||
.fetch_one(&state.db)
|
||
.await?;
|
||
let config = state.config.read().await;
|
||
let token = create_token(
|
||
&account_id, &role, permissions.clone(),
|
||
state.jwt_secret.expose_secret(),
|
||
config.auth.jwt_expiration_hours,
|
||
pwv as u32,
|
||
)?;
|
||
let refresh_token = create_refresh_token(
|
||
&account_id, &role, permissions,
|
||
state.jwt_secret.expose_secret(),
|
||
config.auth.refresh_token_hours,
|
||
pwv as u32,
|
||
)?;
|
||
drop(config);
|
||
|
||
store_refresh_token(
|
||
&state.db, &account_id, &refresh_token,
|
||
state.jwt_secret.expose_secret(), 168,
|
||
).await?;
|
||
|
||
let resp = LoginResponse {
|
||
token,
|
||
refresh_token: refresh_token.clone(),
|
||
account: AccountPublic {
|
||
id: account_id,
|
||
username: req.username,
|
||
email: req.email,
|
||
display_name,
|
||
role,
|
||
status: "active".into(),
|
||
totp_enabled: false,
|
||
created_at: now.to_rfc3339(),
|
||
llm_routing: "local".into(),
|
||
},
|
||
};
|
||
let jar = set_auth_cookies(jar, &resp.token, &refresh_token);
|
||
Ok((jar, Json(resp)))
|
||
}
|
||
|
||
/// POST /api/v1/auth/login
|
||
pub async fn login(
|
||
State(state): State<AppState>,
|
||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||
jar: CookieJar,
|
||
Json(req): Json<LoginRequest>,
|
||
) -> SaasResult<(CookieJar, Json<LoginResponse>)> {
|
||
// 一次查询获取用户信息 + password_hash + totp_secret + 安全字段(合并原来的 3 次查询)
|
||
let row: Option<AccountLoginRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, username, email, display_name, role, status, totp_enabled,
|
||
password_hash, totp_secret, created_at::TEXT, llm_routing,
|
||
password_version, failed_login_count, locked_until::TEXT
|
||
FROM accounts WHERE username = $1 OR email = $1"
|
||
)
|
||
.bind(&req.username)
|
||
.fetch_optional(&state.db)
|
||
.await?;
|
||
|
||
let r = row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
|
||
|
||
if r.status != "active" {
|
||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
|
||
}
|
||
|
||
// M2: 检查账号是否被临时锁定
|
||
if let Some(ref locked_until_str) = r.locked_until {
|
||
if let Ok(locked_time) = chrono::DateTime::parse_from_rfc3339(locked_until_str) {
|
||
if chrono::Utc::now() < locked_time.with_timezone(&chrono::Utc) {
|
||
return Err(SaasError::AuthError("账号已被临时锁定,请稍后再试".into()));
|
||
}
|
||
}
|
||
}
|
||
|
||
if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? {
|
||
// M2: 密码错误,递增失败计数
|
||
let new_count = r.failed_login_count + 1;
|
||
if new_count >= 5 {
|
||
// 锁定 15 分钟
|
||
let locked_until = chrono::Utc::now() + chrono::Duration::minutes(15);
|
||
sqlx::query(
|
||
"UPDATE accounts SET failed_login_count = $1, locked_until = $2 WHERE id = $3"
|
||
)
|
||
.bind(new_count)
|
||
.bind(&locked_until)
|
||
.bind(&r.id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
} else {
|
||
sqlx::query(
|
||
"UPDATE accounts SET failed_login_count = $1 WHERE id = $2"
|
||
)
|
||
.bind(new_count)
|
||
.bind(&r.id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
}
|
||
return Err(SaasError::AuthError("用户名或密码错误".into()));
|
||
}
|
||
|
||
// TOTP 验证: 如果用户已启用 2FA,必须提供有效 TOTP 码
|
||
if r.totp_enabled {
|
||
let code = req.totp_code.as_deref()
|
||
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
||
|
||
let secret = r.totp_secret.clone().ok_or_else(|| {
|
||
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
||
})?;
|
||
|
||
// 解密 TOTP secret (兼容旧的明文格式)
|
||
let config = state.config.read().await;
|
||
let enc_key = config.totp_encryption_key()
|
||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||
let secret = super::totp::decrypt_totp_for_login(&secret, &enc_key)?;
|
||
|
||
if !super::totp::verify_totp_code(&secret, code) {
|
||
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
|
||
}
|
||
}
|
||
|
||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &r.role).await?;
|
||
let config = state.config.read().await;
|
||
let pwv = r.password_version as u32;
|
||
let token = create_token(
|
||
&r.id, &r.role, permissions.clone(),
|
||
state.jwt_secret.expose_secret(),
|
||
config.auth.jwt_expiration_hours,
|
||
pwv,
|
||
)?;
|
||
let refresh_token = create_refresh_token(
|
||
&r.id, &r.role, permissions,
|
||
state.jwt_secret.expose_secret(),
|
||
config.auth.refresh_token_hours,
|
||
pwv,
|
||
)?;
|
||
drop(config);
|
||
|
||
let now = chrono::Utc::now();
|
||
// 登录成功: 重置失败计数和锁定状态
|
||
sqlx::query("UPDATE accounts SET last_login_at = $1, failed_login_count = 0, locked_until = NULL WHERE id = $2")
|
||
.bind(&now).bind(&r.id)
|
||
.execute(&state.db).await?;
|
||
let client_ip = addr.ip().to_string();
|
||
log_operation(&state.db, &r.id, "account.login", "account", &r.id, None, Some(&client_ip)).await?;
|
||
|
||
store_refresh_token(
|
||
&state.db, &r.id, &refresh_token,
|
||
state.jwt_secret.expose_secret(), 168,
|
||
).await?;
|
||
|
||
let resp = LoginResponse {
|
||
token,
|
||
refresh_token: refresh_token.clone(),
|
||
account: AccountPublic {
|
||
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
|
||
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
|
||
llm_routing: r.llm_routing,
|
||
},
|
||
};
|
||
let jar = set_auth_cookies(jar, &resp.token, &refresh_token);
|
||
Ok((jar, Json(resp)))
|
||
}
|
||
|
||
/// POST /api/v1/auth/refresh
|
||
/// 使用 refresh_token 换取新的 access + refresh token 对
|
||
/// refresh_token 一次性使用,使用后立即失效
|
||
pub async fn refresh(
|
||
State(state): State<AppState>,
|
||
jar: CookieJar,
|
||
Json(req): Json<RefreshRequest>,
|
||
) -> SaasResult<(CookieJar, Json<serde_json::Value>)> {
|
||
// 1. 验证 refresh token 签名 (跳过过期检查,但有 7 天窗口限制)
|
||
let claims = verify_token_skip_expiry(&req.refresh_token, state.jwt_secret.expose_secret())?;
|
||
|
||
// 2. 确认是 refresh 类型 token
|
||
if claims.token_type != "refresh" {
|
||
return Err(SaasError::AuthError("无效的 refresh token".into()));
|
||
}
|
||
|
||
let jti = claims.jti.as_deref()
|
||
.ok_or_else(|| SaasError::AuthError("refresh token 缺少 jti".into()))?;
|
||
|
||
// 3. 从 DB 查找 refresh token,确保未被使用
|
||
let row: Option<(String,)> = sqlx::query_as(
|
||
"SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at > $2"
|
||
)
|
||
.bind(jti)
|
||
.bind(&chrono::Utc::now())
|
||
.fetch_optional(&state.db)
|
||
.await?;
|
||
|
||
let token_account_id = row
|
||
.ok_or_else(|| SaasError::AuthError("refresh token 已使用、已过期或不存在".into()))?
|
||
.0;
|
||
|
||
// 4. 验证 token 中的 account_id 与 DB 中的一致
|
||
if token_account_id != claims.sub {
|
||
return Err(SaasError::AuthError("refresh token 账号不匹配".into()));
|
||
}
|
||
|
||
// 5. 标记旧 refresh token 为已使用 (一次性)
|
||
let now = chrono::Utc::now();
|
||
sqlx::query("UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2")
|
||
.bind(&now).bind(jti)
|
||
.execute(&state.db).await?;
|
||
|
||
// 6. 获取最新角色权限 + password_version
|
||
let (role,): (String,) = sqlx::query_as(
|
||
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
|
||
)
|
||
.bind(&claims.sub)
|
||
.fetch_optional(&state.db)
|
||
.await?
|
||
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
|
||
|
||
let (pwv,): (i32,) = sqlx::query_as(
|
||
"SELECT password_version FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&claims.sub)
|
||
.fetch_one(&state.db)
|
||
.await?;
|
||
|
||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||
|
||
// 7. 创建新的 access token + refresh token
|
||
let config = state.config.read().await;
|
||
let new_access = create_token(
|
||
&claims.sub, &role, permissions.clone(),
|
||
state.jwt_secret.expose_secret(),
|
||
config.auth.jwt_expiration_hours,
|
||
pwv as u32,
|
||
)?;
|
||
let new_refresh = create_refresh_token(
|
||
&claims.sub, &role, permissions.clone(),
|
||
state.jwt_secret.expose_secret(),
|
||
config.auth.refresh_token_hours,
|
||
pwv as u32,
|
||
)?;
|
||
drop(config);
|
||
|
||
// 8. 存储新 refresh token 到 DB
|
||
let new_claims = verify_token(&new_refresh, state.jwt_secret.expose_secret())?;
|
||
let new_jti = new_claims.jti.unwrap_or_default();
|
||
let new_id = uuid::Uuid::new_v4().to_string();
|
||
let refresh_expires = chrono::Utc::now() + chrono::Duration::hours(168);
|
||
sqlx::query(
|
||
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6)"
|
||
)
|
||
.bind(&new_id).bind(&claims.sub).bind(&new_jti)
|
||
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
|
||
.execute(&state.db).await?;
|
||
|
||
// 9. 清理过期/已使用的 refresh tokens 已迁移到 Scheduler 定期执行
|
||
// 不再在每次 refresh 时阻塞请求
|
||
|
||
let jar = set_auth_cookies(jar, &new_access, &new_refresh);
|
||
Ok((jar, Json(serde_json::json!({
|
||
"token": new_access,
|
||
"refresh_token": new_refresh,
|
||
}))))
|
||
}
|
||
|
||
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
||
pub async fn me(
|
||
State(state): State<AppState>,
|
||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||
) -> SaasResult<Json<AccountPublic>> {
|
||
let row: Option<AccountAuthRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at::TEXT, llm_routing
|
||
FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&ctx.account_id)
|
||
.fetch_optional(&state.db)
|
||
.await?;
|
||
|
||
let r = row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||
|
||
Ok(Json(AccountPublic {
|
||
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
|
||
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
|
||
llm_routing: r.llm_routing,
|
||
}))
|
||
}
|
||
|
||
/// PUT /api/v1/auth/password — 修改密码
|
||
pub async fn change_password(
|
||
State(state): State<AppState>,
|
||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||
Json(req): Json<ChangePasswordRequest>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
if req.new_password.len() < 8 {
|
||
return Err(SaasError::InvalidInput("新密码至少 8 个字符".into()));
|
||
}
|
||
|
||
// 获取当前密码哈希
|
||
let (password_hash,): (String,) = sqlx::query_as(
|
||
"SELECT password_hash FROM accounts WHERE id = $1"
|
||
)
|
||
.bind(&ctx.account_id)
|
||
.fetch_one(&state.db)
|
||
.await?;
|
||
|
||
// 验证旧密码
|
||
if !verify_password_async(req.old_password.clone(), password_hash.clone()).await? {
|
||
return Err(SaasError::AuthError("旧密码错误".into()));
|
||
}
|
||
|
||
// 更新密码 + 递增 password_version 使旧 token 失效
|
||
let new_hash = hash_password_async(req.new_password.clone()).await?;
|
||
let now = chrono::Utc::now();
|
||
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2, password_version = password_version + 1 WHERE id = $3")
|
||
.bind(&new_hash)
|
||
.bind(&now)
|
||
.bind(&ctx.account_id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
|
||
log_operation(&state.db, &ctx.account_id, "account.change_password", "account", &ctx.account_id,
|
||
None, ctx.client_ip.as_deref()).await?;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
||
}
|
||
|
||
pub(crate) async fn get_role_permissions(
|
||
db: &sqlx::PgPool,
|
||
cache: &dashmap::DashMap<String, Vec<String>>,
|
||
role: &str,
|
||
) -> SaasResult<Vec<String>> {
|
||
// Check cache first
|
||
if let Some(cached) = cache.get(role) {
|
||
return Ok(cached.clone());
|
||
}
|
||
|
||
let row: Option<(String,)> = sqlx::query_as(
|
||
"SELECT permissions FROM roles WHERE id = $1"
|
||
)
|
||
.bind(role)
|
||
.fetch_optional(db)
|
||
.await?;
|
||
|
||
let permissions_str = row
|
||
.ok_or_else(|| SaasError::Internal(format!("角色 {} 不存在", role)))?
|
||
.0;
|
||
|
||
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
||
cache.insert(role.to_string(), permissions.clone());
|
||
Ok(permissions)
|
||
}
|
||
|
||
/// 检查权限 (admin:full 自动通过所有检查)
|
||
pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
|
||
if ctx.permissions.contains(&"admin:full".to_string()) {
|
||
return Ok(());
|
||
}
|
||
if !ctx.permissions.contains(&permission.to_string()) {
|
||
return Err(SaasError::Forbidden(format!("需要 {} 权限", permission)));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
/// 记录操作日志
|
||
pub async fn log_operation(
|
||
db: &sqlx::PgPool,
|
||
account_id: &str,
|
||
action: &str,
|
||
target_type: &str,
|
||
target_id: &str,
|
||
details: Option<serde_json::Value>,
|
||
ip_address: Option<&str>,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
sqlx::query(
|
||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||
)
|
||
.bind(account_id)
|
||
.bind(action)
|
||
.bind(target_type)
|
||
.bind(target_id)
|
||
.bind(details.map(|d| d.to_string()))
|
||
.bind(ip_address)
|
||
.bind(&now)
|
||
.execute(db)
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 存储 refresh token 到 DB
|
||
async fn store_refresh_token(
|
||
db: &sqlx::PgPool,
|
||
account_id: &str,
|
||
refresh_token: &str,
|
||
secret: &str,
|
||
refresh_hours: i64,
|
||
) -> SaasResult<()> {
|
||
let claims = verify_token(refresh_token, secret)?;
|
||
let jti = claims.jti.unwrap_or_default();
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
let expires_at = chrono::Utc::now() + chrono::Duration::hours(refresh_hours);
|
||
|
||
sqlx::query(
|
||
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6)"
|
||
)
|
||
.bind(&id).bind(account_id).bind(&jti)
|
||
.bind(sha256_hex(refresh_token)).bind(&expires_at).bind(&now)
|
||
.execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 清理过期和已使用的 refresh tokens
|
||
/// 注意: 现已迁移到 Worker/Scheduler 定期执行,此函数保留作为备用
|
||
#[allow(dead_code)]
|
||
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
|
||
sqlx::query(
|
||
"DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at < $1) OR (expires_at < $1)"
|
||
)
|
||
.bind(&now)
|
||
.execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// SHA-256 hex digest
|
||
fn sha256_hex(input: &str) -> String {
|
||
use sha2::{Sha256, Digest};
|
||
hex::encode(Sha256::digest(input.as_bytes()))
|
||
}
|
||
|
||
/// POST /api/v1/auth/logout — 撤销 refresh token 并清除 auth cookies
|
||
pub async fn logout(
|
||
State(state): State<AppState>,
|
||
jar: CookieJar,
|
||
) -> (CookieJar, axum::http::StatusCode) {
|
||
// 尝试从 cookie 中获取 refresh token 并撤销
|
||
if let Some(refresh_cookie) = jar.get(REFRESH_TOKEN_COOKIE) {
|
||
let token = refresh_cookie.value();
|
||
if let Ok(claims) = verify_token_skip_expiry(token, state.jwt_secret.expose_secret()) {
|
||
if claims.token_type == "refresh" {
|
||
if let Some(jti) = claims.jti {
|
||
let now = chrono::Utc::now();
|
||
// 标记 refresh token 为已使用(等效于撤销/黑名单)
|
||
let result = sqlx::query(
|
||
"UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2 AND used_at IS NULL"
|
||
)
|
||
.bind(&now).bind(&jti)
|
||
.execute(&state.db)
|
||
.await;
|
||
|
||
match result {
|
||
Ok(r) => {
|
||
if r.rows_affected() > 0 {
|
||
tracing::info!(account_id = %claims.sub, jti = %jti, "Refresh token revoked on logout");
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(jti = %jti, error = %e, "Failed to revoke refresh token on logout");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
(clear_auth_cookies(jar), axum::http::StatusCode::NO_CONTENT)
|
||
}
|