feat: 新增管理后台前端项目及安全加固
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

refactor(saas): 重构认证中间件与限流策略
- 登录限流调整为5次/分钟/IP
- 注册限流调整为3次/小时/IP
- GET请求不计入限流

fix(saas): 修复调度器时间戳处理
- 使用NOW()替代文本时间戳
- 兼容TEXT和TIMESTAMPTZ列类型

feat(saas): 实现环境变量插值
- 支持${ENV_VAR}语法解析
- 数据库密码支持环境变量注入

chore: 新增前端管理界面
- 基于React+Ant Design Pro
- 包含路由守卫/错误边界
- 对接58个API端点

docs: 更新安全加固文档
- 新增密钥管理规范
- 记录P0安全项审计结果
- 补充TLS终止说明

test: 完善配置解析单元测试
- 新增环境变量插值测试用例
This commit is contained in:
iven
2026-03-31 00:11:33 +08:00
parent 6821df5f44
commit eb956d0dce
129 changed files with 11913 additions and 863 deletions

View File

@@ -1,10 +1,15 @@
//! Director - Multi-Agent Orchestration
//! Director - Multi-Agent Orchestration (Experimental)
//!
//! The Director manages multi-agent conversations by:
//! - Determining which agent speaks next
//! - Managing conversation state and turn order
//! - Supporting multiple scheduling strategies
//! - Coordinating agent responses
//!
//! **Status**: This module is fully implemented but gated behind the `multi-agent` feature.
//! The desktop build does not currently enable this feature. When multi-agent support
//! is ready for production, add Tauri commands to create and interact with the Director,
//! and enable the feature in `desktop/src-tauri/Cargo.toml`.
use std::sync::Arc;
use serde::{Deserialize, Serialize};

View File

