Files
zclaw_openfang/crates/zclaw-saas/src/main.rs
iven eb956d0dce
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
feat: 新增管理后台前端项目及安全加固
refactor(saas): 重构认证中间件与限流策略
- 登录限流调整为5次/分钟/IP
- 注册限流调整为3次/小时/IP
- GET请求不计入限流

fix(saas): 修复调度器时间戳处理
- 使用NOW()替代文本时间戳
- 兼容TEXT和TIMESTAMPTZ列类型

feat(saas): 实现环境变量插值
- 支持${ENV_VAR}语法解析
- 数据库密码支持环境变量注入

chore: 新增前端管理界面
- 基于React+Ant Design Pro
- 包含路由守卫/错误边界
- 对接58个API端点

docs: 更新安全加固文档
- 新增密钥管理规范
- 记录P0安全项审计结果
- 补充TLS终止说明

test: 完善配置解析单元测试
- 新增环境变量插值测试用例
2026-03-31 00:11:33 +08:00

276 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! ZCLAW SaaS 服务入口
use axum::extract::State;
use tokio_util::sync::CancellationToken;
use tower_http::timeout::TimeoutLayer;
use tracing::info;
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
use zclaw_saas::workers::WorkerDispatcher;
use zclaw_saas::workers::log_operation::LogOperationWorker;
use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
use zclaw_saas::workers::record_usage::RecordUsageWorker;
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "zclaw_saas=debug,tower_http=debug".into()),
)
.init();
let config = SaaSConfig::load()?;
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
let db = init_db(&config.database.url).await?;
info!("Database initialized");
// 初始化 Worker 调度器 + 注册所有 Worker
let mut dispatcher = WorkerDispatcher::new(db.clone());
dispatcher.register(LogOperationWorker);
dispatcher.register(CleanupRefreshTokensWorker);
dispatcher.register(CleanupRateLimitWorker);
dispatcher.register(RecordUsageWorker);
dispatcher.register(UpdateLastUsedWorker);
info!("Worker dispatcher initialized (5 workers registered)");
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
let shutdown_token = CancellationToken::new();
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?;
// 启动声明式 Scheduler从 TOML 配置读取定时任务)
let scheduler_config = &config.scheduler;
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
info!("Scheduler started with {} jobs", scheduler_config.jobs.len());
// 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务)
zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone());
// 启动用户定时任务调度循环30s 轮询 scheduled_tasks 表)
zclaw_saas::scheduler::start_user_task_scheduler(db.clone());
// 启动内存中的 rate limit 条目清理
let rate_limit_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
rate_limit_state.cleanup_rate_limit_entries();
}
});
let app = build_router(state).await;
// 配置 TCP keepalive + 短 SO_LINGER防止 CLOSE_WAIT 累积
let listener = create_listener(&config.server.host, config.server.port)?;
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
// 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空
let token = shutdown_token.clone();
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>())
.with_graceful_shutdown(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
info!("Received shutdown signal, cancelling SSE streams and draining connections...");
token.cancel();
})
.await?;
Ok(())
}
/// 创建带 TCP keepalive 和短 SO_LINGER 的 TcpListener防止 CLOSE_WAIT 累积
fn create_listener(host: &str, port: u16) -> anyhow::Result<tokio::net::TcpListener> {
let addr = format!("{}:{}", host, port);
let socket = socket2::Socket::new(
socket2::Domain::for_address(addr.parse::<std::net::SocketAddr>()?),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
// SO_REUSEADDR: 允许快速重启时复用 TIME_WAIT 端口
socket.set_reuse_address(true)?;
// TCP keepalive: 60s 空闲后每 10s 探测,连续 3 次无响应则关闭
// 防止已断开但对端未发 FIN 的连接永远留在 CLOSE_WAIT
let keepalive = socket2::SockRef::from(&socket);
keepalive.set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(std::time::Duration::from_secs(60))
.with_interval(std::time::Duration::from_secs(10)),
)?;
// 短 SO_LINGER (1s): 关闭时最多等 1 秒即 RST避免大量 TIME_WAIT
socket.set_linger(Some(std::time::Duration::from_secs(1)))?;
socket.bind(&addr.parse::<std::net::SocketAddr>()?.into())?;
socket.listen(1024)?;
socket.set_nonblocking(true)?;
Ok(tokio::net::TcpListener::from_std(socket.into())?)
}
async fn health_handler(
State(state): State<AppState>,
) -> (axum::http::StatusCode, axum::Json<serde_json::Value> ) {
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
let db_healthy = tokio::time::timeout(
std::time::Duration::from_secs(3),
sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&state.db),
)
.await
.map(|r| r.is_ok())
.unwrap_or(false);
// 连接池容量检查: 使用率 >= 80% 返回 503 (degraded)
let pool = &state.db;
let total = pool.options().get_max_connections() as usize;
if total > 0 {
let idle = pool.num_idle() as usize;
let used = total - idle;
let ratio = used * 100 / total;
if ratio >= 80 {
return (
axum::http::StatusCode::SERVICE_UNAVAILABLE,
axum::Json(serde_json::json!({
"status": "degraded",
"database": true,
"database_pool": {
"usage_pct": ratio,
"used": used,
"total": total,
},
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION"),
})),
);
}
}
let status = if db_healthy { "healthy" } else { "degraded" };
let code = if db_healthy {
axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE };
(code, axum::Json(serde_json::json!({
"status": status,
"database": db_healthy,
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION"),
})))
}
async fn build_router(state: AppState) -> axum::Router {
use axum::middleware;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use axum::http::HeaderValue;
let cors = {
let config = state.config.read().await;
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if config.server.cors_origins.is_empty() {
if is_dev {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.allow_credentials(true)
} else {
tracing::error!("生产环境必须配置 server.cors_origins不能使用 allow_origin(Any)");
panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
}
} else {
let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
.filter_map(|o: &String| o.parse::<HeaderValue>().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::PATCH,
axum::http::Method::DELETE,
axum::http::Method::OPTIONS,
])
.allow_headers([
axum::http::header::AUTHORIZATION,
axum::http::header::CONTENT_TYPE,
axum::http::header::COOKIE,
axum::http::HeaderName::from_static("x-request-id"),
])
.allow_credentials(true)
}
};
let public_routes = zclaw_saas::auth::routes()
.route("/api/health", axum::routing::get(health_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::public_rate_limit_middleware,
));
let protected_routes = zclaw_saas::auth::protected_routes()
.merge(zclaw_saas::account::routes())
.merge(zclaw_saas::model_config::routes())
// relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并
.merge(zclaw_saas::migration::routes())
.merge(zclaw_saas::role::routes())
.merge(zclaw_saas::prompt::routes())
.merge(zclaw_saas::agent_template::routes())
.merge(zclaw_saas::scheduled_task::routes())
.merge(zclaw_saas::telemetry::routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::api_version_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
));
// 非流式路由应用全局 15s 超时relay SSE 端点需要更长超时)
let non_streaming_routes = axum::Router::new()
.merge(public_routes)
.merge(protected_routes)
.layer(TimeoutLayer::new(std::time::Duration::from_secs(15)));
// Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外)
let relay_routes = zclaw_saas::relay::routes()
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::api_version_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
));
axum::Router::new()
.merge(non_streaming_routes)
.merge(relay_routes)
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)
}