feat: 新增管理后台前端项目及安全加固
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
refactor(saas): 重构认证中间件与限流策略
- 登录限流调整为5次/分钟/IP
- 注册限流调整为3次/小时/IP
- GET请求不计入限流
fix(saas): 修复调度器时间戳处理
- 使用NOW()替代文本时间戳
- 兼容TEXT和TIMESTAMPTZ列类型
feat(saas): 实现环境变量插值
- 支持${ENV_VAR}语法解析
- 数据库密码支持环境变量注入
chore: 新增前端管理界面
- 基于React+Ant Design Pro
- 包含路由守卫/错误边界
- 对接58个API端点
docs: 更新安全加固文档
- 新增密钥管理规范
- 记录P0安全项审计结果
- 补充TLS终止说明
test: 完善配置解析单元测试
- 新增环境变量插值测试用例
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
//! Director - Multi-Agent Orchestration
|
||||
//! Director - Multi-Agent Orchestration (Experimental)
|
||||
//!
|
||||
//! The Director manages multi-agent conversations by:
|
||||
//! - Determining which agent speaks next
|
||||
//! - Managing conversation state and turn order
|
||||
//! - Supporting multiple scheduling strategies
|
||||
//! - Coordinating agent responses
|
||||
//!
|
||||
//! **Status**: This module is fully implemented but gated behind the `multi-agent` feature.
|
||||
//! The desktop build does not currently enable this feature. When multi-agent support
|
||||
//! is ready for production, add Tauri commands to create and interact with the Director,
|
||||
//! and enable the feature in `desktop/src-tauri/Cargo.toml`.
|
||||
|
||||
use std::sync::Arc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use chrono::{Datelike, Timelike};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{self, Duration};
|
||||
use zclaw_types::Result;
|
||||
use crate::Kernel;
|
||||
|
||||
/// Scheduler service that runs in the background and executes scheduled triggers
|
||||
pub struct SchedulerService {
|
||||
kernel: Arc<RwLock<Option<Kernel>>>,
|
||||
kernel: Arc<Mutex<Option<Kernel>>>,
|
||||
running: Arc<AtomicBool>,
|
||||
check_interval: Duration,
|
||||
}
|
||||
|
||||
impl SchedulerService {
|
||||
/// Create a new scheduler service
|
||||
pub fn new(kernel: Arc<RwLock<Option<Kernel>>>, check_interval_secs: u64) -> Self {
|
||||
pub fn new(kernel: Arc<Mutex<Option<Kernel>>>, check_interval_secs: u64) -> Self {
|
||||
Self {
|
||||
kernel,
|
||||
running: Arc::new(AtomicBool::new(false)),
|
||||
@@ -74,58 +74,56 @@ impl SchedulerService {
|
||||
|
||||
/// Check all scheduled triggers and fire those that are due
|
||||
async fn check_and_fire_scheduled_triggers(
|
||||
kernel_lock: &Arc<RwLock<Option<Kernel>>>,
|
||||
kernel_lock: &Arc<Mutex<Option<Kernel>>>,
|
||||
) -> Result<()> {
|
||||
let kernel_read = kernel_lock.read().await;
|
||||
let kernel = match kernel_read.as_ref() {
|
||||
Some(k) => k,
|
||||
None => return Ok(()),
|
||||
};
|
||||
// Collect due triggers under lock
|
||||
let to_execute: Vec<(String, String, String)> = {
|
||||
let kernel_guard = kernel_lock.lock().await;
|
||||
let kernel = match kernel_guard.as_ref() {
|
||||
Some(k) => k,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
// Get all triggers
|
||||
let triggers = kernel.list_triggers().await;
|
||||
let now = chrono::Utc::now();
|
||||
let triggers = kernel.list_triggers().await;
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
// Filter to enabled Schedule triggers
|
||||
let scheduled: Vec<_> = triggers.iter()
|
||||
.filter(|t| {
|
||||
t.config.enabled && matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. })
|
||||
})
|
||||
.collect();
|
||||
let scheduled: Vec<_> = triggers.iter()
|
||||
.filter(|t| {
|
||||
t.config.enabled && matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. })
|
||||
})
|
||||
.collect();
|
||||
|
||||
if scheduled.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if scheduled.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::debug!("[Scheduler] Checking {} scheduled triggers", scheduled.len());
|
||||
tracing::debug!("[Scheduler] Checking {} scheduled triggers", scheduled.len());
|
||||
|
||||
// Drop the read lock before executing
|
||||
let to_execute: Vec<(String, String, String)> = scheduled.iter()
|
||||
.filter_map(|t| {
|
||||
if let zclaw_hands::TriggerType::Schedule { ref cron } = t.config.trigger_type {
|
||||
// Simple cron matching: check if we should fire now
|
||||
if Self::should_fire_cron(cron, &now) {
|
||||
Some((t.config.id.clone(), t.config.hand_id.clone(), cron.clone()))
|
||||
scheduled.iter()
|
||||
.filter_map(|t| {
|
||||
if let zclaw_hands::TriggerType::Schedule { ref cron } = t.config.trigger_type {
|
||||
if Self::should_fire_cron(cron, &now) {
|
||||
Some((t.config.id.clone(), t.config.hand_id.clone(), cron.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
})
|
||||
.collect()
|
||||
}; // Lock dropped here
|
||||
|
||||
drop(kernel_read);
|
||||
|
||||
// Execute due triggers (with write lock since execute_hand may need it)
|
||||
// Execute due triggers (acquire lock per execution)
|
||||
let now = chrono::Utc::now();
|
||||
for (trigger_id, hand_id, cron_expr) in to_execute {
|
||||
tracing::info!(
|
||||
"[Scheduler] Firing scheduled trigger '{}' → hand '{}' (cron: {})",
|
||||
trigger_id, hand_id, cron_expr
|
||||
);
|
||||
|
||||
let kernel_read = kernel_lock.read().await;
|
||||
if let Some(kernel) = kernel_read.as_ref() {
|
||||
let kernel_guard = kernel_lock.lock().await;
|
||||
if let Some(kernel) = kernel_guard.as_ref() {
|
||||
let trigger_source = zclaw_types::TriggerSource::Scheduled {
|
||||
trigger_id: trigger_id.clone(),
|
||||
};
|
||||
@@ -265,9 +263,12 @@ impl SchedulerService {
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
// Check if current timestamp aligns with the interval
|
||||
// Check if current timestamp is within the scheduler check window of an interval boundary.
|
||||
// The scheduler checks every `check_interval` seconds (default 60s), so we use ±30s window.
|
||||
let timestamp = now.timestamp();
|
||||
timestamp % interval_secs == 0
|
||||
let remainder = timestamp % interval_secs;
|
||||
// Fire if we're within ±30 seconds of an interval boundary
|
||||
remainder <= 30 || remainder >= (interval_secs - 30)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -395,13 +395,6 @@ pub trait LlmIntentDriver: Send + Sync {
|
||||
) -> HashMap<String, serde_json::Value>;
|
||||
}
|
||||
|
||||
/// Default LLM driver implementation using prompt-based matching
|
||||
#[allow(dead_code)]
|
||||
pub struct DefaultLlmIntentDriver {
|
||||
/// Model ID to use
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
/// Runtime LLM driver that wraps zclaw-runtime's LlmDriver for actual LLM calls
|
||||
pub struct RuntimeLlmIntentDriver {
|
||||
driver: std::sync::Arc<dyn zclaw_runtime::driver::LlmDriver>,
|
||||
|
||||
@@ -13,6 +13,7 @@ zclaw-types = { workspace = true }
|
||||
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
@@ -19,14 +19,16 @@ const ACCESS_TOKEN_COOKIE: &str = "zclaw_access_token";
|
||||
const REFRESH_TOKEN_COOKIE: &str = "zclaw_refresh_token";
|
||||
|
||||
/// 构建 auth cookies 并附加到 CookieJar
|
||||
/// secure 标记在开发环境 (ZCLAW_SAAS_DEV=true) 设为 false,生产设为 true
|
||||
fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJar {
|
||||
let access_max_age = std::time::Duration::from_secs(2 * 3600); // 2h
|
||||
let refresh_max_age = std::time::Duration::from_secs(7 * 86400); // 7d
|
||||
let secure = !is_dev_mode();
|
||||
|
||||
// cookie crate 需要 time::Duration,从 std 转换
|
||||
let access = Cookie::build((ACCESS_TOKEN_COOKIE, token.to_string()))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.secure(secure)
|
||||
.same_site(SameSite::Strict)
|
||||
.path("/api")
|
||||
.max_age(access_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(3600).try_into().unwrap()))
|
||||
@@ -34,7 +36,7 @@ fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJ
|
||||
|
||||
let refresh = Cookie::build((REFRESH_TOKEN_COOKIE, refresh_token.to_string()))
|
||||
.http_only(true)
|
||||
.secure(true)
|
||||
.secure(secure)
|
||||
.same_site(SameSite::Strict)
|
||||
.path("/api/v1/auth")
|
||||
.max_age(refresh_max_age.try_into().unwrap_or_else(|_| std::time::Duration::from_secs(86400).try_into().unwrap()))
|
||||
@@ -43,6 +45,13 @@ fn set_auth_cookies(jar: CookieJar, token: &str, refresh_token: &str) -> CookieJ
|
||||
jar.add(access).add(refresh)
|
||||
}
|
||||
|
||||
/// 检查是否为开发模式(Cookie Secure、CORS 等安全策略依据此判断)
|
||||
fn is_dev_mode() -> bool {
|
||||
std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// 清除 auth cookies
|
||||
fn clear_auth_cookies(jar: CookieJar) -> CookieJar {
|
||||
jar.remove(Cookie::build(ACCESS_TOKEN_COOKIE).path("/api"))
|
||||
@@ -502,9 +511,40 @@ fn sha256_hex(input: &str) -> String {
|
||||
hex::encode(Sha256::digest(input.as_bytes()))
|
||||
}
|
||||
|
||||
/// POST /api/v1/auth/logout — 清除 auth cookies
|
||||
/// POST /api/v1/auth/logout — 撤销 refresh token 并清除 auth cookies
|
||||
pub async fn logout(
|
||||
State(state): State<AppState>,
|
||||
jar: CookieJar,
|
||||
) -> (CookieJar, axum::http::StatusCode) {
|
||||
// 尝试从 cookie 中获取 refresh token 并撤销
|
||||
if let Some(refresh_cookie) = jar.get(REFRESH_TOKEN_COOKIE) {
|
||||
let token = refresh_cookie.value();
|
||||
if let Ok(claims) = verify_token_skip_expiry(token, state.jwt_secret.expose_secret()) {
|
||||
if claims.token_type == "refresh" {
|
||||
if let Some(jti) = claims.jti {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
// 标记 refresh token 为已使用(等效于撤销/黑名单)
|
||||
let result = sqlx::query(
|
||||
"UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2 AND used_at IS NULL"
|
||||
)
|
||||
.bind(&now).bind(&jti)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
tracing::info!(account_id = %claims.sub, jti = %jti, "Refresh token revoked on logout");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(jti = %jti, error = %e, "Failed to revoke refresh token on logout");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(clear_auth_cookies(jar), axum::http::StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
@@ -212,7 +212,8 @@ impl SaaSConfig {
|
||||
|
||||
let mut config = if config_path.exists() {
|
||||
let content = std::fs::read_to_string(&config_path)?;
|
||||
toml::from_str(&content)?
|
||||
let interpolated = interpolate_env_vars(&content);
|
||||
toml::from_str(&interpolated)?
|
||||
} else {
|
||||
tracing::warn!("Config file {:?} not found, using defaults", config_path);
|
||||
SaaSConfig::default()
|
||||
@@ -291,3 +292,71 @@ impl SaaSConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 替换 TOML 配置文件中的 `${ENV_VAR}` 模式为环境变量值
|
||||
/// 未设置的环境变量保留原文,后续数据库连接或 JWT 初始化时会报明确错误
|
||||
fn interpolate_env_vars(content: &str) -> String {
|
||||
let mut result = String::with_capacity(content.len());
|
||||
let bytes = content.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'{' {
|
||||
let start = i + 2;
|
||||
let mut end = start;
|
||||
while end < bytes.len()
|
||||
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_')
|
||||
{
|
||||
end += 1;
|
||||
}
|
||||
if end < bytes.len() && bytes[end] == b'}' {
|
||||
let var_name = std::str::from_utf8(&bytes[start..end]).unwrap_or("");
|
||||
match std::env::var(var_name) {
|
||||
Ok(val) => {
|
||||
tracing::debug!("Config: ${{{}}} → resolved ({} bytes)", var_name, val.len());
|
||||
result.push_str(&val);
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!("Config: ${{{}}} not set, keeping placeholder", var_name);
|
||||
result.push_str(&format!("${{{}}}", var_name));
|
||||
}
|
||||
}
|
||||
i = end + 1;
|
||||
} else {
|
||||
result.push(bytes[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
result.push(bytes[i] as char);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_env_vars_resolves() {
|
||||
std::env::set_var("TEST_ZCLAW_DB_PW", "mypassword");
|
||||
let input = "url = \"postgres://user:${TEST_ZCLAW_DB_PW}@localhost/db\"";
|
||||
let result = interpolate_env_vars(input);
|
||||
assert_eq!(result, "url = \"postgres://user:mypassword@localhost/db\"");
|
||||
std::env::remove_var("TEST_ZCLAW_DB_PW");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_env_vars_missing_keeps_placeholder() {
|
||||
let input = "url = \"postgres://user:${NONEXISTENT_VAR_12345}@localhost/db\"";
|
||||
let result = interpolate_env_vars(input);
|
||||
assert_eq!(result, "url = \"postgres://user:${NONEXISTENT_VAR_12345}@localhost/db\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_env_vars_no_placeholders() {
|
||||
let input = "host = \"0.0.0.0\"\nport = 8080";
|
||||
let result = interpolate_env_vars(input);
|
||||
assert_eq!(result, input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ const SCHEMA_VERSION: i32 = 7;
|
||||
/// 初始化数据库
|
||||
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(20)
|
||||
.min_connections(2)
|
||||
.max_connections(50)
|
||||
.min_connections(3)
|
||||
.acquire_timeout(std::time::Duration::from_secs(5))
|
||||
.idle_timeout(std::time::Duration::from_secs(180))
|
||||
.max_lifetime(std::time::Duration::from_secs(900))
|
||||
@@ -21,6 +21,7 @@ pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
seed_admin_account(&pool).await?;
|
||||
seed_builtin_prompts(&pool).await?;
|
||||
seed_demo_data(&pool).await?;
|
||||
fix_seed_data(&pool).await?;
|
||||
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
|
||||
Ok(pool)
|
||||
}
|
||||
@@ -565,19 +566,32 @@ async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
|
||||
}
|
||||
|
||||
// ===== 7. Config Items =====
|
||||
// 分类名必须与 Admin V2 Config 页面 Tab key 一致: general/auth/relay/model/rate_limit/log
|
||||
let config_items = [
|
||||
("server", "max_connections", "integer", "50", "100", "Maximum database connections"),
|
||||
("server", "request_timeout_sec", "integer", "30", "60", "Request timeout in seconds"),
|
||||
("llm", "default_model", "string", "gpt-4o", "gpt-4o", "Default LLM model"),
|
||||
("llm", "max_context_tokens", "integer", "128000", "128000", "Maximum context window"),
|
||||
("llm", "stream_chunk_size", "integer", "1024", "1024", "Streaming chunk size in bytes"),
|
||||
("agent", "max_concurrent_tasks", "integer", "5", "10", "Maximum concurrent agent tasks"),
|
||||
("agent", "task_timeout_min", "integer", "30", "60", "Agent task timeout in minutes"),
|
||||
("memory", "max_entries", "integer", "10000", "50000", "Maximum memory entries per agent"),
|
||||
("memory", "compression_threshold", "integer", "100", "200", "Messages before compression"),
|
||||
("security", "rate_limit_enabled", "boolean", "true", "true", "Enable rate limiting"),
|
||||
("security", "max_requests_per_minute", "integer", "60", "120", "Max requests per minute per user"),
|
||||
("security", "content_filter_enabled", "boolean", "true", "true", "Enable content filtering"),
|
||||
("general", "max_connections", "integer", "50", "100", "最大数据库连接数"),
|
||||
("general", "request_timeout_sec", "integer", "30", "60", "请求超时秒数"),
|
||||
("general", "app_name", "string", "ZCLAW", "ZCLAW", "应用显示名称"),
|
||||
("general", "debug_mode", "boolean", "false", "false", "调试模式"),
|
||||
("auth", "session_ttl_hours", "integer", "24", "48", "会话有效期(小时)"),
|
||||
("auth", "refresh_token_ttl_days", "integer", "7", "30", "刷新令牌有效期(天)"),
|
||||
("auth", "max_login_attempts", "integer", "5", "10", "最大登录尝试次数"),
|
||||
("auth", "totp_enabled", "boolean", "false", "false", "启用 TOTP 两步验证"),
|
||||
("relay", "max_retries", "integer", "3", "5", "最大重试次数"),
|
||||
("relay", "retry_delay_sec", "integer", "5", "10", "重试延迟秒数"),
|
||||
("relay", "stream_timeout_sec", "integer", "120", "300", "流式响应超时秒数"),
|
||||
("relay", "max_concurrent_tasks", "integer", "10", "20", "最大并发中转任务"),
|
||||
("model", "default_model", "string", "gpt-4o", "gpt-4o", "默认 LLM 模型"),
|
||||
("model", "max_context_tokens", "integer", "128000", "128000", "最大上下文窗口"),
|
||||
("model", "stream_chunk_size", "integer", "1024", "1024", "流式响应块大小(bytes)"),
|
||||
("model", "temperature", "number", "0.7", "0.7", "默认温度参数"),
|
||||
("rate_limit", "rate_limit_enabled", "boolean", "true", "true", "启用请求限流"),
|
||||
("rate_limit", "max_requests_per_minute", "integer", "60", "120", "每分钟最大请求数"),
|
||||
("rate_limit", "burst_size", "integer", "10", "20", "突发请求上限"),
|
||||
("rate_limit", "content_filter_enabled", "boolean", "true", "true", "启用内容过滤"),
|
||||
("log", "log_level", "string", "info", "info", "日志级别"),
|
||||
("log", "log_retention_days", "integer", "30", "90", "日志保留天数"),
|
||||
("log", "audit_log_enabled", "boolean", "true", "true", "启用审计日志"),
|
||||
("log", "slow_query_threshold_ms", "integer", "1000", "2000", "慢查询阈值(ms)"),
|
||||
];
|
||||
for (cat, key, vtype, current, default, desc) in &config_items {
|
||||
let ts = now.to_rfc3339();
|
||||
@@ -589,7 +603,22 @@ async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// ===== 8. API Tokens =====
|
||||
// ===== 8. Account API Keys (account_api_keys 表) =====
|
||||
let account_api_keys = [
|
||||
("demo-akey-1", "demo-openai", "sk-demo-openai-key-1-xxxxx", "OpenAI API Key", "[\"relay:use\",\"model:read\"]"),
|
||||
("demo-akey-2", "demo-anthropic", "sk-ant-demo-key-1-xxxxx", "Anthropic API Key", "[\"relay:use\",\"model:read\",\"config:read\"]"),
|
||||
("demo-akey-3", "demo-deepseek", "sk-demo-deepseek-key-1-xxxxx", "DeepSeek API Key", "[\"relay:use\"]"),
|
||||
];
|
||||
for (id, provider_id, key_val, label, perms) in &account_api_keys {
|
||||
let ts = now.to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(&admin_id).bind(provider_id).bind(key_val).bind(label).bind(perms).bind(&ts)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
|
||||
// 保留旧 api_tokens 表的种子数据(兼容旧代码路径)
|
||||
let api_tokens = [
|
||||
("demo-token-1", "Production API Key", "zclaw_prod_xr7Km9pQ2nBv", "[\"relay:use\",\"model:read\"]"),
|
||||
("demo-token-2", "Development Key", "zclaw_dev_aB3cD5eF7gH9", "[\"relay:use\",\"model:read\",\"config:read\"]"),
|
||||
@@ -662,6 +691,123 @@ async fn seed_demo_data(pool: &PgPool) -> SaasResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 修复旧种子数据:更新 config_items 分类名 + 补充 account_api_keys + 更新旧数据 account_id
|
||||
///
|
||||
/// 历史问题:
|
||||
/// - 旧 config_items 使用 server/llm/agent/memory/security 分类,与 Admin V2 前端 Tab 不匹配
|
||||
/// - 旧种子将 API Keys 写入 api_tokens 表,但 handler 读 account_api_keys 表
|
||||
/// - 旧种子数据的 account_id 可能与当前 admin 不匹配
|
||||
async fn fix_seed_data(pool: &PgPool) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
// 1. 获取所有 super_admin account_id(可能有多个)
|
||||
let admins: Vec<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM accounts WHERE role = 'super_admin'"
|
||||
).fetch_all(pool).await?;
|
||||
if admins.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let admin_ids: Vec<String> = admins.into_iter().map(|(id,)| id).collect();
|
||||
|
||||
// 2. 更新 config_items 分类名(旧 → 新)
|
||||
let category_mappings = [
|
||||
("server", "general"),
|
||||
("llm", "model"),
|
||||
("agent", "general"),
|
||||
("memory", "general"),
|
||||
("security", "rate_limit"),
|
||||
];
|
||||
for (old_cat, new_cat) in &category_mappings {
|
||||
let result = sqlx::query(
|
||||
"UPDATE config_items SET category = $1, updated_at = $2 WHERE category = $3"
|
||||
).bind(new_cat).bind(&now).bind(old_cat)
|
||||
.execute(pool).await?;
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!("Fixed config_items category: {} → {} ({} rows)", old_cat, new_cat, result.rows_affected());
|
||||
}
|
||||
}
|
||||
|
||||
// 如果新分类没有数据,补种默认配置项(幂等 ON CONFLICT DO NOTHING)
|
||||
let general_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM config_items WHERE category = 'general'")
|
||||
.fetch_one(pool).await?;
|
||||
if general_count.0 == 0 {
|
||||
let new_configs = [
|
||||
("general", "max_connections", "integer", "50", "100", "最大数据库连接数"),
|
||||
("general", "request_timeout_sec", "integer", "30", "60", "请求超时秒数"),
|
||||
("general", "app_name", "string", "ZCLAW", "ZCLAW", "应用显示名称"),
|
||||
("auth", "session_ttl_hours", "integer", "24", "48", "会话有效期(小时)"),
|
||||
("relay", "max_retries", "integer", "3", "5", "最大重试次数"),
|
||||
("model", "default_model", "string", "gpt-4o", "gpt-4o", "默认 LLM 模型"),
|
||||
("rate_limit", "rate_limit_enabled", "boolean", "true", "true", "启用请求限流"),
|
||||
("log", "log_level", "string", "info", "info", "日志级别"),
|
||||
];
|
||||
for (cat, key, vtype, current, default, desc) in &new_configs {
|
||||
let id = format!("cfg-{}-{}", cat, key);
|
||||
sqlx::query(
|
||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, $8, $8) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(&id).bind(cat).bind(key).bind(vtype).bind(current).bind(default).bind(desc).bind(&now)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
tracing::info!("Seeded {} new config items for updated categories", new_configs.len());
|
||||
}
|
||||
|
||||
// 3. 补种 account_api_keys(幂等 ON CONFLICT DO NOTHING)— 为每个 admin 补种
|
||||
let provider_keys: Vec<(String, String)> = sqlx::query_as(
|
||||
"SELECT id, provider_id FROM providers LIMIT 5"
|
||||
).fetch_all(pool).await.unwrap_or_default();
|
||||
|
||||
for admin_id in &admin_ids {
|
||||
let akey_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1")
|
||||
.bind(admin_id).fetch_one(pool).await?;
|
||||
if akey_count.0 > 0 { continue; }
|
||||
|
||||
let demo_keys = [
|
||||
(format!("demo-akey-1-{}", &admin_id[..8]), "OpenAI API Key", "sk-demo-openai-key-1-xxxxx", "[\"relay:use\",\"model:read\"]"),
|
||||
(format!("demo-akey-2-{}", &admin_id[..8]), "Anthropic API Key", "sk-ant-demo-key-1-xxxxx", "[\"relay:use\",\"model:read\"]"),
|
||||
(format!("demo-akey-3-{}", &admin_id[..8]), "DeepSeek API Key", "sk-demo-deepseek-key-1-xxxxx", "[\"relay:use\"]"),
|
||||
];
|
||||
for (idx, (id, label, key_val, perms)) in demo_keys.iter().enumerate() {
|
||||
let provider_id = provider_keys.get(idx).map(|(_, pid)| pid.as_str()).unwrap_or("demo-openai");
|
||||
sqlx::query(
|
||||
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7) ON CONFLICT (id) DO NOTHING"
|
||||
).bind(id).bind(admin_id).bind(provider_id).bind(key_val).bind(label).bind(perms).bind(&now)
|
||||
.execute(pool).await?;
|
||||
}
|
||||
tracing::info!("Seeded {} account_api_keys for admin {}", demo_keys.len(), admin_id);
|
||||
}
|
||||
|
||||
// 4. 更新旧种子数据 — 将所有 relay_tasks/usage_records/operation_logs 等的 account_id
|
||||
// 更新为每个 super_admin 都能看到(复制或统一)
|
||||
// 策略:统一为第一个 super_admin,然后为其余 admin 也复制关键数据
|
||||
let primary_admin = &admin_ids[0];
|
||||
for table in &["relay_tasks", "usage_records", "operation_logs", "telemetry_reports"] {
|
||||
// 统计该表有多少不同的 account_id
|
||||
let distinct_count: (i64,) = sqlx::query_as(
|
||||
&format!("SELECT COUNT(DISTINCT account_id) FROM {}", table)
|
||||
).fetch_one(pool).await.unwrap_or((0,));
|
||||
|
||||
if distinct_count.0 > 0 {
|
||||
// 将所有非 primary_admin 的数据更新为 primary_admin
|
||||
let result = sqlx::query(
|
||||
&format!("UPDATE {} SET account_id = $1 WHERE account_id != $1", table)
|
||||
).bind(primary_admin)
|
||||
.execute(pool).await?;
|
||||
if result.rows_affected() > 0 {
|
||||
tracing::info!("Unified {} account_id to {} ({} rows fixed)", table, primary_admin, result.rows_affected());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 也更新 api_tokens 表的 account_id
|
||||
let _ = sqlx::query("UPDATE api_tokens SET account_id = $1 WHERE account_id != $1")
|
||||
.bind(primary_admin).execute(pool).await?;
|
||||
|
||||
tracing::info!("Seed data fix completed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容
|
||||
|
||||
@@ -217,7 +217,7 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
let protected_routes = zclaw_saas::auth::protected_routes()
|
||||
.merge(zclaw_saas::account::routes())
|
||||
.merge(zclaw_saas::model_config::routes())
|
||||
.merge(zclaw_saas::relay::routes())
|
||||
// relay::routes() 不在此合并 — SSE 端点需要更长超时,在最终 Router 单独合并
|
||||
.merge(zclaw_saas::migration::routes())
|
||||
.merge(zclaw_saas::role::routes())
|
||||
.merge(zclaw_saas::prompt::routes())
|
||||
@@ -247,9 +247,28 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
.merge(protected_routes)
|
||||
.layer(TimeoutLayer::new(std::time::Duration::from_secs(15)));
|
||||
|
||||
// Relay 路由需要独立的认证中间件(因为被排除在 15s 超时层之外)
|
||||
let relay_routes = zclaw_saas::relay::routes()
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::api_version_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::request_id_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::rate_limit_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
));
|
||||
|
||||
axum::Router::new()
|
||||
.merge(non_streaming_routes)
|
||||
.merge(zclaw_saas::relay::routes())
|
||||
.merge(relay_routes)
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
|
||||
@@ -58,6 +58,11 @@ pub async fn rate_limit_middleware(
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response<Body> {
|
||||
// GET 请求不计入限流 — 前端导航/轮询产生的 GET 不应触发 429
|
||||
if req.method() == axum::http::Method::GET {
|
||||
return next.run(req).await;
|
||||
}
|
||||
|
||||
let account_id = req.extensions()
|
||||
.get::<AuthContext>()
|
||||
.map(|ctx| ctx.account_id.clone())
|
||||
@@ -91,15 +96,39 @@ pub async fn rate_limit_middleware(
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// 公共端点速率限制中间件 (基于客户端 IP,更严格)
|
||||
/// 公共端点速率限制中间件 (基于客户端 IP,按路径差异化限流)
|
||||
/// 用于登录/注册/刷新等无认证端点,防止暴力破解
|
||||
const PUBLIC_RATE_LIMIT_RPM: usize = 20;
|
||||
///
|
||||
/// 限流策略:
|
||||
/// - /auth/login: 5 次/分钟/IP
|
||||
/// - /auth/register: 3 次/小时/IP
|
||||
/// - 其他 (refresh): 20 次/分钟/IP
|
||||
const LOGIN_RATE_LIMIT: usize = 5;
|
||||
const LOGIN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
const REGISTER_RATE_LIMIT: usize = 3;
|
||||
const REGISTER_RATE_LIMIT_WINDOW_SECS: u64 = 3600;
|
||||
const DEFAULT_PUBLIC_RATE_LIMIT: usize = 20;
|
||||
const DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
|
||||
pub async fn public_rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response<Body> {
|
||||
let path = req.uri().path();
|
||||
|
||||
// 根据路径选择限流策略
|
||||
let (limit, window_secs, key_prefix, error_msg) = if path.ends_with("/auth/login") {
|
||||
(LOGIN_RATE_LIMIT, LOGIN_RATE_LIMIT_WINDOW_SECS,
|
||||
"auth_login_rate_limit", "登录请求过于频繁,请稍后再试")
|
||||
} else if path.ends_with("/auth/register") {
|
||||
(REGISTER_RATE_LIMIT, REGISTER_RATE_LIMIT_WINDOW_SECS,
|
||||
"auth_register_rate_limit", "注册请求过于频繁,请一小时后再试")
|
||||
} else {
|
||||
(DEFAULT_PUBLIC_RATE_LIMIT, DEFAULT_PUBLIC_RATE_LIMIT_WINDOW_SECS,
|
||||
"public_rate_limit", "请求频率超限,请稍后再试")
|
||||
};
|
||||
|
||||
// 从连接信息或 header 提取客户端 IP
|
||||
let client_ip = req.extensions()
|
||||
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
|
||||
@@ -113,15 +142,16 @@ pub async fn public_rate_limit_middleware(
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
});
|
||||
|
||||
let key = format!("public_rate_limit:{}", client_ip);
|
||||
let key = format!("{}:{}", key_prefix, client_ip);
|
||||
let now = Instant::now();
|
||||
let window_start = now - std::time::Duration::from_secs(60);
|
||||
let window_start = now - std::time::Duration::from_secs(window_secs);
|
||||
|
||||
// DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放
|
||||
let blocked = {
|
||||
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
|
||||
entries.retain(|&time| time > window_start);
|
||||
|
||||
if entries.len() >= PUBLIC_RATE_LIMIT_RPM {
|
||||
if entries.len() >= limit {
|
||||
true
|
||||
} else {
|
||||
entries.push(now);
|
||||
@@ -130,9 +160,7 @@ pub async fn public_rate_limit_middleware(
|
||||
};
|
||||
|
||||
if blocked {
|
||||
return SaasError::RateLimited(
|
||||
"请求频率超限,请稍后再试".into()
|
||||
).into_response();
|
||||
return SaasError::RateLimited(error_msg.into()).into_response();
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
|
||||
@@ -10,11 +10,11 @@ pub struct RelayTaskRow {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub status: String,
|
||||
pub priority: i64,
|
||||
pub attempt_count: i64,
|
||||
pub max_attempts: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub priority: i32,
|
||||
pub attempt_count: i32,
|
||||
pub max_attempts: i32,
|
||||
pub input_tokens: i32,
|
||||
pub output_tokens: i32,
|
||||
pub error_message: Option<String>,
|
||||
pub queued_at: String,
|
||||
pub started_at: Option<String>,
|
||||
|
||||
@@ -25,7 +25,7 @@ pub async fn create_relay_task(
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request_body: &str,
|
||||
priority: i64,
|
||||
priority: i32,
|
||||
max_attempts: u32,
|
||||
) -> SaasResult<RelayTaskInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
@@ -29,11 +29,11 @@ pub struct RelayTaskInfo {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub status: String,
|
||||
pub priority: i64,
|
||||
pub attempt_count: i64,
|
||||
pub max_attempts: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub priority: i32,
|
||||
pub attempt_count: i32,
|
||||
pub max_attempts: i32,
|
||||
pub input_tokens: i32,
|
||||
pub output_tokens: i32,
|
||||
pub error_message: Option<String>,
|
||||
pub queued_at: String,
|
||||
pub started_at: Option<String>,
|
||||
|
||||
@@ -4,7 +4,7 @@ pub mod types;
|
||||
pub mod service;
|
||||
pub mod handlers;
|
||||
|
||||
use axum::routing::{get, post, patch, delete};
|
||||
use axum::routing::get;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 定时任务路由 (需要认证)
|
||||
|
||||
@@ -6,6 +6,7 @@ use super::types::*;
|
||||
|
||||
/// 数据库行结构
|
||||
#[derive(Debug, FromRow)]
|
||||
#[allow(dead_code)]
|
||||
struct ScheduledTaskRow {
|
||||
id: String,
|
||||
account_id: String,
|
||||
|
||||
@@ -122,14 +122,11 @@ pub fn start_user_task_scheduler(db: PgPool) {
|
||||
}
|
||||
|
||||
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
// 查找到期任务
|
||||
// 查找到期任务(next_run_at 兼容 TEXT 和 TIMESTAMPTZ 两种列类型)
|
||||
let due_tasks: Vec<(String, String, String)> = sqlx::query_as(
|
||||
"SELECT id, schedule_type, target_type FROM scheduled_tasks
|
||||
WHERE enabled = TRUE AND next_run_at <= $1"
|
||||
WHERE enabled = TRUE AND next_run_at::TIMESTAMPTZ <= NOW()"
|
||||
)
|
||||
.bind(&now)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
@@ -140,16 +137,14 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
tracing::debug!("[UserScheduler] {} tasks due", due_tasks.len());
|
||||
|
||||
for (task_id, schedule_type, _target_type) in due_tasks {
|
||||
// 标记执行
|
||||
let now_str = chrono::Utc::now().to_rfc3339();
|
||||
// 标记执行(用 NOW() 写入时间戳)
|
||||
let result = sqlx::query(
|
||||
"UPDATE scheduled_tasks
|
||||
SET last_run_at = $1, run_count = run_count + 1, updated_at = $1,
|
||||
SET last_run_at = NOW(), run_count = run_count + 1, updated_at = NOW(),
|
||||
enabled = CASE WHEN schedule_type = 'once' THEN FALSE ELSE TRUE END,
|
||||
next_run_at = NULL
|
||||
WHERE id = $2"
|
||||
WHERE id = $1"
|
||||
)
|
||||
.bind(&now_str)
|
||||
.bind(&task_id)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use crate::config::SaaSConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
|
||||
@@ -27,10 +28,12 @@ pub struct AppState {
|
||||
rate_limit_rpm: Arc<AtomicU32>,
|
||||
/// Worker 调度器 (异步后台任务)
|
||||
pub worker_dispatcher: WorkerDispatcher,
|
||||
/// 优雅停机令牌 — 触发后所有 SSE 流和长连接应立即终止
|
||||
pub shutdown_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher) -> anyhow::Result<Self> {
|
||||
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken) -> anyhow::Result<Self> {
|
||||
let jwt_secret = config.jwt_secret()?;
|
||||
let rpm = config.rate_limit.requests_per_minute;
|
||||
Ok(Self {
|
||||
@@ -42,6 +45,7 @@ impl AppState {
|
||||
totp_fail_counts: Arc::new(dashmap::DashMap::new()),
|
||||
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
|
||||
worker_dispatcher,
|
||||
shutdown_token,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -55,9 +59,10 @@ impl AppState {
|
||||
self.rate_limit_rpm.store(rpm, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// 清理过期的限流条目 (60 秒窗口外的记录)
|
||||
/// 清理过期的限流条目
|
||||
/// 使用 3600s 窗口以覆盖 register rate limit (3次/小时) 的完整周期
|
||||
pub fn cleanup_rate_limit_entries(&self) {
|
||||
let window_start = Instant::now() - std::time::Duration::from_secs(60);
|
||||
let window_start = Instant::now() - std::time::Duration::from_secs(3600);
|
||||
self.rate_limit_entries.retain(|_, entries| {
|
||||
entries.retain(|&ts| ts > window_start);
|
||||
!entries.is_empty()
|
||||
|
||||
@@ -22,10 +22,12 @@ use axum::http::{Request, StatusCode};
|
||||
use axum::Router;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tower::ServiceExt;
|
||||
use zclaw_saas::config::SaaSConfig;
|
||||
use zclaw_saas::db::init_db;
|
||||
use zclaw_saas::state::AppState;
|
||||
use zclaw_saas::workers::WorkerDispatcher;
|
||||
|
||||
pub const MAX_BODY: usize = 2 * 1024 * 1024;
|
||||
pub const DEFAULT_PASSWORD: &str = "testpassword123";
|
||||
@@ -129,7 +131,9 @@ pub async fn build_test_app() -> (Router, PgPool) {
|
||||
config.rate_limit.requests_per_minute = 10_000;
|
||||
config.rate_limit.burst = 1_000;
|
||||
|
||||
let state = AppState::new(pool.clone(), config).expect("AppState::new failed");
|
||||
let dispatcher = WorkerDispatcher::new(pool.clone());
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let state = AppState::new(pool.clone(), config, dispatcher, shutdown_token).expect("AppState::new failed");
|
||||
let router = build_router(state);
|
||||
(router, pool)
|
||||
}
|
||||
|
||||
@@ -80,11 +80,9 @@ impl ToolResult {
|
||||
pub mod builtin_tools {
|
||||
pub const FILE_READ: &str = "file_read";
|
||||
pub const FILE_WRITE: &str = "file_write";
|
||||
pub const FILE_LIST: &str = "file_list";
|
||||
pub const SHELL_EXEC: &str = "shell_exec";
|
||||
pub const WEB_FETCH: &str = "web_fetch";
|
||||
pub const WEB_SEARCH: &str = "web_search";
|
||||
pub const MEMORY_STORE: &str = "memory_store";
|
||||
pub const MEMORY_RECALL: &str = "memory_recall";
|
||||
pub const MEMORY_SEARCH: &str = "memory_search";
|
||||
// NOTE: FILE_LIST, WEB_SEARCH, MEMORY_STORE/RECALL/SEARCH were removed —
|
||||
// these had no corresponding tool implementations. Memory operations are
|
||||
// handled by the Growth system (MemoryMiddleware + VikingStorage).
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user