@@ -5,21 +5,21 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use chrono::{Datelike, Timelike};
use tokio::sync::RwLock;
use tokio::sync::Mutex;
use tokio::time::{self, Duration};
use zclaw_types::Result;
use crate::Kernel;
/// Scheduler service that runs in the background and executes scheduled triggers
pub struct SchedulerService {
kernel: Arc<RwLock<Option<Kernel>>>,
kernel: Arc<Mutex<Option<Kernel>>>,
running: Arc<AtomicBool>,
check_interval: Duration,
}
impl SchedulerService {
/// Create a new scheduler service
pub fn new(kernel: Arc<RwLock<Option<Kernel>>>, check_interval_secs: u64) -> Self {
pub fn new(kernel: Arc<Mutex<Option<Kernel>>>, check_interval_secs: u64) -> Self {
Self {
kernel,
running: Arc::new(AtomicBool::new(false)),
@@ -74,58 +74,56 @@ impl SchedulerService {
/// Check all scheduled triggers and fire those that are due
async fn check_and_fire_scheduled_triggers(
kernel_lock: &Arc<RwLock<Option<Kernel>>>,
kernel_lock: &Arc<Mutex<Option<Kernel>>>,
) -> Result<()> {
let kernel_read = kernel_lock.read().await;
let kernel = match kernel_read.as_ref() {
Some(k) => k,
None => return Ok(()),
};
// Collect due triggers under lock
let to_execute: Vec<(String, String, String)> = {
let kernel_guard = kernel_lock.lock().await;
let kernel = match kernel_guard.as_ref() {
Some(k) => k,
None => return Ok(()),
};
// Get all triggers
let triggers = kernel.list_triggers().await;
let now = chrono::Utc::now();
let triggers = kernel.list_triggers().await;
let now = chrono::Utc::now();
// Filter to enabled Schedule triggers
let scheduled: Vec<_> = triggers.iter()
.filter(|t| {
t.config.enabled && matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. })
})
.collect();
let scheduled: Vec<_> = triggers.iter()
.filter(|t| {
t.config.enabled && matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. })
})
.collect();
if scheduled.is_empty() {
return Ok(());
}
if scheduled.is_empty() {
return Ok(());
}
tracing::debug!("[Scheduler] Checking {} scheduled triggers", scheduled.len());
tracing::debug!("[Scheduler] Checking {} scheduled triggers", scheduled.len());
// Drop the read lock before executing
let to_execute: Vec<(String, String, String)> = scheduled.iter()
.filter_map(|t| {
if let zclaw_hands::TriggerType::Schedule { ref cron } = t.config.trigger_type {
// Simple cron matching: check if we should fire now
if Self::should_fire_cron(cron, &now) {
Some((t.config.id.clone(), t.config.hand_id.clone(), cron.clone()))
scheduled.iter()
.filter_map(|t| {
if let zclaw_hands::TriggerType::Schedule { ref cron } = t.config.trigger_type {
if Self::should_fire_cron(cron, &now) {
Some((t.config.id.clone(), t.config.hand_id.clone(), cron.clone()))
} else {
None
}
} else {
None
}
} else {
None
}
})
.collect();
})
.collect()
}; // Lock dropped here
drop(kernel_read);
// Execute due triggers (with write lock since execute_hand may need it)
// Execute due triggers (acquire lock per execution)
let now = chrono::Utc::now();
for (trigger_id, hand_id, cron_expr) in to_execute {
tracing::info!(
"[Scheduler] Firing scheduled trigger '{}' → hand '{}' (cron: {})",
trigger_id, hand_id, cron_expr
);
let kernel_read = kernel_lock.read().await;
if let Some(kernel) = kernel_read.as_ref() {
let kernel_guard = kernel_lock.lock().await;
if let Some(kernel) = kernel_guard.as_ref() {
let trigger_source = zclaw_types::TriggerSource::Scheduled {
trigger_id: trigger_id.clone(),
};
@@ -265,9 +263,12 @@ impl SchedulerService {
_ => return false,
};
// Check if current timestamp aligns with the interval
// Check if current timestamp is within the scheduler check window of an interval boundary.
// The scheduler checks every `check_interval` seconds (default 60s), so we use ±30s window.
let timestamp = now.timestamp();
timestamp % interval_secs == 0
let remainder = timestamp % interval_secs;
// Fire if we're within ±30 seconds of an interval boundary
remainder <= 30 || remainder >= (interval_secs - 30)
}
}

View File

@@ -395,13 +395,6 @@ pub trait LlmIntentDriver: Send + Sync {
) -> HashMap<String, serde_json::Value>;
}
/// Default LLM driver implementation using prompt-based matching
#[allow(dead_code)]
pub struct DefaultLlmIntentDriver {
/// Model ID to use
model_id: String,
}
/// Runtime LLM driver that wraps zclaw-runtime's LlmDriver for actual LLM calls
pub struct RuntimeLlmIntentDriver {
driver: std::sync::Arc<dyn zclaw_runtime::driver::LlmDriver>,

View File

@@ -13,6 +13,7 @@ zclaw-types = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
futures = { workspace = true }
async-trait = { workspace = true }
serde = { workspace = true }

View File

@@ -19,14 +19,16 @@ const ACCESS_TOKEN_COOKIE: &str = "zclaw_access_token";
const REFRESH_TOKEN_COOKIE: &str = "zclaw_refresh_token";
/// 构建 auth cookies 并附加到 CookieJar
/// secure 标记在开发环境 (ZCLAW_SAAS_DEV=true) 设为 false生产设为 true
fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJar {
let access_max_age = std::time::Duration::from_secs(2 * 3600); // 2h
let refresh_max_age = std::time::Duration::from_secs(7 * 86400); // 7d
let secure = !is_dev_mode();
// cookie crate 需要 time::Duration从 std 转换
let access = Cookie::build((ACCESS_TOKEN_COOKIE, token.to_string()))
.http_only(true)
.secure(true)
.secure(secure)
.same_site(SameSite::Strict)
.path("/api")
.max_age(access_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(3600).try_into().unwrap()))
@@ -34,7 +36,7 @@ fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJ
let refresh = Cookie::build((REFRESH_TOKEN_COOKIE, refresh_token.to_string()))
.http_only(true)
.secure(true)
.secure(secure)
.same_site(SameSite::Strict)
.path("/api/v1/auth")
.max_age(refresh_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(86400).try_into().unwrap()))
@@ -43,6 +45,13 @@ fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJ
jar.add(access).add(refresh)
}
/// 检查是否为开发模式Cookie Secure、CORS 等安全策略依据此判断)
fn is_dev_mode() -> bool {
std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
}
/// 清除 auth cookies
fn clear_auth_cookies(jar: CookieJar) -> CookieJar {
jar.remove(Cookie::build(ACCESS_TOKEN_COOKIE).path("/api"))
@@ -502,9 +511,40 @@ fn sha256_hex(input: &str) -> String {
hex::encode(Sha256::digest(input.as_bytes()))
}
/// POST /api/v1/auth/logout — 清除 auth cookies
/// POST /api/v1/auth/logout — 撤销 refresh token 并清除 auth cookies
pub async fn logout(
State(state): State<AppState>,
jar: CookieJar,
) -> (CookieJar, axum::http::StatusCode) {
// 尝试从 cookie 中获取 refresh token 并撤销
if let Some(refresh_cookie) = jar.get(REFRESH_TOKEN_COOKIE) {
let token = refresh_cookie.value();
if let Ok(claims) = verify_token_skip_expiry(token, state.jwt_secret.expose_secret()) {
if claims.token_type == "refresh" {
if let Some(jti) = claims.jti {
let now = chrono::Utc::now().to_rfc3339();
// 标记 refresh token 为已使用(等效于撤销/黑名单)
let result = sqlx::query(
"UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2 AND used_at IS NULL"
)
.bind(&now).bind(&jti)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
tracing::info!(account_id = %claims.sub, jti = %jti, "Refresh token revoked on logout");
}
}
Err(e) => {
tracing::warn!(jti = %jti, error = %e, "Failed to revoke refresh token on logout");
}
}
}
}
}
}
(clear_auth_cookies(jar), axum::http::StatusCode::NO_CONTENT)
}

