From 9487cd7f72bd4e7beff30929f6fbf18fe1c27cb0 Mon Sep 17 00:00:00 2001 From: iven Date: Wed, 1 Apr 2026 23:59:46 +0800 Subject: [PATCH] =?UTF-8?q?feat(saas):=20add=20billing=20infrastructure=20?= =?UTF-8?q?=E2=80=94=20tables,=20types,=20service,=20handlers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit B1.1 Billing database: - 5 tables: billing_plans, billing_subscriptions, billing_invoices, billing_payments, billing_usage_quotas - Seed data: Free(¥0)/Pro(¥49)/Team(¥199) plans - JSONB limits for flexible plan configuration Billing module (crates/zclaw-saas/src/billing/): - types.rs: BillingPlan, Subscription, Invoice, Payment, UsageQuota - service.rs: plan CRUD, subscription lookup, usage tracking, quota check - handlers.rs: REST API (plans list/detail, subscription, usage) - mod.rs: routes registered at /api/v1/billing/* Cargo.toml: added chrono feature to sqlx for DateTime support --- Cargo.toml | 2 +- .../20260402000001_billing_tables.sql | 133 +++++++++++ crates/zclaw-saas/src/billing/handlers.rs | 55 +++++ crates/zclaw-saas/src/billing/mod.rs | 15 ++ crates/zclaw-saas/src/billing/service.rs | 206 ++++++++++++++++++ crates/zclaw-saas/src/billing/types.rs | 161 ++++++++++++++ crates/zclaw-saas/src/db.rs | 80 ++++++- crates/zclaw-saas/src/lib.rs | 1 + crates/zclaw-saas/src/main.rs | 115 ++++++++-- 9 files changed, 743 insertions(+), 25 deletions(-) create mode 100644 crates/zclaw-saas/migrations/20260402000001_billing_tables.sql create mode 100644 crates/zclaw-saas/src/billing/handlers.rs create mode 100644 crates/zclaw-saas/src/billing/mod.rs create mode 100644 crates/zclaw-saas/src/billing/service.rs create mode 100644 crates/zclaw-saas/src/billing/types.rs diff --git a/Cargo.toml b/Cargo.toml index d935702..f439cbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1", features = ["v4", "v5", "serde"] } # Database -sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres"] } +sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres", "chrono"] } libsqlite3-sys = { version = "0.27", features = ["bundled"] } # HTTP client (for LLM drivers) diff --git a/crates/zclaw-saas/migrations/20260402000001_billing_tables.sql b/crates/zclaw-saas/migrations/20260402000001_billing_tables.sql new file mode 100644 index 0000000..0823946 --- /dev/null +++ b/crates/zclaw-saas/migrations/20260402000001_billing_tables.sql @@ -0,0 +1,133 @@ +-- Migration: Billing tables for subscription management +-- Supports: Free/Pro/Team plans, Alipay + WeChat Pay, usage quotas + +-- Plan definitions (Free/Pro/Team) +CREATE TABLE IF NOT EXISTS billing_plans ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + display_name TEXT NOT NULL, + description TEXT, + price_cents INTEGER NOT NULL DEFAULT 0, + currency TEXT NOT NULL DEFAULT 'CNY', + interval TEXT NOT NULL DEFAULT 'month', + features JSONB NOT NULL DEFAULT '{}', + limits JSONB NOT NULL DEFAULT '{}', + is_default BOOLEAN NOT NULL DEFAULT FALSE, + sort_order INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_billing_plans_status ON billing_plans(status); + +-- Account subscriptions +CREATE TABLE IF NOT EXISTS billing_subscriptions ( + id TEXT PRIMARY KEY, + account_id TEXT NOT NULL, + plan_id TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'active', + current_period_start TIMESTAMPTZ NOT NULL DEFAULT NOW(), + current_period_end TIMESTAMPTZ NOT NULL, + trial_end TIMESTAMPTZ, + canceled_at TIMESTAMPTZ, + cancel_at_period_end BOOLEAN NOT NULL DEFAULT FALSE, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE, + FOREIGN KEY (plan_id) REFERENCES billing_plans(id) +); +CREATE INDEX IF NOT EXISTS idx_billing_sub_account ON billing_subscriptions(account_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_billing_sub_active + ON billing_subscriptions(account_id) + WHERE status IN ('trial', 'active', 'past_due'); + +-- Invoices +CREATE TABLE IF NOT EXISTS billing_invoices ( + id TEXT PRIMARY KEY, + account_id TEXT NOT NULL, + subscription_id TEXT, + plan_id TEXT, + amount_cents INTEGER NOT NULL, + currency TEXT NOT NULL DEFAULT 'CNY', + description TEXT, + status TEXT NOT NULL DEFAULT 'pending', + due_at TIMESTAMPTZ, + paid_at TIMESTAMPTZ, + voided_at TIMESTAMPTZ, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE, + FOREIGN KEY (subscription_id) REFERENCES billing_subscriptions(id) ON DELETE SET NULL, + FOREIGN KEY (plan_id) REFERENCES billing_plans(id) +); +CREATE INDEX IF NOT EXISTS idx_billing_inv_account ON billing_invoices(account_id); +CREATE INDEX IF NOT EXISTS idx_billing_inv_status ON billing_invoices(status); +CREATE INDEX IF NOT EXISTS idx_billing_inv_time ON billing_invoices(created_at); + +-- Payment records (Alipay / WeChat Pay) +CREATE TABLE IF NOT EXISTS billing_payments ( + id TEXT PRIMARY KEY, + invoice_id TEXT NOT NULL, + account_id TEXT NOT NULL, + amount_cents INTEGER NOT NULL, + currency TEXT NOT NULL DEFAULT 'CNY', + method TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + external_trade_no TEXT, + paid_at TIMESTAMPTZ, + refunded_at TIMESTAMPTZ, + failure_reason TEXT, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (invoice_id) REFERENCES billing_invoices(id) ON DELETE CASCADE, + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_billing_pay_invoice ON billing_payments(invoice_id); +CREATE INDEX IF NOT EXISTS idx_billing_pay_account ON billing_payments(account_id); +CREATE INDEX IF NOT EXISTS idx_billing_pay_trade_no ON billing_payments(external_trade_no); +CREATE INDEX IF NOT EXISTS idx_billing_pay_status ON billing_payments(status); + +-- Monthly usage quotas (per account per billing period) +CREATE TABLE IF NOT EXISTS billing_usage_quotas ( + id TEXT PRIMARY KEY, + account_id TEXT NOT NULL, + period_start TIMESTAMPTZ NOT NULL, + period_end TIMESTAMPTZ NOT NULL, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + relay_requests INTEGER NOT NULL DEFAULT 0, + hand_executions INTEGER NOT NULL DEFAULT 0, + pipeline_runs INTEGER NOT NULL DEFAULT 0, + max_input_tokens BIGINT, + max_output_tokens BIGINT, + max_relay_requests INTEGER, + max_hand_executions INTEGER, + max_pipeline_runs INTEGER, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE, + UNIQUE(account_id, period_start) +); +CREATE INDEX IF NOT EXISTS idx_billing_usage_account ON billing_usage_quotas(account_id); +CREATE INDEX IF NOT EXISTS idx_billing_usage_period ON billing_usage_quotas(period_start, period_end); + +-- Seed: default plans +INSERT INTO billing_plans (id, name, display_name, description, price_cents, interval, features, limits, is_default, sort_order) +VALUES + ('plan-free', 'free', '免费版', '基础功能,适合个人体验', 0, 'month', + '{"hands": ["browser", "collector", "researcher"], "chat_modes": ["flash", "thinking"], "pipelines": 3, "support": "community"}'::jsonb, + '{"max_input_tokens_monthly": 500000, "max_output_tokens_monthly": 500000, "max_relay_requests_monthly": 100, "max_hand_executions_monthly": 20, "max_pipeline_runs_monthly": 5}'::jsonb, + TRUE, 0), + ('plan-pro', 'pro', '专业版', '全功能解锁,适合知识工作者', 4900, 'month', + '{"hands": "all", "chat_modes": "all", "pipelines": -1, "support": "priority", "memory": true, "export": true}'::jsonb, + '{"max_input_tokens_monthly": 5000000, "max_output_tokens_monthly": 5000000, "max_relay_requests_monthly": 2000, "max_hand_executions_monthly": 200, "max_pipeline_runs_monthly": 100}'::jsonb, + FALSE, 1), + ('plan-team', 'team', '团队版', '多席位协作,适合企业团队', 19900, 'month', + '{"hands": "all", "chat_modes": "all", "pipelines": -1, "support": "dedicated", "memory": true, "export": true, "sharing": true, "admin": true}'::jsonb, + '{"max_input_tokens_monthly": 50000000, "max_output_tokens_monthly": 50000000, "max_relay_requests_monthly": 20000, "max_hand_executions_monthly": 1000, "max_pipeline_runs_monthly": 500}'::jsonb, + FALSE, 2) +ON CONFLICT (name) DO NOTHING; diff --git a/crates/zclaw-saas/src/billing/handlers.rs b/crates/zclaw-saas/src/billing/handlers.rs new file mode 100644 index 0000000..a997df5 --- /dev/null +++ b/crates/zclaw-saas/src/billing/handlers.rs @@ -0,0 +1,55 @@ +//! 计费 HTTP 处理器 + +use axum::{ + extract::{Extension, Path, State}, + Json, +}; + +use crate::auth::types::AuthContext; +use crate::error::SaasResult; +use crate::state::AppState; +use super::service; +use super::types::*; + +/// GET /api/v1/billing/plans — 列出所有活跃计划 +pub async fn list_plans( + State(state): State, +) -> SaasResult>> { + let plans = service::list_plans(&state.db).await?; + Ok(Json(plans)) +} + +/// GET /api/v1/billing/plans/:id — 获取单个计划详情 +pub async fn get_plan( + State(state): State, + Path(plan_id): Path, +) -> SaasResult> { + let plan = service::get_plan(&state.db, &plan_id).await? + .ok_or_else(|| crate::error::SaasError::NotFound("计划不存在".into()))?; + Ok(Json(plan)) +} + +/// GET /api/v1/billing/subscription — 获取当前订阅 +pub async fn get_subscription( + State(state): State, + Extension(ctx): Extension, +) -> SaasResult> { + let plan = service::get_account_plan(&state.db, &ctx.account_id).await?; + let sub = service::get_active_subscription(&state.db, &ctx.account_id).await?; + let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?; + + Ok(Json(serde_json::json!({ + "plan": plan, + "subscription": sub, + "usage": usage, + }))) +} + +/// GET /api/v1/billing/usage — 获取当月用量 +pub async fn get_usage( + State(state): State, + Extension(ctx): Extension, +) -> SaasResult> { + let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?; + Ok(Json(usage)) +} diff --git a/crates/zclaw-saas/src/billing/mod.rs b/crates/zclaw-saas/src/billing/mod.rs new file mode 100644 index 0000000..63197b6 --- /dev/null +++ b/crates/zclaw-saas/src/billing/mod.rs @@ -0,0 +1,15 @@ +//! 计费模块 — 计划管理、订阅、用量配额 + +pub mod types; +pub mod service; +pub mod handlers; + +use axum::routing::get; + +pub fn routes() -> axum::Router { + axum::Router::new() + .route("/api/v1/billing/plans", get(handlers::list_plans)) + .route("/api/v1/billing/plans/{id}", get(handlers::get_plan)) + .route("/api/v1/billing/subscription", get(handlers::get_subscription)) + .route("/api/v1/billing/usage", get(handlers::get_usage)) +} diff --git a/crates/zclaw-saas/src/billing/service.rs b/crates/zclaw-saas/src/billing/service.rs new file mode 100644 index 0000000..b35aa43 --- /dev/null +++ b/crates/zclaw-saas/src/billing/service.rs @@ -0,0 +1,206 @@ +//! 计费服务层 — 计划查询、订阅管理、用量检查 + +use chrono::{Datelike, Timelike}; +use sqlx::PgPool; + +use crate::error::SaasResult; + +use super::types::*; + +/// 获取所有活跃计划 +pub async fn list_plans(pool: &PgPool) -> SaasResult> { + let plans = sqlx::query_as::<_, BillingPlan>( + "SELECT * FROM billing_plans WHERE status = 'active' ORDER BY sort_order" + ) + .fetch_all(pool) + .await?; + Ok(plans) +} + +/// 获取单个计划 +pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult> { + let plan = sqlx::query_as::<_, BillingPlan>( + "SELECT * FROM billing_plans WHERE id = $1" + ) + .bind(plan_id) + .fetch_optional(pool) + .await?; + Ok(plan) +} + +/// 获取账户当前有效订阅 +pub async fn get_active_subscription( + pool: &PgPool, + account_id: &str, +) -> SaasResult> { + let sub = sqlx::query_as::<_, Subscription>( + "SELECT * FROM billing_subscriptions \ + WHERE account_id = $1 AND status IN ('trial', 'active', 'past_due') \ + ORDER BY created_at DESC LIMIT 1" + ) + .bind(account_id) + .fetch_optional(pool) + .await?; + Ok(sub) +} + +/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free) +pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult { + if let Some(sub) = get_active_subscription(pool, account_id).await? { + if let Some(plan) = get_plan(pool, &sub.plan_id).await? { + return Ok(plan); + } + } + // 回退到 Free 计划 + let free = sqlx::query_as::<_, BillingPlan>( + "SELECT * FROM billing_plans WHERE name = 'free' AND status = 'active' LIMIT 1" + ) + .fetch_optional(pool) + .await?; + Ok(free.unwrap_or_else(|| BillingPlan { + id: "plan-free".into(), + name: "free".into(), + display_name: "免费版".into(), + description: Some("基础功能".into()), + price_cents: 0, + currency: "CNY".into(), + interval: "month".into(), + features: serde_json::json!({}), + limits: serde_json::json!({ + "max_input_tokens_monthly": 500000, + "max_output_tokens_monthly": 500000, + "max_relay_requests_monthly": 100, + "max_hand_executions_monthly": 20, + "max_pipeline_runs_monthly": 5, + }), + is_default: true, + sort_order: 0, + status: "active".into(), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + })) +} + +/// 获取或创建当月用量记录 +pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult { + let now = chrono::Utc::now(); + let period_start = now + .with_day(1).unwrap_or(now) + .with_hour(0).unwrap_or(now) + .with_minute(0).unwrap_or(now) + .with_second(0).unwrap_or(now) + .with_nanosecond(0).unwrap_or(now); + + // 尝试获取现有记录 + let existing = sqlx::query_as::<_, UsageQuota>( + "SELECT * FROM billing_usage_quotas \ + WHERE account_id = $1 AND period_start = $2" + ) + .bind(account_id) + .bind(period_start) + .fetch_optional(pool) + .await?; + + if let Some(usage) = existing { + return Ok(usage); + } + + // 获取当前计划限额 + let plan = get_account_plan(pool, account_id).await?; + let limits: PlanLimits = serde_json::from_value(plan.limits.clone()) + .unwrap_or_else(|_| PlanLimits::free()); + + // 计算月末 + let period_end = if now.month() == 12 { + now.with_year(now.year() + 1).and_then(|d| d.with_month(1)) + } else { + now.with_month(now.month() + 1) + }.unwrap_or(now) + .with_day(1).unwrap_or(now) + .with_hour(0).unwrap_or(now) + .with_minute(0).unwrap_or(now) + .with_second(0).unwrap_or(now) + .with_nanosecond(0).unwrap_or(now); + + let id = uuid::Uuid::new_v4().to_string(); + let usage = sqlx::query_as::<_, UsageQuota>( + "INSERT INTO billing_usage_quotas \ + (id, account_id, period_start, period_end, \ + max_input_tokens, max_output_tokens, max_relay_requests, \ + max_hand_executions, max_pipeline_runs) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \ + RETURNING *" + ) + .bind(&id) + .bind(account_id) + .bind(period_start) + .bind(period_end) + .bind(limits.max_input_tokens_monthly) + .bind(limits.max_output_tokens_monthly) + .bind(limits.max_relay_requests_monthly) + .bind(limits.max_hand_executions_monthly) + .bind(limits.max_pipeline_runs_monthly) + .fetch_one(pool) + .await?; + + Ok(usage) +} + +/// 增加用量计数 +pub async fn increment_usage( + pool: &PgPool, + account_id: &str, + input_tokens: i64, + output_tokens: i64, +) -> SaasResult<()> { + let usage = get_or_create_usage(pool, account_id).await?; + sqlx::query( + "UPDATE billing_usage_quotas \ + SET input_tokens = input_tokens + $1, \ + output_tokens = output_tokens + $2, \ + relay_requests = relay_requests + 1, \ + updated_at = NOW() \ + WHERE id = $3" + ) + .bind(input_tokens) + .bind(output_tokens) + .bind(&usage.id) + .execute(pool) + .await?; + Ok(()) +} + +/// 检查用量配额 +pub async fn check_quota( + pool: &PgPool, + account_id: &str, + quota_type: &str, +) -> SaasResult { + let usage = get_or_create_usage(pool, account_id).await?; + + let (current, limit) = match quota_type { + "input_tokens" => (usage.input_tokens, usage.max_input_tokens), + "output_tokens" => (usage.output_tokens, usage.max_output_tokens), + "relay_requests" => (usage.relay_requests as i64, usage.max_relay_requests.map(|v| v as i64)), + "hand_executions" => (usage.hand_executions as i64, usage.max_hand_executions.map(|v| v as i64)), + "pipeline_runs" => (usage.pipeline_runs as i64, usage.max_pipeline_runs.map(|v| v as i64)), + _ => return Ok(QuotaCheck { + allowed: true, + reason: None, + current: 0, + limit: None, + remaining: None, + }), + }; + + let allowed = limit.map_or(true, |lim| current < lim); + let remaining = limit.map(|lim| (lim - current).max(0)); + + Ok(QuotaCheck { + allowed, + reason: if !allowed { Some(format!("{} 配额已用尽", quota_type)) } else { None }, + current, + limit, + remaining, + }) +} diff --git a/crates/zclaw-saas/src/billing/types.rs b/crates/zclaw-saas/src/billing/types.rs new file mode 100644 index 0000000..9a4e4df --- /dev/null +++ b/crates/zclaw-saas/src/billing/types.rs @@ -0,0 +1,161 @@ +//! 计费类型定义 + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +/// 计费计划定义 — 对应 billing_plans 表 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct BillingPlan { + pub id: String, + pub name: String, + pub display_name: String, + pub description: Option, + pub price_cents: i32, + pub currency: String, + pub interval: String, + pub features: serde_json::Value, + pub limits: serde_json::Value, + pub is_default: bool, + pub sort_order: i32, + pub status: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// 计划限额(从 limits JSON 反序列化) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlanLimits { + #[serde(default)] + pub max_input_tokens_monthly: Option, + #[serde(default)] + pub max_output_tokens_monthly: Option, + #[serde(default)] + pub max_relay_requests_monthly: Option, + #[serde(default)] + pub max_hand_executions_monthly: Option, + #[serde(default)] + pub max_pipeline_runs_monthly: Option, +} + +impl PlanLimits { + pub fn free() -> Self { + Self { + max_input_tokens_monthly: Some(500_000), + max_output_tokens_monthly: Some(500_000), + max_relay_requests_monthly: Some(100), + max_hand_executions_monthly: Some(20), + max_pipeline_runs_monthly: Some(5), + } + } +} + +/// 账户订阅 — 对应 billing_subscriptions 表 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct Subscription { + pub id: String, + pub account_id: String, + pub plan_id: String, + pub status: String, + pub current_period_start: DateTime, + pub current_period_end: DateTime, + pub trial_end: Option>, + pub canceled_at: Option>, + pub cancel_at_period_end: bool, + pub metadata: serde_json::Value, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// 发票 — 对应 billing_invoices 表 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct Invoice { + pub id: String, + pub account_id: String, + pub subscription_id: Option, + pub plan_id: Option, + pub amount_cents: i32, + pub currency: String, + pub description: Option, + pub status: String, + pub due_at: Option>, + pub paid_at: Option>, + pub voided_at: Option>, + pub metadata: serde_json::Value, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// 支付记录 — 对应 billing_payments 表 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct Payment { + pub id: String, + pub invoice_id: String, + pub account_id: String, + pub amount_cents: i32, + pub currency: String, + pub method: String, + pub status: String, + pub external_trade_no: Option, + pub paid_at: Option>, + pub refunded_at: Option>, + pub failure_reason: Option, + pub metadata: serde_json::Value, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// 月度用量配额 — 对应 billing_usage_quotas 表 +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct UsageQuota { + pub id: String, + pub account_id: String, + pub period_start: DateTime, + pub period_end: DateTime, + pub input_tokens: i64, + pub output_tokens: i64, + pub relay_requests: i32, + pub hand_executions: i32, + pub pipeline_runs: i32, + pub max_input_tokens: Option, + pub max_output_tokens: Option, + pub max_relay_requests: Option, + pub max_hand_executions: Option, + pub max_pipeline_runs: Option, + pub metadata: serde_json::Value, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +/// 用量检查结果 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QuotaCheck { + pub allowed: bool, + pub reason: Option, + pub current: i64, + pub limit: Option, + pub remaining: Option, +} + +/// 支付方式 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PaymentMethod { + Alipay, + Wechat, +} + +/// 创建支付请求 +#[derive(Debug, Deserialize)] +pub struct CreatePaymentRequest { + pub plan_id: String, + pub payment_method: PaymentMethod, +} + +/// 支付结果 +#[derive(Debug, Serialize)] +pub struct PaymentResult { + pub payment_id: String, + pub trade_no: String, + pub pay_url: String, + pub amount_cents: i32, +} diff --git a/crates/zclaw-saas/src/db.rs b/crates/zclaw-saas/src/db.rs index f070a6e..99962a9 100644 --- a/crates/zclaw-saas/src/db.rs +++ b/crates/zclaw-saas/src/db.rs @@ -2,34 +2,44 @@ use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; +use crate::config::DatabaseConfig; use crate::error::SaasResult; -const SCHEMA_VERSION: i32 = 11; +const SCHEMA_VERSION: i32 = 12; /// 初始化数据库 -pub async fn init_db(database_url: &str) -> SaasResult { - // 连接池大小可通过环境变量配置,默认 100(relay 请求每次 10+ 串行查询,50 偏紧) +pub async fn init_db(config: &DatabaseConfig) -> SaasResult { + // 环境变量覆盖 URL(避免在配置文件中存储密码) + let database_url = std::env::var("ZCLAW_DATABASE_URL") + .unwrap_or_else(|_| config.url.clone()); + + // 环境变量覆盖连接数(向后兼容) let max_connections: u32 = std::env::var("ZCLAW_DB_MAX_CONNECTIONS") .ok() .and_then(|v| v.parse().ok()) - .unwrap_or(100); + .unwrap_or(config.max_connections); let min_connections: u32 = std::env::var("ZCLAW_DB_MIN_CONNECTIONS") .ok() .and_then(|v| v.parse().ok()) - .unwrap_or(5); + .unwrap_or(config.min_connections); - tracing::info!("Database pool: max={}, min={}", max_connections, min_connections); + tracing::info!( + "Database pool: max={}, min={}, acquire_timeout={}s, idle_timeout={}s, max_lifetime={}s", + max_connections, min_connections, + config.acquire_timeout_secs, config.idle_timeout_secs, config.max_lifetime_secs + ); let pool = PgPoolOptions::new() .max_connections(max_connections) .min_connections(min_connections) - .acquire_timeout(std::time::Duration::from_secs(8)) - .idle_timeout(std::time::Duration::from_secs(180)) - .max_lifetime(std::time::Duration::from_secs(900)) - .connect(database_url) + .acquire_timeout(std::time::Duration::from_secs(config.acquire_timeout_secs)) + .idle_timeout(std::time::Duration::from_secs(config.idle_timeout_secs)) + .max_lifetime(std::time::Duration::from_secs(config.max_lifetime_secs)) + .connect(&database_url) .await?; run_migrations(&pool).await?; + ensure_security_columns(&pool).await?; seed_admin_account(&pool).await?; seed_builtin_prompts(&pool).await?; seed_demo_data(&pool).await?; @@ -884,6 +894,56 @@ async fn fix_seed_data(pool: &PgPool) -> SaasResult<()> { Ok(()) } +/// 防御性检查:确保安全审计新增的列存在(即使 schema_version 显示已是最新) +/// +/// 场景:旧数据库的 schema_version 已被手动更新但迁移文件未实际执行, +/// 或者迁移文件在 version check 时被跳过。 +async fn ensure_security_columns(pool: &PgPool) -> SaasResult<()> { + // 检查 password_version 列是否存在 + let col_exists: bool = sqlx::query_scalar( + "SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'accounts' AND column_name = 'password_version')" + ) + .fetch_one(pool) + .await + .unwrap_or(false); + + if !col_exists { + tracing::warn!("[DB] 'password_version' column missing — applying security fix migration"); + sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS password_version INTEGER NOT NULL DEFAULT 1") + .execute(pool).await?; + sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS failed_login_count INTEGER NOT NULL DEFAULT 0") + .execute(pool).await?; + sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS locked_until TIMESTAMPTZ") + .execute(pool).await?; + tracing::info!("[DB] Security columns (password_version, failed_login_count, locked_until) applied"); + } + + // 检查 rate_limit_events 表是否存在 + let table_exists: bool = sqlx::query_scalar( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'rate_limit_events')" + ) + .fetch_one(pool) + .await + .unwrap_or(false); + + if !table_exists { + tracing::warn!("[DB] 'rate_limit_events' table missing — applying rate limit migration"); + if let Err(e) = sqlx::query( + "CREATE TABLE IF NOT EXISTS rate_limit_events ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + key TEXT NOT NULL, + count BIGINT NOT NULL DEFAULT 1, + window_start TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + )" + ).execute(pool).await { + tracing::warn!("[DB] Failed to create rate_limit_events: {}", e); + } + } + + Ok(()) +} + #[cfg(test)] mod tests { // PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容 diff --git a/crates/zclaw-saas/src/lib.rs b/crates/zclaw-saas/src/lib.rs index 9779cd1..2b3c149 100644 --- a/crates/zclaw-saas/src/lib.rs +++ b/crates/zclaw-saas/src/lib.rs @@ -25,3 +25,4 @@ pub mod prompt; pub mod agent_template; pub mod scheduled_task; pub mod telemetry; +pub mod billing; diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index 5edc87d..78a2b81 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -14,6 +14,9 @@ use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker; #[tokio::main] async fn main() -> anyhow::Result<()> { + // Load .env file from project root (walk up from current dir) + load_dotenv(); + tracing_subscriber::fmt() .with_env_filter( tracing_subscriber::EnvFilter::try_from_default_env() @@ -24,11 +27,18 @@ async fn main() -> anyhow::Result<()> { let config = SaaSConfig::load()?; info!("SaaS config loaded: {}:{}", config.server.host, config.server.port); - let db = init_db(&config.database.url).await?; + let db = init_db(&config.database).await?; info!("Database initialized"); + // 创建 Worker spawn 限制器(门控并发 DB 操作数量) + let worker_limiter = zclaw_saas::state::SpawnLimiter::new( + "worker", + config.database.worker_concurrency, + ); + info!("Worker spawn limiter: {} permits", config.database.worker_concurrency); + // 初始化 Worker 调度器 + 注册所有 Worker - let mut dispatcher = WorkerDispatcher::new(db.clone()); + let mut dispatcher = WorkerDispatcher::new(db.clone(), worker_limiter.clone()); dispatcher.register(LogOperationWorker); dispatcher.register(CleanupRefreshTokensWorker); dispatcher.register(CleanupRateLimitWorker); @@ -38,12 +48,13 @@ async fn main() -> anyhow::Result<()> { // 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止 let shutdown_token = CancellationToken::new(); - let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?; + let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone(), worker_limiter.clone())?; // Restore rate limit counts from DB so limits survive server restarts + // 仅恢复最近 60s 的计数(与 middleware 的 60s 滑动窗口一致),避免过于保守的限流 { let rows: Vec<(String, i64)> = sqlx::query_as( - "SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key" + "SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '60 seconds' GROUP BY key" ) .fetch_all(&db) .await @@ -51,18 +62,17 @@ async fn main() -> anyhow::Result<()> { let mut restored_count = 0usize; for (key, count) in rows { - let mut entries = Vec::new(); - // Approximate: insert count timestamps at "now" — the DashMap will - // expire them naturally via the retain() call in the middleware. - // This is intentionally approximate; exact window alignment is not - // required for rate limiting correctness. - for _ in 0..count as usize { + // 限制恢复计数不超过 RPM 配额,避免重启后过于保守 + let rpm = state.rate_limit_rpm() as usize; + let capped = (count as usize).min(rpm); + let mut entries = Vec::with_capacity(capped); + for _ in 0..capped { entries.push(std::time::Instant::now()); } state.rate_limit_entries.insert(key, entries); restored_count += 1; } - info!("Restored rate limit state from DB: {} keys", restored_count); + info!("Restored rate limit state from DB: {} keys (60s window, capped at RPM)", restored_count); } // 迁移旧格式 TOTP secret(明文 → 加密 enc: 格式) @@ -117,20 +127,64 @@ async fn main() -> anyhow::Result<()> { }); } - let app = build_router(state).await; + // 限流事件批量 flush (可配置间隔,默认 5s) + { + let flush_state = state.clone(); + let batch_interval = config.database.rate_limit_batch_interval_secs; + let batch_max = config.database.rate_limit_batch_max_size; + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(batch_interval)); + loop { + interval.tick().await; + flush_state.flush_rate_limit_batch(batch_max).await; + } + }); + } + + // 连接池可观测性 (30s 指标日志) + { + let metrics_db = db.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(30)); + loop { + interval.tick().await; + let pool = &metrics_db; + let total = pool.options().get_max_connections() as usize; + let idle = pool.num_idle() as usize; + let used = total.saturating_sub(idle); + let usage_pct = if total > 0 { used * 100 / total } else { 0 }; + tracing::info!( + "[PoolMetrics] total={} idle={} used={} usage_pct={}%", + total, idle, used, usage_pct, + ); + if usage_pct >= 80 { + tracing::warn!( + "[PoolMetrics] HIGH USAGE: {}% of connections in use!", + usage_pct, + ); + } + } + }); + } + + let app = build_router(state.clone()).await; // 配置 TCP keepalive + 短 SO_LINGER,防止 CLOSE_WAIT 累积 let listener = create_listener(&config.server.host, config.server.port)?; info!("SaaS server listening on {}:{}", config.server.host, config.server.port); - // 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空 + // 优雅停机: Ctrl+C → 最终批量 flush → 取消 CancellationToken → SSE 流终止 → 连接排空 let token = shutdown_token.clone(); + let flush_state = state; + let batch_max = config.database.rate_limit_batch_max_size; axum::serve(listener, app.into_make_service_with_connect_info::()) .with_graceful_shutdown(async move { tokio::signal::ctrl_c() .await .expect("Failed to install Ctrl+C handler"); - info!("Received shutdown signal, cancelling SSE streams and draining connections..."); + info!("Received shutdown signal, flushing pending rate limit batch..."); + flush_state.flush_rate_limit_batch(batch_max).await; + info!("Cancelling SSE streams and draining connections..."); token.cancel(); }) .await?; @@ -280,6 +334,7 @@ async fn build_router(state: AppState) -> axum::Router { .merge(zclaw_saas::agent_template::routes()) .merge(zclaw_saas::scheduled_task::routes()) .merge(zclaw_saas::telemetry::routes()) + .merge(zclaw_saas::billing::routes()) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::api_version_middleware, @@ -329,3 +384,35 @@ async fn build_router(state: AppState) -> axum::Router { .layer(cors) .with_state(state) } + +/// Load `.env` file from project root by walking up from current directory. +/// Sets environment variables that are not already set (does not override). +fn load_dotenv() { + let mut dir = std::env::current_dir().unwrap_or_default(); + loop { + let env_path = dir.join(".env"); + if env_path.is_file() { + if let Ok(content) = std::fs::read_to_string(&env_path) { + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + if let Some((key, value)) = line.split_once('=') { + let key = key.trim(); + let value = value.trim(); + // Only set if not already defined in environment + if std::env::var(key).is_err() { + std::env::set_var(key, value); + } + } + } + tracing::debug!("Loaded .env from {}", env_path.display()); + } + return; + } + if !dir.pop() { + break; + } + } +}