//! ZCLAW SaaS 服务入口 use axum::extract::State; use tokio_util::sync::CancellationToken; use tower_http::timeout::TimeoutLayer; use tracing::info; use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState}; use zclaw_saas::workers::WorkerDispatcher; use zclaw_saas::workers::log_operation::LogOperationWorker; use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker; use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker; use zclaw_saas::workers::record_usage::RecordUsageWorker; use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker; #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter( tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| "zclaw_saas=debug,tower_http=debug".into()), ) .init(); let config = SaaSConfig::load()?; info!("SaaS config loaded: {}:{}", config.server.host, config.server.port); let db = init_db(&config.database.url).await?; info!("Database initialized"); // 初始化 Worker 调度器 + 注册所有 Worker let mut dispatcher = WorkerDispatcher::new(db.clone()); dispatcher.register(LogOperationWorker); dispatcher.register(CleanupRefreshTokensWorker); dispatcher.register(CleanupRateLimitWorker); dispatcher.register(RecordUsageWorker); dispatcher.register(UpdateLastUsedWorker); info!("Worker dispatcher initialized (5 workers registered)"); // 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止 let shutdown_token = CancellationToken::new(); let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?; // Restore rate limit counts from DB so limits survive server restarts { let rows: Vec<(String, i64)> = sqlx::query_as( "SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key" ) .fetch_all(&db) .await .unwrap_or_default(); let mut restored_count = 0usize; for (key, count) in rows { let mut entries = Vec::new(); // Approximate: insert count timestamps at "now" — the DashMap will // expire them naturally via the retain() call in the middleware. // This is intentionally approximate; exact window alignment is not // required for rate limiting correctness. for _ in 0..count as usize { entries.push(std::time::Instant::now()); } state.rate_limit_entries.insert(key, entries); restored_count += 1; } info!("Restored rate limit state from DB: {} keys", restored_count); } // 迁移旧格式 TOTP secret(明文 → 加密 enc: 格式) { let config_for_migration = state.config.read().await; if let Ok(enc_key) = config_for_migration.totp_encryption_key() { drop(config_for_migration); if let Err(e) = zclaw_saas::crypto::migrate_legacy_totp_secrets(&db, &enc_key).await { tracing::warn!("TOTP legacy migration check failed: {}", e); } } else { drop(config_for_migration); } } // 启动声明式 Scheduler(从 TOML 配置读取定时任务) let scheduler_config = &config.scheduler; zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref()); info!("Scheduler started with {} jobs", scheduler_config.jobs.len()); // 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务) zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone()); // 启动用户定时任务调度循环(30s 轮询 scheduled_tasks 表) zclaw_saas::scheduler::start_user_task_scheduler(db.clone()); // 启动内存中的 rate limit 条目清理 let rate_limit_state = state.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); loop { interval.tick().await; rate_limit_state.cleanup_rate_limit_entries(); } }); // 初始化缓存并启动定时刷新 (60s) state.cache.load_from_db(&db).await.map_err(|e| anyhow::anyhow!("{}", e))?; info!("Cache initialized: {} providers, {} models", state.cache.providers.len(), state.cache.models.len()); { let cache_state = state.clone(); let db_clone = db.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(60)); loop { interval.tick().await; if let Err(e) = cache_state.cache.load_from_db(&db_clone).await { tracing::warn!("Cache refresh failed: {}", e); } cache_state.cache.calibrate_queue_counts(&db_clone).await; } }); } let app = build_router(state).await; // 配置 TCP keepalive + 短 SO_LINGER,防止 CLOSE_WAIT 累积 let listener = create_listener(&config.server.host, config.server.port)?; info!("SaaS server listening on {}:{}", config.server.host, config.server.port); // 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空 let token = shutdown_token.clone(); axum::serve(listener, app.into_make_service_with_connect_info::()) .with_graceful_shutdown(async move { tokio::signal::ctrl_c() .await .expect("Failed to install Ctrl+C handler"); info!("Received shutdown signal, cancelling SSE streams and draining connections..."); token.cancel(); }) .await?; Ok(()) } /// 创建带 TCP keepalive 和短 SO_LINGER 的 TcpListener,防止 CLOSE_WAIT 累积 fn create_listener(host: &str, port: u16) -> anyhow::Result { let addr = format!("{}:{}", host, port); let socket = socket2::Socket::new( socket2::Domain::for_address(addr.parse::()?), socket2::Type::STREAM, Some(socket2::Protocol::TCP), )?; // SO_REUSEADDR: 允许快速重启时复用 TIME_WAIT 端口 socket.set_reuse_address(true)?; // TCP keepalive: 60s 空闲后每 10s 探测,连续 3 次无响应则关闭 // 防止已断开但对端未发 FIN 的连接永远留在 CLOSE_WAIT let keepalive = socket2::SockRef::from(&socket); keepalive.set_tcp_keepalive( &socket2::TcpKeepalive::new() .with_time(std::time::Duration::from_secs(60)) .with_interval(std::time::Duration::from_secs(10)), )?; // 短 SO_LINGER (1s): 关闭时最多等 1 秒即 RST,避免大量 TIME_WAIT socket.set_linger(Some(std::time::Duration::from_secs(1)))?; socket.bind(&addr.parse::()?.into())?; socket.listen(1024)?; socket.set_nonblocking(true)?; Ok(tokio::net::TcpListener::from_std(socket.into())?) } async fn health_handler( State(state): State, ) -> (axum::http::StatusCode, axum::Json ) { // health 必须独立快速返回,用 3s 超时避免连接池满时阻塞 let db_healthy = tokio::time::timeout( std::time::Duration::from_secs(3), sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&state.db), ) .await .map(|r| r.is_ok()) .unwrap_or(false); // 连接池容量检查: 使用率 >= 80% 返回 503 (degraded) let pool = &state.db; let total = pool.options().get_max_connections() as usize; if total > 0 { let idle = pool.num_idle() as usize; let used = total - idle; let ratio = used * 100 / total; if ratio >= 80 { return ( axum::http::StatusCode::SERVICE_UNAVAILABLE, axum::Json(serde_json::json!({ "status": "degraded", "database": true, "database_pool": { "usage_pct": ratio, "used": used, "total": total, }, "timestamp": chrono::Utc::now().to_rfc3339(), "version": env!("CARGO_PKG_VERSION"), })), ); } } let status = if db_healthy { "healthy" } else { "degraded" }; let code = if db_healthy { axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE }; (code, axum::Json(serde_json::json!({ "status": status, "database": db_healthy, "timestamp": chrono::Utc::now().to_rfc3339(), "version": env!("CARGO_PKG_VERSION"), }))) } async fn build_router(state: AppState) -> axum::Router { use axum::middleware; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; use axum::http::HeaderValue; let cors = { let config = state.config.read().await; let is_dev = std::env::var("ZCLAW_SAAS_DEV") .map(|v| v == "true" || v == "1") .unwrap_or(false); if config.server.cors_origins.is_empty() { if is_dev { CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_headers(Any) .allow_credentials(true) } else { tracing::error!("生产环境必须配置 server.cors_origins,不能使用 allow_origin(Any)"); panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。"); } } else { let origins: Vec = config.server.cors_origins.iter() .filter_map(|o: &String| o.parse::().ok()) .collect(); CorsLayer::new() .allow_origin(origins) .allow_methods([ axum::http::Method::GET, axum::http::Method::POST, axum::http::Method::PUT, axum::http::Method::PATCH, axum::http::Method::DELETE, axum::http::Method::OPTIONS, ]) .allow_headers([ axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE, axum::http::header::COOKIE, axum::http::HeaderName::from_static("x-request-id"), ]) .allow_credentials(true) } }; let public_routes = zclaw_saas::auth::routes() .route("/api/health", axum::routing::get(health_handler)) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::public_rate_limit_middleware, )); let protected_routes = zclaw_saas::auth::protected_routes() .merge(zclaw_saas::account::routes()) .merge(zclaw_saas::model_config::routes()) // relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并 .merge(zclaw_saas::migration::routes()) .merge(zclaw_saas::role::routes()) .merge(zclaw_saas::prompt::routes()) .merge(zclaw_saas::agent_template::routes()) .merge(zclaw_saas::scheduled_task::routes()) .merge(zclaw_saas::telemetry::routes()) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::api_version_middleware, )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::request_id_middleware, )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::rate_limit_middleware, )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, )); // 非流式路由应用全局 15s 超时(relay SSE 端点需要更长超时) let non_streaming_routes = axum::Router::new() .merge(public_routes) .merge(protected_routes) .layer(TimeoutLayer::new(std::time::Duration::from_secs(15))); // Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外) let relay_routes = zclaw_saas::relay::routes() .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::api_version_middleware, )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::request_id_middleware, )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::rate_limit_middleware, )) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, )); axum::Router::new() .merge(non_streaming_routes) .merge(relay_routes) .layer(TraceLayer::new_for_http()) .layer(cors) .with_state(state) }