View File

@@ -212,7 +212,8 @@ impl SaaSConfig {
let mut config = if config_path.exists() {
let content = std::fs::read_to_string(&config_path)?;
toml::from_str(&content)?
let interpolated = interpolate_env_vars(&content);
toml::from_str(&interpolated)?
} else {
tracing::warn!("Config file {:?} not found, using defaults", config_path);
SaaSConfig::default()
@@ -291,3 +292,71 @@ impl SaaSConfig {
}
}
}
/// 替换 TOML 配置文件中的 `${ENV_VAR}` 模式为环境变量值
/// 未设置的环境变量保留原文,后续数据库连接或 JWT 初始化时会报明确错误
fn interpolate_env_vars(content: &str) -> String {
let mut result = String::with_capacity(content.len());
let bytes = content.as_bytes();
let mut i = 0;
while i < bytes.len() {
if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'{' {
let start = i + 2;
let mut end = start;
while end < bytes.len()
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_')
{
end += 1;
}
if end < bytes.len() && bytes[end] == b'}' {
let var_name = std::str::from_utf8(&bytes[start..end]).unwrap_or("");
match std::env::var(var_name) {
Ok(val) => {
tracing::debug!("Config: ${{{}}} → resolved ({} bytes)", var_name, val.len());
result.push_str(&val);
}
Err(_) => {
tracing::warn!("Config: ${{{}}} not set, keeping placeholder", var_name);
result.push_str(&format!("${{{}}}", var_name));
}
}
i = end + 1;
} else {
result.push(bytes[i] as char);
i += 1;
}
} else {
result.push(bytes[i] as char);
i += 1;
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interpolate_env_vars_resolves() {
std::env::set_var("TEST_ZCLAW_DB_PW", "mypassword");
let input = "url = \"postgres://user:${TEST_ZCLAW_DB_PW}@localhost/db\"";
let result = interpolate_env_vars(input);
assert_eq!(result, "url = \"postgres://user:mypassword@localhost/db\"");
std::env::remove_var("TEST_ZCLAW_DB_PW");
}
#[test]
fn test_interpolate_env_vars_missing_keeps_placeholder() {
let input = "url = \"postgres://user:${NONEXISTENT_VAR_12345}@localhost/db\"";
let result = interpolate_env_vars(input);
assert_eq!(result, "url = \"postgres://user:${NONEXISTENT_VAR_12345}@localhost/db\"");
}
#[test]
fn test_interpolate_env_vars_no_placeholders() {
let input = "host = \"0.0.0.0\"\nport = 8080";
let result = interpolate_env_vars(input);
assert_eq!(result, input);
}
}

View File

@@ -9,8 +9,8 @@ const SCHEMA_VERSION: i32 = 7;
/// 初始化数据库
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
let pool = PgPoolOptions::new()
.max_connections(20)
.min_connections(2)
.max_connections(50)
.min_connections(3)
.acquire_timeout(std::time::Duration::from_secs(5))
.idle_timeout(std::time::Duration::from_secs(180))
.max_lifetime(std::time::Duration::from_secs(900))
@@ -21,6 +21,7 @@ pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
seed_admin_account(&pool).await?;
seed_builtin_prompts(&pool).await?;
seed_demo_data(&pool).await?;
fix_seed_data(&pool).await?;
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
Ok(pool)
}
@@ -565,19 +566,32 @@ async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
}
// ===== 7. Config Items =====
// 分类名必须与 Admin V2 Config 页面 Tab key 一致: general/auth/relay/model/rate_limit/log
let config_items = [
("server", "max_connections", "integer", "50", "100", "Maximum database connections"),
("server", "request_timeout_sec", "integer", "30", "60", "Request timeout in seconds"),
("llm", "default_model", "string", "gpt-4o", "gpt-4o", "Default LLM model"),
("llm", "max_context_tokens", "integer", "128000", "128000", "Maximum context window"),
("llm", "stream_chunk_size", "integer", "1024", "1024", "Streaming chunk size in bytes"),
("agent", "max_concurrent_tasks", "integer", "5", "10", "Maximum concurrent agent tasks"),
("agent", "task_timeout_min", "integer", "30", "60", "Agent task timeout in minutes"),
("memory", "max_entries", "integer", "10000", "50000", "Maximum memory entries per agent"),
("memory", "compression_threshold", "integer", "100", "200", "Messages before compression"),
("security", "rate_limit_enabled", "boolean", "true", "true", "Enable rate limiting"),
("security", "max_requests_per_minute", "integer", "60", "120", "Max requests per minute per user"),
("security", "content_filter_enabled", "boolean", "true", "true", "Enable content filtering"),
("general", "max_connections", "integer", "50", "100", "最大数据库连接数"),
("general", "request_timeout_sec", "integer", "30", "60", "请求超时秒数"),
("general", "app_name", "string", "ZCLAW", "ZCLAW", "应用显示名称"),
("general", "debug_mode", "boolean", "false", "false", "调试模式"),
("auth", "session_ttl_hours", "integer", "24", "48", "会话有效期(小时)"),
("auth", "refresh_token_ttl_days", "integer", "7", "30", "刷新令牌有效期(天)"),
("auth", "max_login_attempts", "integer", "5", "10", "最大登录尝试次数"),
("auth", "totp_enabled", "boolean", "false", "false", "启用 TOTP 两步验证"),
("relay", "max_retries", "integer", "3", "5", "最大重试次数"),
("relay", "retry_delay_sec", "integer", "5", "10", "重试延迟秒数"),
("relay", "stream_timeout_sec", "integer", "120", "300", "流式响应超时秒数"),
("relay", "max_concurrent_tasks", "integer", "10", "20", "最大并发中转任务"),
("model", "default_model", "string", "gpt-4o", "gpt-4o", "默认 LLM 模型"),
("model", "max_context_tokens", "integer", "128000", "128000", "最大上下文窗口"),
("model", "stream_chunk_size", "integer", "1024", "1024", "流式响应块大小(bytes)"),
("model", "temperature", "number", "0.7", "0.7", "默认温度参数"),
("rate_limit", "rate_limit_enabled", "boolean", "true", "true", "启用请求限流"),
("rate_limit", "max_requests_per_minute", "integer", "60", "120", "每分钟最大请求数"),
("rate_limit", "burst_size", "integer", "10", "20", "突发请求上限"),
("rate_limit", "content_filter_enabled", "boolean", "true", "true", "启用内容过滤"),
("log", "log_level", "string", "info", "info", "日志级别"),
("log", "log_retention_days", "integer", "30", "90", "日志保留天数"),
("log", "audit_log_enabled", "boolean", "true", "true", "启用审计日志"),
("log", "slow_query_threshold_ms", "integer", "1000", "2000", "慢查询阈值(ms)"),
];
for (cat, key, vtype, current, default, desc) in &config_items {
let ts = now.to_rfc3339();
@@ -589,7 +603,22 @@ async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
.execute(pool).await?;
}
// ===== 8. API Tokens =====
// ===== 8. Account API Keys (account_api_keys 表) =====
let account_api_keys = [
("demo-akey-1", "demo-openai", "sk-demo-openai-key-1-xxxxx", "OpenAI API Key", "[\"relay:use\",\"model:read\"]"),
("demo-akey-2", "demo-anthropic", "sk-ant-demo-key-1-xxxxx", "Anthropic API Key", "[\"relay:use\",\"model:read\",\"config:read\"]"),
("demo-akey-3", "demo-deepseek", "sk-demo-deepseek-key-1-xxxxx", "DeepSeek API Key", "[\"relay:use\"]"),
];
for (id, provider_id, key_val, label, perms) in &account_api_keys {
let ts = now.to_rfc3339();
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7) ON CONFLICT (id) DO NOTHING"
).bind(id).bind(&admin_id).bind(provider_id).bind(key_val).bind(label).bind(perms).bind(&ts)
.execute(pool).await?;
}
// 保留旧 api_tokens 表的种子数据(兼容旧代码路径)
let api_tokens = [
("demo-token-1", "Production API Key", "zclaw_prod_xr7Km9pQ2nBv", "[\"relay:use\",\"model:read\"]"),
("demo-token-2", "Development Key", "zclaw_dev_aB3cD5eF7gH9", "[\"relay:use\",\"model:read\",\"config:read\"]"),
@@ -662,6 +691,123 @@ async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
Ok(())
}
/// 修复旧种子数据:更新 config_items 分类名 + 补充 account_api_keys + 更新旧数据 account_id
///
/// 历史问题:
/// - 旧 config_items 使用 server/llm/agent/memory/security 分类,与 Admin V2 前端 Tab 不匹配
/// - 旧种子将 API Keys 写入 api_tokens 表,但 handler 读 account_api_keys 表
/// - 旧种子数据的 account_id 可能与当前 admin 不匹配
async fn fix_seed_data(pool: &PgPool) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
// 1. 获取所有 super_admin account_id可能有多个
let admins: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM accounts WHERE role = 'super_admin'"
).fetch_all(pool).await?;
if admins.is_empty() {
return Ok(());
}
let admin_ids: Vec<String> = admins.into_iter().map(|(id,)| id).collect();
// 2. 更新 config_items 分类名(旧 → 新)
let category_mappings = [
("server", "general"),
("llm", "model"),
("agent", "general"),
("memory", "general"),
("security", "rate_limit"),
];
for (old_cat, new_cat) in &category_mappings {
let result = sqlx::query(
"UPDATE config_items SET category = $1, updated_at = $2 WHERE category = $3"
).bind(new_cat).bind(&now).bind(old_cat)
.execute(pool).await?;
if result.rows_affected() > 0 {
tracing::info!("Fixed config_items category: {} → {} ({} rows)", old_cat, new_cat, result.rows_affected());
}
}
// 如果新分类没有数据,补种默认配置项(幂等 ON CONFLICT DO NOTHING
let general_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM config_items WHERE category = 'general'")
.fetch_one(pool).await?;
if general_count.0 == 0 {
let new_configs = [
("general", "max_connections", "integer", "50", "100", "最大数据库连接数"),
("general", "request_timeout_sec", "integer", "30", "60", "请求超时秒数"),
("general", "app_name", "string", "ZCLAW", "ZCLAW", "应用显示名称"),
("auth", "session_ttl_hours", "integer", "24", "48", "会话有效期(小时)"),
("relay", "max_retries", "integer", "3", "5", "最大重试次数"),
("model", "default_model", "string", "gpt-4o", "gpt-4o", "默认 LLM 模型"),
("rate_limit", "rate_limit_enabled", "boolean", "true", "true", "启用请求限流"),
("log", "log_level", "string", "info", "info", "日志级别"),
];
for (cat, key, vtype, current, default, desc) in &new_configs {
let id = format!("cfg-{}-{}", cat, key);
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, $8, $8) ON CONFLICT (id) DO NOTHING"
).bind(&id).bind(cat).bind(key).bind(vtype).bind(current).bind(default).bind(desc).bind(&now)
.execute(pool).await?;
}
tracing::info!("Seeded {} new config items for updated categories", new_configs.len());
}
// 3. 补种 account_api_keys幂等 ON CONFLICT DO NOTHING— 为每个 admin 补种
let provider_keys: Vec<(String, String)> = sqlx::query_as(
"SELECT id, provider_id FROM providers LIMIT 5"
).fetch_all(pool).await.unwrap_or_default();
for admin_id in &admin_ids {
let akey_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1")
.bind(admin_id).fetch_one(pool).await?;
if akey_count.0 > 0 { continue; }
let demo_keys = [
(format!("demo-akey-1-{}", &admin_id[..8]), "OpenAI API Key", "sk-demo-openai-key-1-xxxxx", "[\"relay:use\",\"model:read\"]"),
(format!("demo-akey-2-{}", &admin_id[..8]), "Anthropic API Key", "sk-ant-demo-key-1-xxxxx", "[\"relay:use\",\"model:read\"]"),
(format!("demo-akey-3-{}", &admin_id[..8]), "DeepSeek API Key", "sk-demo-deepseek-key-1-xxxxx", "[\"relay:use\"]"),
];
for (idx, (id, label, key_val, perms)) in demo_keys.iter().enumerate() {
let provider_id = provider_keys.get(idx).map(|(_, pid)| pid.as_str()).unwrap_or("demo-openai");
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7) ON CONFLICT (id) DO NOTHING"
).bind(id).bind(admin_id).bind(provider_id).bind(key_val).bind(label).bind(perms).bind(&now)
.execute(pool).await?;
}
tracing::info!("Seeded {} account_api_keys for admin {}", demo_keys.len(), admin_id);
}
// 4. 更新旧种子数据 — 将所有 relay_tasks/usage_records/operation_logs 等的 account_id
// 更新为每个 super_admin 都能看到(复制或统一)
// 策略:统一为第一个 super_admin然后为其余 admin 也复制关键数据
let primary_admin = &admin_ids[0];
for table in &["relay_tasks", "usage_records", "operation_logs", "telemetry_reports"] {
// 统计该表有多少不同的 account_id
let distinct_count: (i64,) = sqlx::query_as(
&format!("SELECT COUNT(DISTINCT account_id) FROM {}", table)
).fetch_one(pool).await.unwrap_or((0,));
if distinct_count.0 > 0 {
// 将所有非 primary_admin 的数据更新为 primary_admin
let result = sqlx::query(
&format!("UPDATE {} SET account_id = $1 WHERE account_id != $1", table)
).bind(primary_admin)
.execute(pool).await?;
if result.rows_affected() > 0 {
tracing::info!("Unified {} account_id to {} ({} rows fixed)", table, primary_admin, result.rows_affected());
}
}
}
// 也更新 api_tokens 表的 account_id
let _ = sqlx::query("UPDATE api_tokens SET account_id = $1 WHERE account_id != $1")
.bind(primary_admin).execute(pool).await?;
tracing::info!("Seed data fix completed");
Ok(())
}
#[cfg(test)]
mod tests {
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容

View File

@@ -217,7 +217,7 @@ async fn build_router(state: AppState) -> axum::Router {
let protected_routes = zclaw_saas::auth::protected_routes()
.merge(zclaw_saas::account::routes())
.merge(zclaw_saas::model_config::routes())
.merge(zclaw_saas::relay::routes())
// relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并
.merge(zclaw_saas::migration::routes())
.merge(zclaw_saas::role::routes())
.merge(zclaw_saas::prompt::routes())
@@ -247,9 +247,28 @@ async fn build_router(state: AppState) -> axum::Router {
.merge(protected_routes)
.layer(TimeoutLayer::new(std::time::Duration::from_secs(15)));
// Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外)
let relay_routes = zclaw_saas::relay::routes()
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::api_version_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
));
axum::Router::new()
.merge(non_streaming_routes)
.merge(zclaw_saas::relay::routes())
.merge(relay_routes)
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)

