diff --git a/crates/zclaw-saas/src/auth/handlers.rs b/crates/zclaw-saas/src/auth/handlers.rs index f539e2a..f0c5e4b 100644 --- a/crates/zclaw-saas/src/auth/handlers.rs +++ b/crates/zclaw-saas/src/auth/handlers.rs @@ -208,13 +208,17 @@ pub async fn login( return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status))); } - // M2: 检查账号是否被临时锁定 - if let Some(ref locked_until_str) = r.locked_until { - if let Ok(locked_time) = chrono::DateTime::parse_from_rfc3339(locked_until_str) { - if chrono::Utc::now() < locked_time.with_timezone(&chrono::Utc) { - return Err(SaasError::AuthError("账号已被临时锁定,请稍后再试".into())); - } - } + // M2: 检查账号是否被临时锁定 (直接在 SQL 层比较,避免时区解析问题) + let is_locked: bool = sqlx::query_scalar( + "SELECT locked_until IS NOT NULL AND locked_until > NOW() FROM accounts WHERE id = $1" + ) + .bind(&r.id) + .fetch_one(&state.db) + .await + .unwrap_or(false); + + if is_locked { + return Err(SaasError::AuthError("账号已被临时锁定,请稍后再试".into())); } if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? { @@ -580,33 +584,49 @@ fn sha256_hex(input: &str) -> String { pub async fn logout( State(state): State, jar: CookieJar, + Json(req): Json, ) -> (CookieJar, axum::http::StatusCode) { - // 尝试从 cookie 中获取 refresh token 并撤销 - if let Some(refresh_cookie) = jar.get(REFRESH_TOKEN_COOKIE) { - let token = refresh_cookie.value(); - if let Ok(claims) = verify_token_skip_expiry(token, state.jwt_secret.expose_secret()) { - if claims.token_type == "refresh" { - if let Some(jti) = claims.jti { - let now = chrono::Utc::now(); - // 标记 refresh token 为已使用(等效于撤销/黑名单) - let result = sqlx::query( - "UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2 AND used_at IS NULL" - ) - .bind(&now).bind(&jti) - .execute(&state.db) - .await; + let jwt_secret = state.jwt_secret.expose_secret(); - match result { - Ok(r) => { - if r.rows_affected() > 0 { - tracing::info!(account_id = %claims.sub, jti = %jti, "Refresh token revoked on logout"); - } - } - Err(e) => { - tracing::warn!(jti = %jti, error = %e, "Failed to revoke refresh token on logout"); - } + // 收集所有可用的 refresh token 来源 + let mut tokens_to_check: Vec = Vec::new(); + + // 来源 1: 请求 body 中的 refresh_token + if let Some(ref token) = req.refresh_token { + tokens_to_check.push(token.clone()); + } + + // 来源 2: cookie 中的 refresh_token + if let Some(refresh_cookie) = jar.get(REFRESH_TOKEN_COOKIE) { + let cookie_val = refresh_cookie.value().to_string(); + if !tokens_to_check.contains(&cookie_val) { + tokens_to_check.push(cookie_val); + } + } + + // 从任意有效的 refresh token 提取 account_id,然后撤销该账户所有 token + for token in &tokens_to_check { + if let Ok(claims) = verify_token_skip_expiry(token, jwt_secret) { + if claims.token_type == "refresh" { + let now = chrono::Utc::now(); + // 撤销该账户的所有 refresh token (不仅是当前的) + let result = sqlx::query( + "UPDATE refresh_tokens SET used_at = $1 WHERE account_id = $2 AND used_at IS NULL" + ) + .bind(&now) + .bind(&claims.sub) + .execute(&state.db) + .await; + + match result { + Ok(r) => { + tracing::info!(account_id = %claims.sub, n = r.rows_affected(), "All refresh tokens revoked on logout"); + } + Err(e) => { + tracing::warn!(account_id = %claims.sub, error = %e, "Failed to revoke refresh tokens"); } } + break; // 一次成功即可 } } } diff --git a/crates/zclaw-saas/src/auth/types.rs b/crates/zclaw-saas/src/auth/types.rs index 8928d3f..5fdc478 100644 --- a/crates/zclaw-saas/src/auth/types.rs +++ b/crates/zclaw-saas/src/auth/types.rs @@ -62,3 +62,9 @@ pub struct AuthContext { pub struct RefreshRequest { pub refresh_token: String, } + +/// 登出请求 (refresh_token 可选,不传则仅清除 cookie) +#[derive(Debug, Deserialize)] +pub struct LogoutRequest { + pub refresh_token: Option, +} diff --git a/crates/zclaw-saas/src/model_config/service.rs b/crates/zclaw-saas/src/model_config/service.rs index 94b64fd..b2c055f 100644 --- a/crates/zclaw-saas/src/model_config/service.rs +++ b/crates/zclaw-saas/src/model_config/service.rs @@ -89,11 +89,13 @@ pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: String::new() }; + let display_name = req.display_name.as_deref().unwrap_or(&req.name); + sqlx::query( "INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)" ) - .bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key) + .bind(&id).bind(&req.name).bind(display_name).bind(&encrypted_api_key) .bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now) .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider '{}'", req.name)))?; diff --git a/crates/zclaw-saas/src/model_config/types.rs b/crates/zclaw-saas/src/model_config/types.rs index 49b6f2e..1790b12 100644 --- a/crates/zclaw-saas/src/model_config/types.rs +++ b/crates/zclaw-saas/src/model_config/types.rs @@ -21,7 +21,7 @@ pub struct ProviderInfo { #[derive(Debug, Deserialize)] pub struct CreateProviderRequest { pub name: String, - pub display_name: String, + pub display_name: Option, pub base_url: String, #[serde(default = "default_protocol")] pub api_protocol: String,