Files
zclaw_openfang/crates/zclaw-saas/src/auth/handlers.rs
iven 7de486bfca
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
test(saas): Phase 1 integration tests — billing + scheduled_task + knowledge (68 tests)
- Fix TIMESTAMPTZ decode errors: add ::TEXT cast to all SELECT queries
  where Row structs use String for TIMESTAMPTZ columns (~22 locations)
- Fix Axum 0.7 route params: {id} → :id in billing/knowledge/scheduled_task routes
- Fix JSONB bind: scheduled_task INSERT uses ::jsonb cast for input_payload
- Add billing_test.rs (14 tests): plans, subscription, usage, payments, invoices
- Add scheduled_task_test.rs (12 tests): CRUD, validation, isolation
- Add knowledge_test.rs (20 tests): categories, items, versions, search, analytics, permissions
- Fix auth test regression: 6 tests were failing due to TIMESTAMPTZ type mismatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-07 14:25:34 +08:00

616 lines
22 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 认证 HTTP 处理器
use axum::{extract::{State, ConnectInfo}, Json};
use axum_extra::extract::CookieJar;
use axum_extra::extract::cookie::{Cookie, SameSite};
use std::net::SocketAddr;
use secrecy::ExposeSecret;
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::models::{AccountAuthRow, AccountLoginRow};
use super::{
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
password::{hash_password_async, verify_password_async},
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
};
/// Cookie 配置常量
const ACCESS_TOKEN_COOKIE: &str = "zclaw_access_token";
const REFRESH_TOKEN_COOKIE: &str = "zclaw_refresh_token";
/// 构建 auth cookies 并附加到 CookieJar
/// secure 标记在开发环境 (ZCLAW_SAAS_DEV=true) 设为 false生产设为 true
fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJar {
let access_max_age = std::time::Duration::from_secs(2 * 3600); // 2h
let refresh_max_age = std::time::Duration::from_secs(7 * 86400); // 7d
let secure = !is_dev_mode();
// cookie crate 需要 time::Duration从 std 转换
let access = Cookie::build((ACCESS_TOKEN_COOKIE, token.to_string()))
.http_only(true)
.secure(secure)
.same_site(SameSite::Strict)
.path("/api")
.max_age(access_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(3600).try_into().unwrap()))
.build();
let refresh = Cookie::build((REFRESH_TOKEN_COOKIE, refresh_token.to_string()))
.http_only(true)
.secure(secure)
.same_site(SameSite::Strict)
.path("/api/v1/auth")
.max_age(refresh_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(86400).try_into().unwrap()))
.build();
jar.add(access).add(refresh)
}
/// 检查是否为开发模式Cookie Secure、CORS 等安全策略依据此判断)
fn is_dev_mode() -> bool {
std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
}
/// 清除 auth cookies
fn clear_auth_cookies(jar: CookieJar) -> CookieJar {
jar.remove(Cookie::build(ACCESS_TOKEN_COOKIE).path("/api"))
.remove(Cookie::build(REFRESH_TOKEN_COOKIE).path("/api/v1/auth"))
}
/// POST /api/v1/auth/register
/// 注册成功后自动签发 JWT返回与 login 一致的 LoginResponse
pub async fn register(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
jar: CookieJar,
Json(req): Json<RegisterRequest>,
) -> SaasResult<(CookieJar, Json<LoginResponse>)> {
if req.username.len() < 3 {
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
}
if req.username.len() > 32 {
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
}
static USERNAME_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
let username_re = USERNAME_RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
if !username_re.is_match(&req.username) {
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
}
if !req.email.contains('@') || !req.email.contains('.') {
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
}
// M3: 严格邮箱格式校验
static EMAIL_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
let email_re = EMAIL_RE.get_or_init(|| regex::Regex::new(
r"^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$"
).unwrap());
if !email_re.is_match(&req.email) {
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
}
if req.password.len() < 8 {
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
}
if req.password.len() > 128 {
return Err(SaasError::InvalidInput("密码最多 128 个字符".into()));
}
if let Some(ref name) = req.display_name {
if name.len() > 64 {
return Err(SaasError::InvalidInput("显示名称最多 64 个字符".into()));
}
}
let existing: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
)
.bind(&req.username)
.bind(&req.email)
.fetch_all(&state.db)
.await?;
if !existing.is_empty() {
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
}
let password_hash = hash_password_async(req.password.clone()).await?;
let account_id = uuid::Uuid::new_v4().to_string();
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
let display_name = req.display_name.unwrap_or_default();
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at, llm_routing)
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7, 'local')"
)
.bind(&account_id)
.bind(&req.username)
.bind(&req.email)
.bind(&password_hash)
.bind(&display_name)
.bind(&role)
.bind(&now)
.execute(&state.db)
.await?;
let client_ip = addr.ip().to_string();
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
// 注册成功后自动签发 JWT + Refresh Token
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
// 查询新创建账户的 password_version (默认为 1)
let (pwv,): (i32,) = sqlx::query_as(
"SELECT password_version FROM accounts WHERE id = $1"
)
.bind(&account_id)
.fetch_one(&state.db)
.await?;
let config = state.config.read().await;
let token = create_token(
&account_id, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
pwv as u32,
)?;
let refresh_token = create_refresh_token(
&account_id, &role, permissions,
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
pwv as u32,
)?;
drop(config);
store_refresh_token(
&state.db, &account_id, &refresh_token,
state.jwt_secret.expose_secret(), 168,
).await?;
let resp = LoginResponse {
token,
refresh_token: refresh_token.clone(),
account: AccountPublic {
id: account_id,
username: req.username,
email: req.email,
display_name,
role,
status: "active".into(),
totp_enabled: false,
created_at: now.to_rfc3339(),
llm_routing: "local".into(),
},
};
let jar = set_auth_cookies(jar, &resp.token, &refresh_token);
Ok((jar, Json(resp)))
}
/// POST /api/v1/auth/login
pub async fn login(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
jar: CookieJar,
Json(req): Json<LoginRequest>,
) -> SaasResult<(CookieJar, Json<LoginResponse>)> {
// 一次查询获取用户信息 + password_hash + totp_secret + 安全字段(合并原来的 3 次查询)
let row: Option<AccountLoginRow> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled,
password_hash, totp_secret, created_at::TEXT, llm_routing,
password_version, failed_login_count, locked_until::TEXT
FROM accounts WHERE username = $1 OR email = $1"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await?;
let r = row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
if r.status != "active" {
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
}
// M2: 检查账号是否被临时锁定
if let Some(ref locked_until_str) = r.locked_until {
if let Ok(locked_time) = chrono::DateTime::parse_from_rfc3339(locked_until_str) {
if chrono::Utc::now() < locked_time.with_timezone(&chrono::Utc) {
return Err(SaasError::AuthError("账号已被临时锁定,请稍后再试".into()));
}
}
}
if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? {
// M2: 密码错误,递增失败计数
let new_count = r.failed_login_count + 1;
if new_count >= 5 {
// 锁定 15 分钟
let locked_until = chrono::Utc::now() + chrono::Duration::minutes(15);
sqlx::query(
"UPDATE accounts SET failed_login_count = $1, locked_until = $2 WHERE id = $3"
)
.bind(new_count)
.bind(&locked_until)
.bind(&r.id)
.execute(&state.db)
.await?;
} else {
sqlx::query(
"UPDATE accounts SET failed_login_count = $1 WHERE id = $2"
)
.bind(new_count)
.bind(&r.id)
.execute(&state.db)
.await?;
}
return Err(SaasError::AuthError("用户名或密码错误".into()));
}
// TOTP 验证: 如果用户已启用 2FA必须提供有效 TOTP 码
if r.totp_enabled {
let code = req.totp_code.as_deref()
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
let secret = r.totp_secret.clone().ok_or_else(|| {
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
})?;
// 解密 TOTP secret (兼容旧的明文格式)
let config = state.config.read().await;
let enc_key = config.totp_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
let secret = super::totp::decrypt_totp_for_login(&secret, &enc_key)?;
if !super::totp::verify_totp_code(&secret, code) {
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
}
}
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &r.role).await?;
let config = state.config.read().await;
let pwv = r.password_version as u32;
let token = create_token(
&r.id, &r.role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
pwv,
)?;
let refresh_token = create_refresh_token(
&r.id, &r.role, permissions,
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
pwv,
)?;
drop(config);
let now = chrono::Utc::now();
// 登录成功: 重置失败计数和锁定状态
sqlx::query("UPDATE accounts SET last_login_at = $1, failed_login_count = 0, locked_until = NULL WHERE id = $2")
.bind(&now).bind(&r.id)
.execute(&state.db).await?;
let client_ip = addr.ip().to_string();
log_operation(&state.db, &r.id, "account.login", "account", &r.id, None, Some(&client_ip)).await?;
store_refresh_token(
&state.db, &r.id, &refresh_token,
state.jwt_secret.expose_secret(), 168,
).await?;
let resp = LoginResponse {
token,
refresh_token: refresh_token.clone(),
account: AccountPublic {
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
llm_routing: r.llm_routing,
},
};
let jar = set_auth_cookies(jar, &resp.token, &refresh_token);
Ok((jar, Json(resp)))
}
/// POST /api/v1/auth/refresh
/// 使用 refresh_token 换取新的 access + refresh token 对
/// refresh_token 一次性使用,使用后立即失效
pub async fn refresh(
State(state): State<AppState>,
jar: CookieJar,
Json(req): Json<RefreshRequest>,
) -> SaasResult<(CookieJar, Json<serde_json::Value>)> {
// 1. 验证 refresh token 签名 (跳过过期检查,但有 7 天窗口限制)
let claims = verify_token_skip_expiry(&req.refresh_token, state.jwt_secret.expose_secret())?;
// 2. 确认是 refresh 类型 token
if claims.token_type != "refresh" {
return Err(SaasError::AuthError("无效的 refresh token".into()));
}
let jti = claims.jti.as_deref()
.ok_or_else(|| SaasError::AuthError("refresh token 缺少 jti".into()))?;
// 3. 从 DB 查找 refresh token确保未被使用
let row: Option<(String,)> = sqlx::query_as(
"SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at > $2"
)
.bind(jti)
.bind(&chrono::Utc::now())
.fetch_optional(&state.db)
.await?;
let token_account_id = row
.ok_or_else(|| SaasError::AuthError("refresh token 已使用、已过期或不存在".into()))?
.0;
// 4. 验证 token 中的 account_id 与 DB 中的一致
if token_account_id != claims.sub {
return Err(SaasError::AuthError("refresh token 账号不匹配".into()));
}
// 5. 标记旧 refresh token 为已使用 (一次性)
let now = chrono::Utc::now();
sqlx::query("UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2")
.bind(&now).bind(jti)
.execute(&state.db).await?;
// 6. 获取最新角色权限 + password_version
let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
)
.bind(&claims.sub)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
let (pwv,): (i32,) = sqlx::query_as(
"SELECT password_version FROM accounts WHERE id = $1"
)
.bind(&claims.sub)
.fetch_one(&state.db)
.await?;
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
// 7. 创建新的 access token + refresh token
let config = state.config.read().await;
let new_access = create_token(
&claims.sub, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
pwv as u32,
)?;
let new_refresh = create_refresh_token(
&claims.sub, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.refresh_token_hours,
pwv as u32,
)?;
drop(config);
// 8. 存储新 refresh token 到 DB
let new_claims = verify_token(&new_refresh, state.jwt_secret.expose_secret())?;
let new_jti = new_claims.jti.unwrap_or_default();
let new_id = uuid::Uuid::new_v4().to_string();
let refresh_expires = chrono::Utc::now() + chrono::Duration::hours(168);
sqlx::query(
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6)"
)
.bind(&new_id).bind(&claims.sub).bind(&new_jti)
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
.execute(&state.db).await?;
// 9. 清理过期/已使用的 refresh tokens 已迁移到 Scheduler 定期执行
// 不再在每次 refresh 时阻塞请求
let jar = set_auth_cookies(jar, &new_access, &new_refresh);
Ok((jar, Json(serde_json::json!({
"token": new_access,
"refresh_token": new_refresh,
}))))
}
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
pub async fn me(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<AccountPublic>> {
let row: Option<AccountAuthRow> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at::TEXT, llm_routing
FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
.await?;
let r = row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
Ok(Json(AccountPublic {
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
llm_routing: r.llm_routing,
}))
}
/// PUT /api/v1/auth/password — 修改密码
pub async fn change_password(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
Json(req): Json<ChangePasswordRequest>,
) -> SaasResult<Json<serde_json::Value>> {
if req.new_password.len() < 8 {
return Err(SaasError::InvalidInput("新密码至少 8 个字符".into()));
}
// 获取当前密码哈希
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await?;
// 验证旧密码
if !verify_password_async(req.old_password.clone(), password_hash.clone()).await? {
return Err(SaasError::AuthError("旧密码错误".into()));
}
// 更新密码 + 递增 password_version 使旧 token 失效
let new_hash = hash_password_async(req.new_password.clone()).await?;
let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2, password_version = password_version + 1 WHERE id = $3")
.bind(&new_hash)
.bind(&now)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;
log_operation(&state.db, &ctx.account_id, "account.change_password", "account", &ctx.account_id,
None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
}
pub(crate) async fn get_role_permissions(
db: &sqlx::PgPool,
cache: &dashmap::DashMap<String, Vec<String>>,
role: &str,
) -> SaasResult<Vec<String>> {
// Check cache first
if let Some(cached) = cache.get(role) {
return Ok(cached.clone());
}
let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = $1"
)
.bind(role)
.fetch_optional(db)
.await?;
let permissions_str = row
.ok_or_else(|| SaasError::Internal(format!("角色 {} 不存在", role)))?
.0;
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
cache.insert(role.to_string(), permissions.clone());
Ok(permissions)
}
/// 检查权限 (admin:full 自动通过所有检查)
pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
if ctx.permissions.contains(&"admin:full".to_string()) {
return Ok(());
}
if !ctx.permissions.contains(&permission.to_string()) {
return Err(SaasError::Forbidden(format!("需要 {} 权限", permission)));
}
Ok(())
}
/// 记录操作日志
pub async fn log_operation(
db: &sqlx::PgPool,
account_id: &str,
action: &str,
target_type: &str,
target_id: &str,
details: Option<serde_json::Value>,
ip_address: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
.bind(account_id)
.bind(action)
.bind(target_type)
.bind(target_id)
.bind(details.map(|d| d.to_string()))
.bind(ip_address)
.bind(&now)
.execute(db)
.await?;
Ok(())
}
/// 存储 refresh token 到 DB
async fn store_refresh_token(
db: &sqlx::PgPool,
account_id: &str,
refresh_token: &str,
secret: &str,
refresh_hours: i64,
) -> SaasResult<()> {
let claims = verify_token(refresh_token, secret)?;
let jti = claims.jti.unwrap_or_default();
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let expires_at = chrono::Utc::now() + chrono::Duration::hours(refresh_hours);
sqlx::query(
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6)"
)
.bind(&id).bind(account_id).bind(&jti)
.bind(sha256_hex(refresh_token)).bind(&expires_at).bind(&now)
.execute(db).await?;
Ok(())
}
/// 清理过期和已使用的 refresh tokens
/// 注意: 现已迁移到 Worker/Scheduler 定期执行,此函数保留作为备用
#[allow(dead_code)]
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
let now = chrono::Utc::now();
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
sqlx::query(
"DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at < $1) OR (expires_at < $1)"
)
.bind(&now)
.execute(db).await?;
Ok(())
}
/// SHA-256 hex digest
fn sha256_hex(input: &str) -> String {
use sha2::{Sha256, Digest};
hex::encode(Sha256::digest(input.as_bytes()))
}
/// POST /api/v1/auth/logout — 撤销 refresh token 并清除 auth cookies
pub async fn logout(
State(state): State<AppState>,
jar: CookieJar,
) -> (CookieJar, axum::http::StatusCode) {
// 尝试从 cookie 中获取 refresh token 并撤销
if let Some(refresh_cookie) = jar.get(REFRESH_TOKEN_COOKIE) {
let token = refresh_cookie.value();
if let Ok(claims) = verify_token_skip_expiry(token, state.jwt_secret.expose_secret()) {
if claims.token_type == "refresh" {
if let Some(jti) = claims.jti {
let now = chrono::Utc::now();
// 标记 refresh token 为已使用(等效于撤销/黑名单)
let result = sqlx::query(
"UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2 AND used_at IS NULL"
)
.bind(&now).bind(&jti)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
tracing::info!(account_id = %claims.sub, jti = %jti, "Refresh token revoked on logout");
}
}
Err(e) => {
tracing::warn!(jti = %jti, error = %e, "Failed to revoke refresh token on logout");
}
}
}
}
}
}
(clear_auth_cookies(jar), axum::http::StatusCode::NO_CONTENT)
}