Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P2 code quality (SEC2-P2-01~10): - P2-04: Replace vague TODO with detailed Phase 2 design note in generate_embedding.rs - P2-05: Add NOTE(fire-and-forget) annotations to 4 long-running tokio::spawn in main.rs - P2-07: Add DESIGN NOTE to scheduler explaining sequential execution rationale - P2-08: Add compile-time table name whitelist + runtime char validation in db.rs - P2-02: Verified N/A (only zclaw-pipeline uses serde_yaml_bw, no inconsistency) - P2-06: Verified N/A (bind loop correctly matches 6-column placeholders) - P2-03: Remains OPEN (requires upstream sqlx release) Config HTTP method alignment (B3-4): - Fix admin-v2 config.ts: request.patch -> request.put to match backend .put() route - Fix backend handler doc comment: PATCH -> PUT - Add @reserved annotations to 6 config handlers without frontend callers
447 lines
18 KiB
Rust
447 lines
18 KiB
Rust
//! ZCLAW SaaS 服务入口
|
||
|
||
use axum::extract::State;
|
||
use tokio_util::sync::CancellationToken;
|
||
use tower_http::timeout::TimeoutLayer;
|
||
use tracing::info;
|
||
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
|
||
use zclaw_saas::workers::WorkerDispatcher;
|
||
use zclaw_saas::workers::log_operation::LogOperationWorker;
|
||
use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
||
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
||
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
||
use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker;
|
||
use zclaw_saas::workers::generate_embedding::GenerateEmbeddingWorker;
|
||
|
||
#[tokio::main]
|
||
async fn main() -> anyhow::Result<()> {
|
||
// Load .env file from project root (walk up from current dir)
|
||
load_dotenv();
|
||
|
||
tracing_subscriber::fmt()
|
||
.with_env_filter(
|
||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||
.unwrap_or_else(|_| "zclaw_saas=debug,tower_http=debug".into()),
|
||
)
|
||
.init();
|
||
|
||
let config = SaaSConfig::load()?;
|
||
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
|
||
|
||
let db = init_db(&config.database).await?;
|
||
info!("Database initialized");
|
||
|
||
// 创建 Worker spawn 限制器(门控并发 DB 操作数量)
|
||
let worker_limiter = zclaw_saas::state::SpawnLimiter::new(
|
||
"worker",
|
||
config.database.worker_concurrency,
|
||
);
|
||
info!("Worker spawn limiter: {} permits", config.database.worker_concurrency);
|
||
|
||
// 初始化 Worker 调度器 + 注册所有 Worker
|
||
let mut dispatcher = WorkerDispatcher::new(db.clone(), worker_limiter.clone());
|
||
dispatcher.register(LogOperationWorker);
|
||
dispatcher.register(CleanupRefreshTokensWorker);
|
||
dispatcher.register(CleanupRateLimitWorker);
|
||
dispatcher.register(RecordUsageWorker);
|
||
dispatcher.register(UpdateLastUsedWorker);
|
||
dispatcher.register(AggregateUsageWorker);
|
||
dispatcher.register(GenerateEmbeddingWorker);
|
||
info!("Worker dispatcher initialized (7 workers registered)");
|
||
|
||
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
|
||
let shutdown_token = CancellationToken::new();
|
||
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone(), worker_limiter.clone())?;
|
||
|
||
// Restore rate limit counts from DB so limits survive server restarts
|
||
// 仅恢复最近 60s 的计数(与 middleware 的 60s 滑动窗口一致),避免过于保守的限流
|
||
{
|
||
let rows: Vec<(String, i64)> = sqlx::query_as(
|
||
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '60 seconds' GROUP BY key"
|
||
)
|
||
.fetch_all(&db)
|
||
.await
|
||
.unwrap_or_default();
|
||
|
||
let mut restored_count = 0usize;
|
||
for (key, count) in rows {
|
||
// 限制恢复计数不超过 RPM 配额,避免重启后过于保守
|
||
let rpm = state.rate_limit_rpm() as usize;
|
||
let capped = (count as usize).min(rpm);
|
||
let mut entries = Vec::with_capacity(capped);
|
||
for _ in 0..capped {
|
||
entries.push(std::time::Instant::now());
|
||
}
|
||
state.rate_limit_entries.insert(key, entries);
|
||
restored_count += 1;
|
||
}
|
||
info!("Restored rate limit state from DB: {} keys (60s window, capped at RPM)", restored_count);
|
||
}
|
||
|
||
// 迁移旧格式 TOTP secret(明文 → 加密 enc: 格式)
|
||
{
|
||
let config_for_migration = state.config.read().await;
|
||
if let Ok(enc_key) = config_for_migration.totp_encryption_key() {
|
||
drop(config_for_migration);
|
||
if let Err(e) = zclaw_saas::crypto::migrate_legacy_totp_secrets(&db, &enc_key).await {
|
||
tracing::warn!("TOTP legacy migration check failed: {}", e);
|
||
}
|
||
} else {
|
||
drop(config_for_migration);
|
||
}
|
||
}
|
||
|
||
// 启动声明式 Scheduler(从 TOML 配置读取定时任务)
|
||
let scheduler_config = &config.scheduler;
|
||
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
|
||
info!("Scheduler started with {} jobs", scheduler_config.jobs.len());
|
||
|
||
// 启动内置 DB 清理任务(设备清理等不通过 Worker 的任务)
|
||
zclaw_saas::scheduler::start_db_cleanup_tasks(db.clone());
|
||
|
||
// 启动用户定时任务调度循环(30s 轮询 scheduled_tasks 表)
|
||
zclaw_saas::scheduler::start_user_task_scheduler(db.clone());
|
||
|
||
// 启动内存中的 rate limit 条目清理
|
||
// NOTE (fire-and-forget): Long-running background service task. JoinHandle not bound
|
||
// because this task runs for the lifetime of the process and is cancelled on shutdown.
|
||
let rate_limit_state = state.clone();
|
||
tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||
loop {
|
||
interval.tick().await;
|
||
rate_limit_state.cleanup_rate_limit_entries();
|
||
}
|
||
});
|
||
|
||
// 初始化缓存并启动定时刷新 (60s)
|
||
state.cache.load_from_db(&db).await.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||
info!("Cache initialized: {} providers, {} models", state.cache.providers.len(), state.cache.models.len());
|
||
{
|
||
let cache_state = state.clone();
|
||
let db_clone = db.clone();
|
||
// NOTE (fire-and-forget): Long-running cache refresh service. JoinHandle not bound.
|
||
tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
|
||
loop {
|
||
interval.tick().await;
|
||
if let Err(e) = cache_state.cache.load_from_db(&db_clone).await {
|
||
tracing::warn!("Cache refresh failed: {}", e);
|
||
}
|
||
cache_state.cache.calibrate_queue_counts(&db_clone).await;
|
||
}
|
||
});
|
||
}
|
||
|
||
// 限流事件批量 flush (可配置间隔,默认 5s)
|
||
{
|
||
let flush_state = state.clone();
|
||
let batch_interval = config.database.rate_limit_batch_interval_secs;
|
||
let batch_max = config.database.rate_limit_batch_max_size;
|
||
// NOTE (fire-and-forget): Long-running rate limit flush service. JoinHandle not bound.
|
||
tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(batch_interval));
|
||
loop {
|
||
interval.tick().await;
|
||
flush_state.flush_rate_limit_batch(batch_max).await;
|
||
}
|
||
});
|
||
}
|
||
|
||
// 连接池可观测性 (30s 指标日志)
|
||
{
|
||
let metrics_db = db.clone();
|
||
// NOTE (fire-and-forget): Long-running pool metrics service. JoinHandle not bound.
|
||
tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||
loop {
|
||
interval.tick().await;
|
||
let pool = &metrics_db;
|
||
let total = pool.options().get_max_connections() as usize;
|
||
let idle = pool.num_idle() as usize;
|
||
let used = total.saturating_sub(idle);
|
||
let usage_pct = if total > 0 { used * 100 / total } else { 0 };
|
||
tracing::info!(
|
||
"[PoolMetrics] total={} idle={} used={} usage_pct={}%",
|
||
total, idle, used, usage_pct,
|
||
);
|
||
if usage_pct >= 80 {
|
||
tracing::warn!(
|
||
"[PoolMetrics] HIGH USAGE: {}% of connections in use!",
|
||
usage_pct,
|
||
);
|
||
}
|
||
}
|
||
});
|
||
}
|
||
|
||
let app = build_router(state.clone()).await;
|
||
|
||
// 配置 TCP keepalive + 短 SO_LINGER,防止 CLOSE_WAIT 累积
|
||
let listener = create_listener(&config.server.host, config.server.port)?;
|
||
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
|
||
|
||
// 优雅停机: Ctrl+C → 最终批量 flush → 取消 CancellationToken → SSE 流终止 → 连接排空
|
||
let token = shutdown_token.clone();
|
||
let flush_state = state;
|
||
let batch_max = config.database.rate_limit_batch_max_size;
|
||
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>())
|
||
.with_graceful_shutdown(async move {
|
||
tokio::signal::ctrl_c()
|
||
.await
|
||
.expect("Failed to install Ctrl+C handler");
|
||
info!("Received shutdown signal, flushing pending rate limit batch...");
|
||
flush_state.flush_rate_limit_batch(batch_max).await;
|
||
info!("Cancelling SSE streams and draining connections...");
|
||
token.cancel();
|
||
})
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 创建带 TCP keepalive 和短 SO_LINGER 的 TcpListener,防止 CLOSE_WAIT 累积
|
||
fn create_listener(host: &str, port: u16) -> anyhow::Result<tokio::net::TcpListener> {
|
||
let addr = format!("{}:{}", host, port);
|
||
let socket = socket2::Socket::new(
|
||
socket2::Domain::for_address(addr.parse::<std::net::SocketAddr>()?),
|
||
socket2::Type::STREAM,
|
||
Some(socket2::Protocol::TCP),
|
||
)?;
|
||
|
||
// SO_REUSEADDR: 允许快速重启时复用 TIME_WAIT 端口
|
||
socket.set_reuse_address(true)?;
|
||
|
||
// TCP keepalive: 60s 空闲后每 10s 探测,连续 3 次无响应则关闭
|
||
// 防止已断开但对端未发 FIN 的连接永远留在 CLOSE_WAIT
|
||
let keepalive = socket2::SockRef::from(&socket);
|
||
keepalive.set_tcp_keepalive(
|
||
&socket2::TcpKeepalive::new()
|
||
.with_time(std::time::Duration::from_secs(60))
|
||
.with_interval(std::time::Duration::from_secs(10)),
|
||
)?;
|
||
|
||
// 短 SO_LINGER (1s): 关闭时最多等 1 秒即 RST,避免大量 TIME_WAIT
|
||
socket.set_linger(Some(std::time::Duration::from_secs(1)))?;
|
||
|
||
socket.bind(&addr.parse::<std::net::SocketAddr>()?.into())?;
|
||
socket.listen(1024)?;
|
||
socket.set_nonblocking(true)?;
|
||
|
||
Ok(tokio::net::TcpListener::from_std(socket.into())?)
|
||
}
|
||
|
||
async fn health_handler(
|
||
State(state): State<AppState>,
|
||
) -> (axum::http::StatusCode, axum::Json<serde_json::Value> ) {
|
||
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
|
||
let db_healthy = tokio::time::timeout(
|
||
std::time::Duration::from_secs(3),
|
||
sqlx::query_scalar::<_, i32>("SELECT 1").fetch_one(&state.db),
|
||
)
|
||
.await
|
||
.map(|r| r.is_ok())
|
||
.unwrap_or(false);
|
||
|
||
// 连接池容量检查: 使用率 >= 80% 返回 503 (degraded)
|
||
let pool = &state.db;
|
||
let total = pool.options().get_max_connections() as usize;
|
||
if total > 0 {
|
||
let idle = pool.num_idle() as usize;
|
||
let used = total - idle;
|
||
let ratio = used * 100 / total;
|
||
if ratio >= 80 {
|
||
return (
|
||
axum::http::StatusCode::SERVICE_UNAVAILABLE,
|
||
axum::Json(serde_json::json!({
|
||
"status": "degraded",
|
||
"database": true,
|
||
"database_pool": {
|
||
"usage_pct": ratio,
|
||
"used": used,
|
||
"total": total,
|
||
},
|
||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||
"version": env!("CARGO_PKG_VERSION"),
|
||
})),
|
||
);
|
||
}
|
||
}
|
||
|
||
let status = if db_healthy { "healthy" } else { "degraded" };
|
||
let code = if db_healthy {
|
||
axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE };
|
||
|
||
(code, axum::Json(serde_json::json!({
|
||
"status": status,
|
||
"database": db_healthy,
|
||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||
"version": env!("CARGO_PKG_VERSION"),
|
||
})))
|
||
}
|
||
|
||
async fn build_router(state: AppState) -> axum::Router {
|
||
use axum::middleware;
|
||
use tower_http::cors::{Any, CorsLayer};
|
||
use tower_http::trace::TraceLayer;
|
||
|
||
use axum::http::HeaderValue;
|
||
let cors = {
|
||
let config = state.config.read().await;
|
||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||
.map(|v| v == "true" || v == "1")
|
||
.unwrap_or(false);
|
||
if config.server.cors_origins.is_empty() {
|
||
if is_dev {
|
||
CorsLayer::new()
|
||
.allow_origin(Any)
|
||
.allow_methods(Any)
|
||
.allow_headers(Any)
|
||
.allow_credentials(true)
|
||
} else {
|
||
tracing::error!("生产环境必须配置 server.cors_origins,不能使用 allow_origin(Any)");
|
||
panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
|
||
}
|
||
} else {
|
||
let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
|
||
.filter_map(|o: &String| o.parse::<HeaderValue>().ok())
|
||
.collect();
|
||
CorsLayer::new()
|
||
.allow_origin(origins)
|
||
.allow_methods([
|
||
axum::http::Method::GET,
|
||
axum::http::Method::POST,
|
||
axum::http::Method::PUT,
|
||
axum::http::Method::PATCH,
|
||
axum::http::Method::DELETE,
|
||
axum::http::Method::OPTIONS,
|
||
])
|
||
.allow_headers([
|
||
axum::http::header::AUTHORIZATION,
|
||
axum::http::header::CONTENT_TYPE,
|
||
axum::http::header::COOKIE,
|
||
axum::http::HeaderName::from_static("x-request-id"),
|
||
])
|
||
.allow_credentials(true)
|
||
}
|
||
};
|
||
|
||
let public_routes = zclaw_saas::auth::routes()
|
||
.route("/api/health", axum::routing::get(health_handler))
|
||
.merge(zclaw_saas::billing::callback_routes())
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::public_rate_limit_middleware,
|
||
));
|
||
|
||
let protected_routes = zclaw_saas::auth::protected_routes()
|
||
.merge(zclaw_saas::account::routes())
|
||
.merge(zclaw_saas::model_config::routes())
|
||
// relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并
|
||
.merge(zclaw_saas::migration::routes())
|
||
.merge(zclaw_saas::role::routes())
|
||
.merge(zclaw_saas::prompt::routes())
|
||
.merge(zclaw_saas::agent_template::routes())
|
||
.merge(zclaw_saas::scheduled_task::routes())
|
||
.merge(zclaw_saas::telemetry::routes())
|
||
.merge(zclaw_saas::billing::routes())
|
||
.merge(zclaw_saas::knowledge::routes())
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::api_version_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::request_id_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::rate_limit_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::auth::auth_middleware,
|
||
));
|
||
|
||
// 非流式路由应用全局 15s 超时(relay SSE 端点需要更长超时)
|
||
let non_streaming_routes = axum::Router::new()
|
||
.merge(public_routes)
|
||
.merge(protected_routes)
|
||
.layer(TimeoutLayer::new(std::time::Duration::from_secs(15)));
|
||
|
||
// Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外)
|
||
let relay_routes = zclaw_saas::relay::routes()
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::api_version_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::request_id_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::quota_check_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::middleware::rate_limit_middleware,
|
||
))
|
||
.layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
zclaw_saas::auth::auth_middleware,
|
||
));
|
||
|
||
let mut router = axum::Router::new()
|
||
.merge(non_streaming_routes)
|
||
.merge(relay_routes);
|
||
|
||
// 开发模式挂载 mock 支付页面
|
||
{
|
||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||
.map(|v| v == "true" || v == "1")
|
||
.unwrap_or(false);
|
||
if is_dev {
|
||
router = router.merge(zclaw_saas::billing::mock_routes());
|
||
info!("Mock payment routes mounted (dev mode)");
|
||
}
|
||
}
|
||
|
||
router
|
||
.layer(TraceLayer::new_for_http())
|
||
.layer(cors)
|
||
.with_state(state)
|
||
}
|
||
|
||
/// Load `.env` file from project root by walking up from current directory.
|
||
/// Sets environment variables that are not already set (does not override).
|
||
fn load_dotenv() {
|
||
let mut dir = std::env::current_dir().unwrap_or_default();
|
||
loop {
|
||
let env_path = dir.join(".env");
|
||
if env_path.is_file() {
|
||
if let Ok(content) = std::fs::read_to_string(&env_path) {
|
||
for line in content.lines() {
|
||
let line = line.trim();
|
||
if line.is_empty() || line.starts_with('#') {
|
||
continue;
|
||
}
|
||
if let Some((key, value)) = line.split_once('=') {
|
||
let key = key.trim();
|
||
let value = value.trim();
|
||
// Only set if not already defined in environment
|
||
if std::env::var(key).is_err() {
|
||
std::env::set_var(key, value);
|
||
}
|
||
}
|
||
}
|
||
tracing::debug!("Loaded .env from {}", env_path.display());
|
||
}
|
||
return;
|
||
}
|
||
if !dir.pop() {
|
||
break;
|
||
}
|
||
}
|
||
}
|