Security audit (2026-03-31): 5 HIGH + 10 MEDIUM issues, all fixed. HIGH: - H1: JWT password_version mechanism (pwv in Claims, middleware verification, auto-increment on password change) - H2: Docker saas port bound to 127.0.0.1 - H3: TOTP encryption key decoupled from JWT secret (production bailout) - H4+H5: Tauri CSP hardened (removed unsafe-inline, restricted connect-src) MEDIUM: - M1: Persistent rate limiting (PostgreSQL rate_limit_events table) - M2: Account lockout (5 failures -> 15min lock) - M3: RFC 5322 email validation with regex - M4: Device registration typed struct with length limits - M5: Provider URL validation on create/update (SSRF prevention) - M6: Legacy TOTP secret migration (fixed nonce -> random nonce) - M7: Legacy frontend crypto migration (static salt -> random salt) - M8+M9: Admin frontend: removed JS token storage, HttpOnly cookie only - M10: Pipeline debug log sanitization (keys only, 100-char truncation) Also: fixed CLAUDE.md Section 12 (was corrupted), added title.rs middleware skeleton, fixed RegisterDeviceRequest visibility.
332 lines
13 KiB
Rust
332 lines
13 KiB
Rust
//! ZCLAW SaaS 服务入口
|
||
|
||
use axum::extract::State;
|
||
use tokio_util::sync::CancellationToken;
|
||
use tower_http::timeout::TimeoutLayer;
|
||
use tracing::info;
|
||
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
|
||
use zclaw_saas::workers::WorkerDispatcher;
|
||
use zclaw_saas::workers::log_operation::LogOperationWorker;
|
||
use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
||
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
||
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
||
|
||
#[tokio::main]
|
||
async fn main() -> anyhow::Result<()> {
|
||
tracing_subscriber::fmt()
|
||
.with_env_filter(
|
||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||
.unwrap_or_else(|_| "zclaw_saas=debug,tower_http=debug".into()),
|
||
)
|
||
.init();
|
||
|
||
let config = SaaSConfig::load()?;
|
||
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
|
||
|
||
let db = init_db(&config.database.url).await?;
|
||
info!("Database initialized");
|
||
|
||
// 初始化 Worker 调度器 + 注册所有 Worker
|
||
let mut dispatcher = WorkerDispatcher::new(db.clone());
|
||
dispatcher.register(LogOperationWorker);
|
||
dispatcher.register(CleanupRefreshTokensWorker);
|
||
dispatcher.register(CleanupRateLimitWorker);
|
||
dispatcher.register(RecordUsageWorker);
|
||
dispatcher.register(UpdateLastUsedWorker);
|
||
info!("Worker dispatcher initialized (5 workers registered)");
|
||
|
||
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
|
||
let shutdown_token = CancellationToken::new();
|
||
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?;
|
||
|
||
// Restore rate limit counts from DB so limits survive server restarts
|
||
{
|
||
let rows: Vec<(String, i64)> = sqlx::query_as(
|
||
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key"
|
||
)
|
||
.fetch_all(&db)
|
||
.await
|
||
.unwrap_or_default();
|
||
|
||
let mut restored_count = 0usize;
|
||
for (key, count) in rows {
|
||
let mut entries = Vec::new();
|
||
// Approximate: insert count timestamps at "now" — the DashMap will
|
||
// expire them naturally via the retain() call in the middleware.
|
||
// This is intentionally approximate; exact window alignment is not
|
||
// required for rate limiting correctness.
|
||
for _ in 0..count as usize {
|
||
entries.push(std::time::Instant::now());
|
||
}
|
||
state.rate_limit_entries.insert(key, entries);
|
||
restored_count += 1;
|
||
}
|
||
info!("Restored rate limit state from DB: {} keys", restored_count);
|
||
}
|
||
|
||
// 迁移旧格式 TOTP secret(明文 → 加密 enc: 格式)
|
||
{
|
||
let config_for_migration = state.config.read().await;
|
||
if let Ok(enc_key) = config_for_migration.totp_encryption_key() {
|
||
drop(config_for_migration);
|
||
if let Err(e) = zclaw_saas::crypto::migrate_legacy_totp_secrets(&db, &enc_key).await {
|
||
tracing::warn!("TOTP legacy migration check failed: {}", e);
|
||
}
|
||
} else {
|
||
drop(config_for_migration);
|
||
}
|
||
}
|
||
|
||
// 启动声明式 Scheduler(从 TOML 配置读取定时任务)
|
||
let scheduler_config = &config.scheduler;
|
||
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
|
||
info!("Scheduler started with {} jobs", scheduler_config.jobs.len());
|
||
|
||
// 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务)
|
||
zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone());
|
||
|
||
// 启动用户定时任务调度循环(30s 轮询 scheduled_tasks 表)
|
||
zclaw_saas::scheduler::start_user_task_scheduler(db.clone());
|
||
|
||
// 启动内存中的 rate limit 条目清理
|
||
let rate_limit_state = state.clone();
|
||
tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||
loop {
|
||
interval.tick().await;
|
||
rate_limit_state.cleanup_rate_limit_entries();
|
||
}
|
||
});
|
||
|
||
// 初始化缓存并启动定时刷新 (60s)
|
||
state.cache.load_from_db(&db).await.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||
info!("Cache initialized: {} providers, {} models", state.cache.providers.len(), state.cache.models.len());
|
||
{
|
||
let cache_state = state.clone();
|
||
let db_clone = db.clone();
|
||
tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
|
||
loop {
|
||
interval.tick().await;
|
||
if let Err(e) = cache_state.cache.load_from_db(&db_clone).await {
|
||
tracing::warn!("Cache refresh failed: {}", e);
|
||
}
|
||
cache_state.cache.calibrate_queue_counts(&db_clone).await;
|
||
}
|
||
});
|
||
}
|
||
|
||
let app = build_router(state).await;
|
||
|
||
// 配置 TCP keepalive + 短 SO_LINGER,防止 CLOSE_WAIT 累积
|
||
let listener = create_listener(&config.server.host, config.server.port)?;
|
||
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
|
||
|
||
// 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空
|
||
let token = shutdown_token.clone();
|
||
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>())
|
||
.with_graceful_shutdown(async move {
|
||
tokio::signal::ctrl_c()
|
||
.await
|
||
.expect("Failed to install Ctrl+C handler");
|
||
info!("Received shutdown signal, cancelling SSE streams and draining connections...");
|
||
token.cancel();
|
||
})
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 创建带 TCP keepalive 和短 SO_LINGER 的 TcpListener,防止 CLOSE_WAIT 累积
|
||
fn create_listener(host: &str, port: u16) -> anyhow::Result<tokio::net::TcpListener> {
|
||
let addr = format!("{}:{}", host, port);
|
||
let socket = socket2::Socket::new(
|
||
socket2::Domain::for_address(addr.parse::<std::net::SocketAddr>()?),
|
||
socket2::Type::STREAM,
|
||
Some(socket2::Protocol::TCP),
|
||
)?;
|
||
|
||
// SO_REUSEADDR: 允许快速重启时复用 TIME_WAIT 端口
|
||
socket.set_reuse_address(true)?;
|
||
|
||
// TCP keepalive: 60s 空闲后每 10s 探测,连续 3 次无响应则关闭
|
||
// 防止已断开但对端未发 FIN 的连接永远留在 CLOSE_WAIT
|
||
let keepalive = socket2::SockRef::from(&socket);
|
||
keepalive.set_tcp_keepalive(
|
||
&socket2::TcpKeepalive::new()
|
||
.with_time(std::time::Duration::from_secs(60))
|
||
.with_interval(std::time::Duration::from_secs(10)),
|
||
)?;
|
||
|
||
// 短 SO_LINGER (1s): 关闭时最多等 1 秒即 RST,避免大量 TIME_WAIT
|
||
socket.set_linger(Some(std::time::Duration::from_secs(1)))?;
|
||
|
||
socket.bind(&addr.parse::<std::net::SocketAddr>()?.into())?;
|
||
socket.listen(1024)?;
|
||
socket.set_nonblocking(true)?;
|
||
|
||
Ok(tokio::net::TcpListener::from_std(socket.into())?)
|
||
}
|
||
|
||
async fn health_handler(
|
||
State(state): State<AppState>,
|
||
) -> (axum::http::StatusCode, axum::Json<serde_json::Value> ) {
|
||
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
|
||
let db_healthy = tokio::time::timeout(
|
||
std::time::Duration::from_secs(3),
|
||
sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&state.db),
|
||
)
|
||
.await
|
||
.map(|r| r.is_ok())
|
||
.unwrap_or(false);
|
||
|
||
// 连接池容量检查: 使用率 >= 80% 返回 503 (degraded)
|
||
let pool = &state.db;
|
||
let total = pool.options().get_max_connections() as usize;
|
||
if total > 0 {
|
||
let idle = pool.num_idle() as usize;
|
||
let used = total - idle;
|
||
let ratio = used * 100 / total;
|
||
if ratio >= 80 {
|
||
return (
|
||
axum::http::StatusCode::SERVICE_UNAVAILABLE,
|
||
axum::Json(serde_json::json!({
|
||
"status": "degraded",
|
||
"database": true,
|
||
"database_pool": {
|
||
"usage_pct": ratio,
|
||
"used": used,
|
||
"total": total,
|
||
},
|
||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||
"version": env!("CARGO_PKG_VERSION"),
|
||
})),
|
||
);
|
||
}
|
||
}
|
||
|
||
let status = if db_healthy { "healthy" } else { "degraded" };
|
||
let code = if db_healthy {
|
||
axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE };
|
||
|
||
(code, axum::Json(serde_json::json!({
|
||
"status": status,
|
||
"database": db_healthy,
|
||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||
"version": env!("CARGO_PKG_VERSION"),
|
||
})))
|
||
}
|
||
|
||
async fn build_router(state: AppState) -> axum::Router {
|
||
use axum::middleware;
|
||
use tower_http::cors::{Any, CorsLayer};
|
||
use tower_http::trace::TraceLayer;
|
||
|
||
use axum::http::HeaderValue;
|
||
let cors = {
|
||
let config = state.config.read().await;
|
||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||
.map(|v| v == "true" || v == "1")
|
||
.unwrap_or(false);
|
||
if config.server.cors_origins.is_empty() {
|
||
if is_dev {
|
||
CorsLayer::new()
|
||
.allow_origin(Any)
|
||
.allow_methods(Any)
|
||
.allow_headers(Any)
|
||
.allow_credentials(true)
|
||
} else {
|
||
tracing::error!("生产环境必须配置 server.cors_origins,不能使用 allow_origin(Any)");
|
||
panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
|
||
}
|
||
} else {
|
||
let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
|
||
.filter_map(|o: &String| o.parse::<HeaderValue>().ok())
|
||
.collect();
|
||
CorsLayer::new()
|
||
.allow_origin(origins)
|
||
.allow_methods([
|
||
axum::http::Method::GET,
|
||
axum::http::Method::POST,
|
||
axum::http::Method::PUT,
|
||
axum::http::Method::PATCH,
|
||
axum::http::Method::DELETE,
|
||
axum::http::Method::OPTIONS,
|
||
])
|
||
.allow_headers([
|
||
axum::http::header::AUTHORIZATION,
|
||
axum::http::header::CONTENT_TYPE,
|
||
axum::http::header::COOKIE,
|
||
axum::http::HeaderName::from_static("x-request-id"),
|
||
])
|
||
.allow_credentials(true)
|
||
}
|
||
};
|
||
|
||
let public_routes = zclaw_saas::auth::routes()
|
||
.route("/api/health", axum::routing::get(health_handler))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::public_rate_limit_middleware,
|
||
));
|
||
|
||
let protected_routes = zclaw_saas::auth::protected_routes()
|
||
.merge(zclaw_saas::account::routes())
|
||
.merge(zclaw_saas::model_config::routes())
|
||
// relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并
|
||
.merge(zclaw_saas::migration::routes())
|
||
.merge(zclaw_saas::role::routes())
|
||
.merge(zclaw_saas::prompt::routes())
|
||
.merge(zclaw_saas::agent_template::routes())
|
||
.merge(zclaw_saas::scheduled_task::routes())
|
||
.merge(zclaw_saas::telemetry::routes())
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::api_version_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::request_id_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::rate_limit_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::auth::auth_middleware,
|
||
));
|
||
|
||
// 非流式路由应用全局 15s 超时(relay SSE 端点需要更长超时)
|
||
let non_streaming_routes = axum::Router::new()
|
||
.merge(public_routes)
|
||
.merge(protected_routes)
|
||
.layer(TimeoutLayer::new(std::time::Duration::from_secs(15)));
|
||
|
||
// Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外)
|
||
let relay_routes = zclaw_saas::relay::routes()
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::api_version_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::request_id_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::rate_limit_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::auth::auth_middleware,
|
||
));
|
||
|
||
axum::Router::new()
|
||
.merge(non_streaming_routes)
|
||
.merge(relay_routes)
|
||
.layer(TraceLayer::new_for_http())
|
||
.layer(cors)
|
||
.with_state(state)
|
||
}
|