Files
zclaw_openfang/crates/zclaw-saas/src/auth/mod.rs
iven 7de486bfca
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
test(saas): Phase 1 integration tests — billing + scheduled_task + knowledge (68 tests)
- Fix TIMESTAMPTZ decode errors: add ::TEXT cast to all SELECT queries
  where Row structs use String for TIMESTAMPTZ columns (~22 locations)
- Fix Axum 0.7 route params: {id} → :id in billing/knowledge/scheduled_task routes
- Fix JSONB bind: scheduled_task INSERT uses ::jsonb cast for input_payload
- Add billing_test.rs (14 tests): plans, subscription, usage, payments, invoices
- Add scheduled_task_test.rs (12 tests): CRUD, validation, isolation
- Add knowledge_test.rs (20 tests): categories, items, versions, search, analytics, permissions
- Fix auth test regression: 6 tests were failing due to TIMESTAMPTZ type mismatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-07 14:25:34 +08:00

228 lines
7.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 认证模块
pub mod jwt;
pub mod password;
pub mod types;
pub mod handlers;
pub mod totp;
use axum::{
extract::{Request, State},
http::header,
middleware::Next,
response::{IntoResponse, Response},
extract::ConnectInfo,
};
use secrecy::ExposeSecret;
use crate::error::SaasError;
use crate::state::AppState;
use types::AuthContext;
use std::net::SocketAddr;
/// 通过 API Token 验证身份
///
/// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at
async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<String>) -> Result<AuthContext, SaasError> {
use sha2::{Sha256, Digest};
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
"SELECT account_id, expires_at::TEXT, permissions FROM api_tokens
WHERE token_hash = $1 AND revoked_at IS NULL"
)
.bind(&token_hash)
.fetch_optional(&state.db)
.await?;
let (account_id, expires_at, permissions_json) = row
.ok_or(SaasError::Unauthorized)?;
// 检查是否过期
if let Some(ref exp) = expires_at {
let now = chrono::Utc::now();
if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) {
if now >= exp_time.with_timezone(&chrono::Utc) {
return Err(SaasError::Unauthorized);
}
}
}
// 查询关联账号的角色
let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
)
.bind(&account_id)
.fetch_optional(&state.db)
.await?
.ok_or(SaasError::Unauthorized)?;
// 合并 token 权限与角色权限(去重)
let role_permissions = handlers::get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
let mut permissions = role_permissions;
for p in token_permissions {
if !permissions.contains(&p) {
permissions.push(p);
}
}
// 异步更新 last_used_at — 通过 Worker 通道派发,受 SpawnLimiter 门控
// 替换原来的 tokio::spawn(DB UPDATE),消除每请求无限制 spawn
{
use crate::workers::update_last_used::UpdateLastUsedArgs;
let args = UpdateLastUsedArgs {
token_hash: token_hash.to_string(),
};
if let Err(e) = state.worker_dispatcher.dispatch("update_last_used", args).await {
tracing::debug!("Failed to dispatch update_last_used: {}", e);
}
}
Ok(AuthContext {
account_id,
role,
permissions,
client_ip,
})
}
/// 从请求中提取客户端 IP安全版仅对 trusted_proxies 解析 XFF
fn extract_client_ip(req: &Request, trusted_proxies: &[String]) -> Option<String> {
// 优先从 ConnectInfo 获取直接连接 IP
let connect_ip = req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip().to_string());
// 仅当直接连接 IP 在 trusted_proxies 中时,才信任 XFF/X-Real-IP
if let Some(ref ip) = connect_ip {
if trusted_proxies.iter().any(|p| p == ip) {
// 受信代理 → 从 XFF 取真实客户端 IP
if let Some(forwarded) = req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
if let Some(client) = forwarded.split(',').next() {
let trimmed = client.trim();
if !trimmed.is_empty() {
return Some(trimmed.to_string());
}
}
}
// 尝试 X-Real-IP
if let Some(real_ip) = req.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
{
let trimmed = real_ip.trim();
if !trimmed.is_empty() {
return Some(trimmed.to_string());
}
}
}
}
// 非受信来源或无代理头 → 返回直接连接 IP
connect_ip
}
/// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份
pub async fn auth_middleware(
State(state): State<AppState>,
jar: axum_extra::extract::cookie::CookieJar,
mut req: Request,
next: Next,
) -> Response {
let client_ip = {
let config = state.config.read().await;
extract_client_ip(&req, &config.server.trusted_proxies)
};
let auth_header = req.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
// 尝试从 Authorization header 提取 token
let header_token = auth_header.and_then(|auth| auth.strip_prefix("Bearer "));
// 尝试从 HttpOnly cookie 提取 token (仅当 header 不存在时)
let cookie_token = jar.get("zclaw_access_token").map(|c| c.value().to_string());
let token = header_token
.or(cookie_token.as_deref());
let result = if let Some(token) = token {
if token.starts_with("zclaw_") {
// API Token 路径
verify_api_token(&state, token, client_ip.clone()).await
} else {
// JWT 路径
match jwt::verify_token(token, state.jwt_secret.expose_secret()) {
Ok(claims) => {
// H1: 验证 password_version — 密码变更后旧 token 失效
let pwv_row: Option<(i32,)> = sqlx::query_as(
"SELECT password_version FROM accounts WHERE id = $1"
)
.bind(&claims.sub)
.fetch_optional(&state.db)
.await
.ok()
.flatten();
match pwv_row {
Some((current_pwv,)) if (current_pwv as u32) == claims.pwv => {
Ok(AuthContext {
account_id: claims.sub,
role: claims.role,
permissions: claims.permissions,
client_ip,
})
}
_ => {
tracing::warn!(
account_id = %claims.sub,
token_pwv = claims.pwv,
"Token rejected: password_version mismatch or account not found"
);
Err(SaasError::Unauthorized)
}
}
}
Err(_) => Err(SaasError::Unauthorized),
}
}
} else {
Err(SaasError::Unauthorized)
};
match result {
Ok(ctx) => {
req.extensions_mut().insert(ctx);
next.run(req).await
}
Err(e) => e.into_response(),
}
}
/// 路由 (无需认证的端点)
pub fn routes() -> axum::Router<AppState> {
use axum::routing::post;
axum::Router::new()
.route("/api/v1/auth/register", post(handlers::register))
.route("/api/v1/auth/login", post(handlers::login))
.route("/api/v1/auth/refresh", post(handlers::refresh))
.route("/api/v1/auth/logout", post(handlers::logout))
}
/// 需要认证的路由
pub fn protected_routes() -> axum::Router<AppState> {
use axum::routing::{get, post, put};
axum::Router::new()
.route("/api/v1/auth/me", get(handlers::me))
.route("/api/v1/auth/password", put(handlers::change_password))
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))
.route("/api/v1/auth/totp/verify", post(totp::verify_totp))
.route("/api/v1/auth/totp/disable", post(totp::disable_totp))
}