Files
zclaw_openfang/crates/zclaw-saas/src/main.rs
iven e3b93ff96d fix(security): implement all 15 security fixes from penetration test V1
Security audit (2026-03-31): 5 HIGH + 10 MEDIUM issues, all fixed.

HIGH:
- H1: JWT password_version mechanism (pwv in Claims, middleware verification,
  auto-increment on password change)
- H2: Docker saas port bound to 127.0.0.1
- H3: TOTP encryption key decoupled from JWT secret (production bailout)
- H4+H5: Tauri CSP hardened (removed unsafe-inline, restricted connect-src)

MEDIUM:
- M1: Persistent rate limiting (PostgreSQL rate_limit_events table)
- M2: Account lockout (5 failures -> 15min lock)
- M3: RFC 5322 email validation with regex
- M4: Device registration typed struct with length limits
- M5: Provider URL validation on create/update (SSRF prevention)
- M6: Legacy TOTP secret migration (fixed nonce -> random nonce)
- M7: Legacy frontend crypto migration (static salt -> random salt)
- M8+M9: Admin frontend: removed JS token storage, HttpOnly cookie only
- M10: Pipeline debug log sanitization (keys only, 100-char truncation)

Also: fixed CLAUDE.md Section 12 (was corrupted), added title.rs middleware
skeleton, fixed RegisterDeviceRequest visibility.
2026-04-01 08:38:37 +08:00