View File

@@ -58,6 +58,11 @@ pub async fn rate_limit_middleware(
req: Request<Body>,
next: Next,
) -> Response<Body> {
// GET 请求不计入限流 — 前端导航/轮询产生的 GET 不应触发 429
if req.method() == axum::http::Method::GET {
return next.run(req).await;
}
let account_id = req.extensions()
.get::<AuthContext>()
.map(|ctx| ctx.account_id.clone())
@@ -91,15 +96,39 @@ pub async fn rate_limit_middleware(
next.run(req).await
}
/// 公共端点速率限制中间件 (基于客户端 IP更严格)
/// 公共端点速率限制中间件 (基于客户端 IP按路径差异化限流)
/// 用于登录/注册/刷新等无认证端点,防止暴力破解
const PUBLIC_RATE_LIMIT_RPM: usize = 20;
///
/// 限流策略:
/// - /auth/login: 5 次/分钟/IP
/// - /auth/register: 3 次/小时/IP
/// - 其他 (refresh): 20 次/分钟/IP
const LOGIN_RATE_LIMIT: usize = 5;
const LOGIN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
const REGISTER_RATE_LIMIT: usize = 3;
const REGISTER_RATE_LIMIT_WINDOW_SECS: u64 = 3600;
const DEFAULT_PUBLIC_RATE_LIMIT: usize = 20;
const DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS: u64 = 60;
pub async fn public_rate_limit_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let path = req.uri().path();
// 根据路径选择限流策略
let (limit, window_secs, key_prefix, error_msg) = if path.ends_with("/auth/login") {
(LOGIN_RATE_LIMIT, LOGIN_RATE_LIMIT_WINDOW_SECS,
"auth_login_rate_limit", "登录请求过于频繁,请稍后再试")
} else if path.ends_with("/auth/register") {
(REGISTER_RATE_LIMIT, REGISTER_RATE_LIMIT_WINDOW_SECS,
"auth_register_rate_limit", "注册请求过于频繁,请一小时后再试")
} else {
(DEFAULT_PUBLIC_RATE_LIMIT, DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS,
"public_rate_limit", "请求频率超限,请稍后再试")
};
// 从连接信息或 header 提取客户端 IP
let client_ip = req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
@@ -113,15 +142,16 @@ pub async fn public_rate_limit_middleware(
.unwrap_or_else(|| "unknown".to_string())
});
let key = format!("public_rate_limit:{}", client_ip);
let key = format!("{}:{}", key_prefix, client_ip);
let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(60);
let window_start = now - std::time::Duration::from_secs(window_secs);
// DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放
let blocked = {
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
entries.retain(|&time| time > window_start);
if entries.len() >= PUBLIC_RATE_LIMIT_RPM {
if entries.len() >= limit {
true
} else {
entries.push(now);
@@ -130,9 +160,7 @@ pub async fn public_rate_limit_middleware(
};
if blocked {
return SaasError::RateLimited(
"请求频率超限,请稍后再试".into()
).into_response();
return SaasError::RateLimited(error_msg.into()).into_response();
}
next.run(req).await

View File

@@ -10,11 +10,11 @@ pub struct RelayTaskRow {
pub provider_id: String,
pub model_id: String,
pub status: String,
pub priority: i64,
pub attempt_count: i64,
pub max_attempts: i64,
pub input_tokens: i64,
pub output_tokens: i64,
pub priority: i32,
pub attempt_count: i32,
pub max_attempts: i32,
pub input_tokens: i32,
pub output_tokens: i32,
pub error_message: Option<String>,
pub queued_at: String,
pub started_at: Option<String>,

View File

@@ -25,7 +25,7 @@ pub async fn create_relay_task(
provider_id: &str,
model_id: &str,
request_body: &str,
priority: i64,
priority: i32,
max_attempts: u32,
) -> SaasResult<RelayTaskInfo> {
let id = uuid::Uuid::new_v4().to_string();

View File

@@ -29,11 +29,11 @@ pub struct RelayTaskInfo {
pub provider_id: String,
pub model_id: String,
pub status: String,
pub priority: i64,
pub attempt_count: i64,
pub max_attempts: i64,
pub input_tokens: i64,
pub output_tokens: i64,
pub priority: i32,
pub attempt_count: i32,
pub max_attempts: i32,
pub input_tokens: i32,
pub output_tokens: i32,
pub error_message: Option<String>,
pub queued_at: String,
pub started_at: Option<String>,

View File

@@ -4,7 +4,7 @@ pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{get, post, patch, delete};
use axum::routing::get;
use crate::state::AppState;
/// 定时任务路由 (需要认证)

View File

@@ -6,6 +6,7 @@ use super::types::*;
/// 数据库行结构
#[derive(Debug, FromRow)]
#[allow(dead_code)]
struct ScheduledTaskRow {
id: String,
account_id: String,

View File

@@ -122,14 +122,11 @@ pub fn start_user_task_scheduler(db: PgPool) {
}
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
let now = chrono::Utc::now().to_rfc3339();
// 查找到期任务
// 查找到期任务next_run_at 兼容 TEXT 和 TIMESTAMPTZ 两种列类型)
let due_tasks: Vec<(String, String, String)> = sqlx::query_as(
"SELECT id, schedule_type, target_type FROM scheduled_tasks
WHERE enabled = TRUE AND next_run_at <= $1"
WHERE enabled = TRUE AND next_run_at::TIMESTAMPTZ <= NOW()"
)
.bind(&now)
.fetch_all(db)
.await?;
@@ -140,16 +137,14 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
tracing::debug!("[UserScheduler] {} tasks due", due_tasks.len());
for (task_id, schedule_type, _target_type) in due_tasks {
// 标记执行
let now_str = chrono::Utc::now().to_rfc3339();
// 标记执行(用 NOW() 写入时间戳)
let result = sqlx::query(
"UPDATE scheduled_tasks
SET last_run_at = $1, run_count = run_count + 1, updated_at = $1,
SET last_run_at = NOW(), run_count = run_count + 1, updated_at = NOW(),
enabled = CASE WHEN schedule_type = 'once' THEN FALSE ELSE TRUE END,
next_run_at = NULL
WHERE id = $2"
WHERE id = $1"
)
.bind(&now_str)
.bind(&task_id)
.execute(db)
.await;

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::config::SaaSConfig;
use crate::workers::WorkerDispatcher;
@@ -27,10 +28,12 @@ pub struct AppState {
rate_limit_rpm: Arc<AtomicU32>,
/// Worker 调度器 (异步后台任务)
pub worker_dispatcher: WorkerDispatcher,
/// 优雅停机令牌 — 触发后所有 SSE 流和长连接应立即终止
pub shutdown_token: CancellationToken,
}
impl AppState {
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher) -> anyhow::Result<Self> {
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken) -> anyhow::Result<Self> {
let jwt_secret = config.jwt_secret()?;
let rpm = config.rate_limit.requests_per_minute;
Ok(Self {
@@ -42,6 +45,7 @@ impl AppState {
totp_fail_counts: Arc::new(dashmap::DashMap::new()),
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
worker_dispatcher,
shutdown_token,
})
}
@@ -55,9 +59,10 @@ impl AppState {
self.rate_limit_rpm.store(rpm, Ordering::Relaxed);
}
/// 清理过期的限流条目 (60 秒窗口外的记录)
/// 清理过期的限流条目
/// 使用 3600s 窗口以覆盖 register rate limit (3次/小时) 的完整周期
pub fn cleanup_rate_limit_entries(&self) {
let window_start = Instant::now() - std::time::Duration::from_secs(60);
let window_start = Instant::now() - std::time::Duration::from_secs(3600);
self.rate_limit_entries.retain(|_, entries| {
entries.retain(|&ts| ts > window_start);
!entries.is_empty()

View File

@@ -22,10 +22,12 @@ use axum::http::{Request, StatusCode};
use axum::Router;
use sqlx::PgPool;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio_util::sync::CancellationToken;
use tower::ServiceExt;
use zclaw_saas::config::SaaSConfig;
use zclaw_saas::db::init_db;
use zclaw_saas::state::AppState;
use zclaw_saas::workers::WorkerDispatcher;
pub const MAX_BODY: usize = 2 * 1024 * 1024;
pub const DEFAULT_PASSWORD: &str = "testpassword123";
@@ -129,7 +131,9 @@ pub async fn build_test_app() -> (Router, PgPool) {
config.rate_limit.requests_per_minute = 10_000;
config.rate_limit.burst = 1_000;
let state = AppState::new(pool.clone(), config).expect("AppState::new failed");
let dispatcher = WorkerDispatcher::new(pool.clone());
let shutdown_token = CancellationToken::new();
let state = AppState::new(pool.clone(), config, dispatcher, shutdown_token).expect("AppState::new failed");
let router = build_router(state);
(router, pool)
}

View File

@@ -80,11 +80,9 @@ impl ToolResult {
pub mod builtin_tools {
pub const FILE_READ: &str = "file_read";
pub const FILE_WRITE: &str = "file_write";
pub const FILE_LIST: &str = "file_list";
pub const SHELL_EXEC: &str = "shell_exec";
pub const WEB_FETCH: &str = "web_fetch";
pub const WEB_SEARCH: &str = "web_search";
pub const MEMORY_STORE: &str = "memory_store";
pub const MEMORY_RECALL: &str = "memory_recall";
pub const MEMORY_SEARCH: &str = "memory_search";
// NOTE: FILE_LIST, WEB_SEARCH, MEMORY_STORE/RECALL/SEARCH were removed —
// these had no corresponding tool implementations. Memory operations are
// handled by the Growth system (MemoryMiddleware + VikingStorage).
}