332 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ZCLAW SaaS 服务入口
use axum::extract::State;
use tokio_util::sync::CancellationToken;
use tower_http::timeout::TimeoutLayer;
use tracing::info;
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
use zclaw_saas::workers::WorkerDispatcher;
use zclaw_saas::workers::log_operation::LogOperationWorker;
use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
use zclaw_saas::workers::record_usage::RecordUsageWorker;
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "zclaw_saas=debug,tower_http=debug".into()),
)
.init();
let config = SaaSConfig::load()?;
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
let db = init_db(&config.database.url).await?;
info!("Database initialized");
// 初始化 Worker 调度器 + 注册所有 Worker
let mut dispatcher = WorkerDispatcher::new(db.clone());
dispatcher.register(LogOperationWorker);
dispatcher.register(CleanupRefreshTokensWorker);
dispatcher.register(CleanupRateLimitWorker);
dispatcher.register(RecordUsageWorker);
dispatcher.register(UpdateLastUsedWorker);
info!("Worker dispatcher initialized (5 workers registered)");
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
let shutdown_token = CancellationToken::new();
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?;
// Restore rate limit counts from DB so limits survive server restarts
{
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key"
)
.fetch_all(&db)
.await
.unwrap_or_default();
let mut restored_count = 0usize;
for (key, count) in rows {
let mut entries = Vec::new();
// Approximate: insert count timestamps at "now" — the DashMap will
// expire them naturally via the retain() call in the middleware.
// This is intentionally approximate; exact window alignment is not
// required for rate limiting correctness.
for _ in 0..count as usize {
entries.push(std::time::Instant::now());
}
state.rate_limit_entries.insert(key, entries);
restored_count += 1;
}
info!("Restored rate limit state from DB: {} keys", restored_count);
}
// 迁移旧格式 TOTP secret明文 → 加密 enc: 格式)
{
let config_for_migration = state.config.read().await;
if let Ok(enc_key) = config_for_migration.totp_encryption_key() {
drop(config_for_migration);
if let Err(e) = zclaw_saas::crypto::migrate_legacy_totp_secrets(&db, &enc_key).await {
tracing::warn!("TOTP legacy migration check failed: {}", e);
}
} else {
drop(config_for_migration);
}
}
// 启动声明式 Scheduler从 TOML 配置读取定时任务)
let scheduler_config = &config.scheduler;
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
info!("Scheduler started with {} jobs", scheduler_config.jobs.len());
// 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务)
zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone());
// 启动用户定时任务调度循环30s 轮询 scheduled_tasks 表)
zclaw_saas::scheduler::start_user_task_scheduler(db.clone());
// 启动内存中的 rate limit 条目清理
let rate_limit_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
rate_limit_state.cleanup_rate_limit_entries();
}
});
// 初始化缓存并启动定时刷新 (60s)
state.cache.load_from_db(&db).await.map_err(|e| anyhow::anyhow!("{}", e))?;
info!("Cache initialized: {} providers, {} models", state.cache.providers.len(), state.cache.models.len());
{
let cache_state = state.clone();
let db_clone = db.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
loop {
interval.tick().await;
if let Err(e) = cache_state.cache.load_from_db(&db_clone).await {
tracing::warn!("Cache refresh failed: {}", e);
}
cache_state.cache.calibrate_queue_counts(&db_clone).await;
}
});
}
let app = build_router(state).await;
// 配置 TCP keepalive + 短 SO_LINGER防止 CLOSE_WAIT 累积
let listener = create_listener(&config.server.host, config.server.port)?;
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
// 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空
let token = shutdown_token.clone();
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>())
.with_graceful_shutdown(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
info!("Received shutdown signal, cancelling SSE streams and draining connections...");
token.cancel();
})
.await?;
Ok(())
}
/// 创建带 TCP keepalive 和短 SO_LINGER 的 TcpListener防止 CLOSE_WAIT 累积
fn create_listener(host: &str, port: u16) -> anyhow::Result<tokio::net::TcpListener> {
let addr = format!("{}:{}", host, port);
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr.parse::<std::net::SocketAddr>()?),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
// SO_REUSEADDR: 允许快速重启时复用 TIME_WAIT 端口
socket.set_reuse_address(true)?;
// TCP keepalive: 60s 空闲后每 10s 探测,连续 3 次无响应则关闭
// 防止已断开但对端未发 FIN 的连接永远留在 CLOSE_WAIT
let keepalive = socket2::SockRef::from(&socket);
keepalive.set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(std::time::Duration::from_secs(60))
.with_interval(std::time::Duration::from_secs(10)),
)?;
// 短 SO_LINGER (1s): 关闭时最多等 1 秒即 RST避免大量 TIME_WAIT
socket.set_linger(Some(std::time::Duration::from_secs(1)))?;
socket.bind(&addr.parse::<std::net::SocketAddr>()?.into())?;
socket.listen(1024)?;
socket.set_nonblocking(true)?;
Ok(tokio::net::TcpListener::from_std(socket.into())?)
}
async fn health_handler(
State(state): State<AppState>,
) -> (axum::http::StatusCode, axum::Json<serde_json::Value> ) {
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
let db_healthy = tokio::time::timeout(
std::time::Duration::from_secs(3),
sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&state.db),
)
.await
.map(|r| r.is_ok())
.unwrap_or(false);
// 连接池容量检查: 使用率 >= 80% 返回 503 (degraded)
let pool = &state.db;
let total = pool.options().get_max_connections() as usize;
if total > 0 {
let idle = pool.num_idle() as usize;
let used = total - idle;
let ratio = used * 100 / total;
if ratio >= 80 {
return (
axum::http::StatusCode::SERVICE_UNAVAILABLE,
axum::Json(serde_json::json!({
"status": "degraded",
"database": true,
"database_pool": {
"usage_pct": ratio,
"used": used,
"total": total,
},
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION"),
})),
);
}
}
let status = if db_healthy { "healthy" } else { "degraded" };
let code = if db_healthy {
axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE };
(code, axum::Json(serde_json::json!({
"status": status,
"database": db_healthy,
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION"),
})))
}
async fn build_router(state: AppState) -> axum::Router {
use axum::middleware;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use axum::http::HeaderValue;
let cors = {
let config = state.config.read().await;
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if config.server.cors_origins.is_empty() {
if is_dev {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.allow_credentials(true)
} else {
tracing::error!("生产环境必须配置 server.cors_origins不能使用 allow_origin(Any)");
panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
}
} else {
let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
.filter_map(|o: &String| o.parse::<HeaderValue>().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::PATCH,
axum::http::Method::DELETE,
axum::http::Method::OPTIONS,
])
.allow_headers([
axum::http::header::AUTHORIZATION,
axum::http::header::CONTENT_TYPE,
axum::http::header::COOKIE,
axum::http::HeaderName::from_static("x-request-id"),
])
.allow_credentials(true)
}
};
let public_routes = zclaw_saas::auth::routes()
.route("/api/health", axum::routing::get(health_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::public_rate_limit_middleware,
));
let protected_routes = zclaw_saas::auth::protected_routes()
.merge(zclaw_saas::account::routes())
.merge(zclaw_saas::model_config::routes())
// relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并
.merge(zclaw_saas::migration::routes())
.merge(zclaw_saas::role::routes())
.merge(zclaw_saas::prompt::routes())
.merge(zclaw_saas::agent_template::routes())
.merge(zclaw_saas::scheduled_task::routes())
.merge(zclaw_saas::telemetry::routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::api_version_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
));
// 非流式路由应用全局 15s 超时relay SSE 端点需要更长超时)
let non_streaming_routes = axum::Router::new()
.merge(public_routes)
.merge(protected_routes)
.layer(TimeoutLayer::new(std::time::Duration::from_secs(15)));
// Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外)
let relay_routes = zclaw_saas::relay::routes()
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::api_version_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
));
axum::Router::new()
.merge(non_streaming_routes)
.merge(relay_routes)
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)
}