fix(security): implement all 15 security fixes from penetration test V1
Security audit (2026-03-31): 5 HIGH + 10 MEDIUM issues, all fixed. HIGH: - H1: JWT password_version mechanism (pwv in Claims, middleware verification, auto-increment on password change) - H2: Docker saas port bound to 127.0.0.1 - H3: TOTP encryption key decoupled from JWT secret (production bailout) - H4+H5: Tauri CSP hardened (removed unsafe-inline, restricted connect-src) MEDIUM: - M1: Persistent rate limiting (PostgreSQL rate_limit_events table) - M2: Account lockout (5 failures -> 15min lock) - M3: RFC 5322 email validation with regex - M4: Device registration typed struct with length limits - M5: Provider URL validation on create/update (SSRF prevention) - M6: Legacy TOTP secret migration (fixed nonce -> random nonce) - M7: Legacy frontend crypto migration (static salt -> random salt) - M8+M9: Admin frontend: removed JS token storage, HttpOnly cookie only - M10: Pipeline debug log sanitization (keys only, 100-char truncation) Also: fixed CLAUDE.md Section 12 (was corrupted), added title.rs middleware skeleton, fixed RegisterDeviceRequest visibility.
This commit is contained in:
@@ -131,10 +131,7 @@ impl ActionRegistry {
|
||||
json_mode: bool,
|
||||
) -> Result<Value, ActionError> {
|
||||
tracing::debug!(target: "pipeline_actions", "execute_llm: Called with template length: {}", template.len());
|
||||
tracing::debug!(target: "pipeline_actions", "execute_llm: Input HashMap contents:");
|
||||
for (k, v) in &input {
|
||||
tracing::debug!(target: "pipeline_actions", " {} => {:?}", k, v);
|
||||
}
|
||||
tracing::debug!(target: "pipeline_actions", "execute_llm: input keys ({}): {:?}", input.len(), input.keys().collect::<Vec<_>>());
|
||||
|
||||
if let Some(driver) = &self.llm_driver {
|
||||
// Load template if it's a file path
|
||||
|
||||
@@ -186,22 +186,17 @@ impl PipelineExecutor {
|
||||
match action {
|
||||
Action::LlmGenerate { template, input, model, temperature, max_tokens, json_mode } => {
|
||||
tracing::debug!(target: "pipeline_executor", "LlmGenerate action called");
|
||||
tracing::debug!(target: "pipeline_executor", "Raw input map:");
|
||||
for (k, v) in input {
|
||||
tracing::debug!(target: "pipeline_executor", " {} => {}", k, v);
|
||||
}
|
||||
tracing::debug!(target: "pipeline_executor", "input keys: {:?}", input.keys().collect::<Vec<_>>());
|
||||
|
||||
// First resolve the template itself (handles ${inputs.xxx}, ${item.xxx}, etc.)
|
||||
let resolved_template = context.resolve(template)?;
|
||||
let resolved_template_str = resolved_template.as_str().unwrap_or(template).to_string();
|
||||
tracing::debug!(target: "pipeline_executor", "Resolved template (first 300 chars): {}",
|
||||
&resolved_template_str[..resolved_template_str.len().min(300)]);
|
||||
tracing::debug!(target: "pipeline_executor", "Resolved template ({} chars, first 100): {}",
|
||||
resolved_template_str.len(),
|
||||
&resolved_template_str[..resolved_template_str.len().min(100)]);
|
||||
|
||||
let resolved_input = context.resolve_map(input)?;
|
||||
tracing::debug!(target: "pipeline_executor", "Resolved input map:");
|
||||
for (k, v) in &resolved_input {
|
||||
tracing::debug!(target: "pipeline_executor", " {} => {:?}", k, v);
|
||||
}
|
||||
tracing::debug!(target: "pipeline_executor", "Resolved input keys: {:?}", resolved_input.keys().collect::<Vec<_>>());
|
||||
self.action_registry.execute_llm(
|
||||
&resolved_template_str,
|
||||
resolved_input,
|
||||
|
||||
37
crates/zclaw-runtime/src/middleware/title.rs
Normal file
37
crates/zclaw-runtime/src/middleware/title.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
//! Title generation middleware — auto-generates conversation titles after the first turn.
|
||||
//!
|
||||
//! Inspired by DeerFlow's TitleMiddleware: after the first user-assistant exchange,
|
||||
//! generates a short descriptive title using the LLM instead of defaulting to
|
||||
//! "新对话" or truncating the user's first message.
|
||||
//!
|
||||
//! Priority 180 — runs after compaction (100) and memory (150), before skill index (200).
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext};
|
||||
|
||||
/// Middleware that auto-generates conversation titles after the first exchange.
|
||||
pub struct TitleMiddleware {
|
||||
/// Whether a title has been generated for the current session.
|
||||
titled: std::sync::atomic::AtomicBool,
|
||||
}
|
||||
|
||||
impl TitleMiddleware {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
titled: std::sync::atomic::AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TitleMiddleware {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for TitleMiddleware {
|
||||
fn name(&self) -> &str { "title" }
|
||||
fn priority(&self) -> i32 { 180 }
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
-- H1 Security Fix: password_version for JWT invalidation on password change
|
||||
-- When password changes, password_version increments, invalidating all existing JWTs
|
||||
|
||||
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS password_version INTEGER NOT NULL DEFAULT 1;
|
||||
|
||||
-- Failed login tracking for account lockout (M2)
|
||||
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS failed_login_count INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS locked_until TIMESTAMPTZ;
|
||||
@@ -0,0 +1,12 @@
|
||||
-- M1 Security Fix: Persistent rate limiting events table
|
||||
-- Replaces in-memory DashMap to survive server restarts
|
||||
|
||||
CREATE TABLE IF NOT EXISTS rate_limit_events (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
key TEXT NOT NULL,
|
||||
window_start TIMESTAMPTZ NOT NULL,
|
||||
count INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_rle_key_window ON rate_limit_events (key, window_start);
|
||||
@@ -213,18 +213,40 @@ pub async fn dashboard_stats(
|
||||
|
||||
// ============ Devices ============
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub(super) struct RegisterDeviceRequest {
|
||||
#[serde(default)]
|
||||
device_id: String,
|
||||
#[serde(default)]
|
||||
device_name: String,
|
||||
#[serde(default)]
|
||||
platform: String,
|
||||
#[serde(default)]
|
||||
app_version: String,
|
||||
}
|
||||
|
||||
/// POST /api/v1/devices/register — 注册或更新设备
|
||||
pub async fn register_device(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<serde_json::Value>,
|
||||
Json(req): Json<RegisterDeviceRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
let device_id = req.get("device_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
||||
let device_name = req.get("device_name").and_then(|v| v.as_str()).unwrap_or("Unknown");
|
||||
let platform = req.get("platform").and_then(|v| v.as_str()).unwrap_or("unknown");
|
||||
let app_version = req.get("app_version").and_then(|v| v.as_str()).unwrap_or("");
|
||||
// 输入验证
|
||||
if req.device_id.is_empty() || req.device_id.len() > 64 {
|
||||
return Err(SaasError::InvalidInput("device_id 必须为 1-64 个字符".into()));
|
||||
}
|
||||
if req.device_name.len() > 128 {
|
||||
return Err(SaasError::InvalidInput("device_name 最多 128 个字符".into()));
|
||||
}
|
||||
if req.platform.len() > 32 {
|
||||
return Err(SaasError::InvalidInput("platform 最多 32 个字符".into()));
|
||||
}
|
||||
if req.app_version.len() > 32 {
|
||||
return Err(SaasError::InvalidInput("app_version 最多 32 个字符".into()));
|
||||
}
|
||||
|
||||
let device_name = if req.device_name.is_empty() { "Unknown" } else { &req.device_name };
|
||||
let platform = if req.platform.is_empty() { "unknown" } else { &req.platform };
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let device_uuid = uuid::Uuid::new_v4().to_string();
|
||||
@@ -238,19 +260,19 @@ pub async fn register_device(
|
||||
)
|
||||
.bind(&device_uuid)
|
||||
.bind(&ctx.account_id)
|
||||
.bind(device_id)
|
||||
.bind(&req.device_id)
|
||||
.bind(device_name)
|
||||
.bind(platform)
|
||||
.bind(app_version)
|
||||
.bind(&req.app_version)
|
||||
.bind(&now)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "device.register", "device", device_id,
|
||||
log_operation(&state.db, &ctx.account_id, "device.register", "device", &req.device_id,
|
||||
Some(serde_json::json!({"device_name": device_name, "platform": platform})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "device_id": device_id})))
|
||||
Ok(Json(serde_json::json!({"ok": true, "device_id": req.device_id})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/devices/heartbeat — 设备心跳
|
||||
|
||||
@@ -80,6 +80,14 @@ pub async fn register(
|
||||
if !req.email.contains('@') || !req.email.contains('.') {
|
||||
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||||
}
|
||||
// M3: 严格邮箱格式校验
|
||||
static EMAIL_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
|
||||
let email_re = EMAIL_RE.get_or_init(|| regex::Regex::new(
|
||||
r"^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$"
|
||||
).unwrap());
|
||||
if !email_re.is_match(&req.email) {
|
||||
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||||
}
|
||||
if req.password.len() < 8 {
|
||||
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
||||
}
|
||||
@@ -129,16 +137,25 @@ pub async fn register(
|
||||
|
||||
// 注册成功后自动签发 JWT + Refresh Token
|
||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||||
// 查询新创建账户的 password_version (默认为 1)
|
||||
let (pwv,): (i32,) = sqlx::query_as(
|
||||
"SELECT password_version FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&account_id)
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
let config = state.config.read().await;
|
||||
let token = create_token(
|
||||
&account_id, &role, permissions.clone(),
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.jwt_expiration_hours,
|
||||
pwv as u32,
|
||||
)?;
|
||||
let refresh_token = create_refresh_token(
|
||||
&account_id, &role, permissions,
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.refresh_token_hours,
|
||||
pwv as u32,
|
||||
)?;
|
||||
drop(config);
|
||||
|
||||
@@ -173,11 +190,12 @@ pub async fn login(
|
||||
jar: CookieJar,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> SaasResult<(CookieJar, Json<LoginResponse>)> {
|
||||
// 一次查询获取用户信息 + password_hash + totp_secret(合并原来的 3 次查询)
|
||||
// 一次查询获取用户信息 + password_hash + totp_secret + 安全字段(合并原来的 3 次查询)
|
||||
let row: Option<AccountLoginRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled,
|
||||
password_hash, totp_secret, created_at, llm_routing
|
||||
password_hash, totp_secret, created_at, llm_routing,
|
||||
password_version, failed_login_count, locked_until
|
||||
FROM accounts WHERE username = $1 OR email = $1"
|
||||
)
|
||||
.bind(&req.username)
|
||||
@@ -190,7 +208,38 @@ pub async fn login(
|
||||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
|
||||
}
|
||||
|
||||
// M2: 检查账号是否被临时锁定
|
||||
if let Some(ref locked_until_str) = r.locked_until {
|
||||
if let Ok(locked_time) = chrono::DateTime::parse_from_rfc3339(locked_until_str) {
|
||||
if chrono::Utc::now() < locked_time.with_timezone(&chrono::Utc) {
|
||||
return Err(SaasError::AuthError("账号已被临时锁定,请稍后再试".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? {
|
||||
// M2: 密码错误,递增失败计数
|
||||
let new_count = r.failed_login_count + 1;
|
||||
if new_count >= 5 {
|
||||
// 锁定 15 分钟
|
||||
let locked_until = (chrono::Utc::now() + chrono::Duration::minutes(15)).to_rfc3339();
|
||||
sqlx::query(
|
||||
"UPDATE accounts SET failed_login_count = $1, locked_until = $2 WHERE id = $3"
|
||||
)
|
||||
.bind(new_count)
|
||||
.bind(&locked_until)
|
||||
.bind(&r.id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE accounts SET failed_login_count = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(new_count)
|
||||
.bind(&r.id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
}
|
||||
return Err(SaasError::AuthError("用户名或密码错误".into()));
|
||||
}
|
||||
|
||||
@@ -216,20 +265,24 @@ pub async fn login(
|
||||
|
||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &r.role).await?;
|
||||
let config = state.config.read().await;
|
||||
let pwv = r.password_version as u32;
|
||||
let token = create_token(
|
||||
&r.id, &r.role, permissions.clone(),
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.jwt_expiration_hours,
|
||||
pwv,
|
||||
)?;
|
||||
let refresh_token = create_refresh_token(
|
||||
&r.id, &r.role, permissions,
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.refresh_token_hours,
|
||||
pwv,
|
||||
)?;
|
||||
drop(config);
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
|
||||
// 登录成功: 重置失败计数和锁定状态
|
||||
sqlx::query("UPDATE accounts SET last_login_at = $1, failed_login_count = 0, locked_until = NULL WHERE id = $2")
|
||||
.bind(&now).bind(&r.id)
|
||||
.execute(&state.db).await?;
|
||||
let client_ip = addr.ip().to_string();
|
||||
@@ -296,7 +349,7 @@ pub async fn refresh(
|
||||
.bind(&now).bind(jti)
|
||||
.execute(&state.db).await?;
|
||||
|
||||
// 6. 获取最新角色权限
|
||||
// 6. 获取最新角色权限 + password_version
|
||||
let (role,): (String,) = sqlx::query_as(
|
||||
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
|
||||
)
|
||||
@@ -305,6 +358,13 @@ pub async fn refresh(
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
|
||||
|
||||
let (pwv,): (i32,) = sqlx::query_as(
|
||||
"SELECT password_version FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&claims.sub)
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
|
||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
||||
|
||||
// 7. 创建新的 access token + refresh token
|
||||
@@ -313,11 +373,13 @@ pub async fn refresh(
|
||||
&claims.sub, &role, permissions.clone(),
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.jwt_expiration_hours,
|
||||
pwv as u32,
|
||||
)?;
|
||||
let new_refresh = create_refresh_token(
|
||||
&claims.sub, &role, permissions.clone(),
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.refresh_token_hours,
|
||||
pwv as u32,
|
||||
)?;
|
||||
drop(config);
|
||||
|
||||
@@ -390,10 +452,10 @@ pub async fn change_password(
|
||||
return Err(SaasError::AuthError("旧密码错误".into()));
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
// 更新密码 + 递增 password_version 使旧 token 失效
|
||||
let new_hash = hash_password_async(req.new_password.clone()).await?;
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
|
||||
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2, password_version = password_version + 1 WHERE id = $3")
|
||||
.bind(&new_hash)
|
||||
.bind(&now)
|
||||
.bind(&ctx.account_id)
|
||||
|
||||
@@ -17,6 +17,9 @@ pub struct Claims {
|
||||
/// token 类型: "access" 或 "refresh"
|
||||
#[serde(default = "default_token_type")]
|
||||
pub token_type: String,
|
||||
/// password version — 密码变更后自增,使旧 token 失效
|
||||
#[serde(default = "default_pwv")]
|
||||
pub pwv: u32,
|
||||
pub iat: i64,
|
||||
pub exp: i64,
|
||||
}
|
||||
@@ -25,8 +28,12 @@ fn default_token_type() -> String {
|
||||
"access".to_string()
|
||||
}
|
||||
|
||||
fn default_pwv() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn new_access(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
|
||||
pub fn new_access(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64, pwv: u32) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
jti: Some(uuid::Uuid::new_v4().to_string()),
|
||||
@@ -34,13 +41,14 @@ impl Claims {
|
||||
role: role.to_string(),
|
||||
permissions,
|
||||
token_type: "access".to_string(),
|
||||
pwv,
|
||||
iat: now.timestamp(),
|
||||
exp: (now + Duration::hours(expiration_hours)).timestamp(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建 refresh token claims (有效期更长,用于一次性刷新)
|
||||
pub fn new_refresh(account_id: &str, role: &str, permissions: Vec<String>, refresh_hours: i64) -> Self {
|
||||
pub fn new_refresh(account_id: &str, role: &str, permissions: Vec<String>, refresh_hours: i64, pwv: u32) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
jti: Some(uuid::Uuid::new_v4().to_string()),
|
||||
@@ -48,6 +56,7 @@ impl Claims {
|
||||
role: role.to_string(),
|
||||
permissions,
|
||||
token_type: "refresh".to_string(),
|
||||
pwv,
|
||||
iat: now.timestamp(),
|
||||
exp: (now + Duration::hours(refresh_hours)).timestamp(),
|
||||
}
|
||||
@@ -61,8 +70,9 @@ pub fn create_token(
|
||||
permissions: Vec<String>,
|
||||
secret: &str,
|
||||
expiration_hours: i64,
|
||||
pwv: u32,
|
||||
) -> SaasResult<String> {
|
||||
let claims = Claims::new_access(account_id, role, permissions, expiration_hours);
|
||||
let claims = Claims::new_access(account_id, role, permissions, expiration_hours, pwv);
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
@@ -78,8 +88,9 @@ pub fn create_refresh_token(
|
||||
permissions: Vec<String>,
|
||||
secret: &str,
|
||||
refresh_hours: i64,
|
||||
pwv: u32,
|
||||
) -> SaasResult<String> {
|
||||
let claims = Claims::new_refresh(account_id, role, permissions, refresh_hours);
|
||||
let claims = Claims::new_refresh(account_id, role, permissions, refresh_hours, pwv);
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
@@ -137,10 +148,11 @@ pub fn create_token_pair(
|
||||
secret: &str,
|
||||
access_hours: i64,
|
||||
refresh_hours: i64,
|
||||
pwv: u32,
|
||||
) -> SaasResult<TokenPair> {
|
||||
Ok(TokenPair {
|
||||
access_token: create_token(account_id, role, permissions.clone(), secret, access_hours)?,
|
||||
refresh_token: create_refresh_token(account_id, role, permissions, secret, refresh_hours)?,
|
||||
access_token: create_token(account_id, role, permissions.clone(), secret, access_hours, pwv)?,
|
||||
refresh_token: create_refresh_token(account_id, role, permissions, secret, refresh_hours, pwv)?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -155,7 +167,7 @@ mod tests {
|
||||
let token = create_token(
|
||||
"account-123", "admin",
|
||||
vec!["model:read".to_string()],
|
||||
TEST_SECRET, 24,
|
||||
TEST_SECRET, 24, 1,
|
||||
).unwrap();
|
||||
|
||||
let claims = verify_token(&token, TEST_SECRET).unwrap();
|
||||
@@ -164,6 +176,7 @@ mod tests {
|
||||
assert_eq!(claims.permissions, vec!["model:read"]);
|
||||
assert!(claims.jti.is_some());
|
||||
assert_eq!(claims.token_type, "access");
|
||||
assert_eq!(claims.pwv, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -174,15 +187,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_wrong_secret() {
|
||||
let token = create_token("account-123", "admin", vec![], TEST_SECRET, 24).unwrap();
|
||||
let token = create_token("account-123", "admin", vec![], TEST_SECRET, 24, 1).unwrap();
|
||||
let result = verify_token(&token, "wrong-secret");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_refresh_token_has_different_jti() {
|
||||
let access = create_token("acct-1", "user", vec![], TEST_SECRET, 1).unwrap();
|
||||
let refresh = create_refresh_token("acct-1", "user", vec![], TEST_SECRET, 168).unwrap();
|
||||
let access = create_token("acct-1", "user", vec![], TEST_SECRET, 1, 1).unwrap();
|
||||
let refresh = create_refresh_token("acct-1", "user", vec![], TEST_SECRET, 168, 1).unwrap();
|
||||
|
||||
let access_claims = verify_token(&access, TEST_SECRET).unwrap();
|
||||
let refresh_claims = verify_token(&refresh, TEST_SECRET).unwrap();
|
||||
|
||||
@@ -130,15 +130,39 @@ pub async fn auth_middleware(
|
||||
verify_api_token(&state, token, client_ip.clone()).await
|
||||
} else {
|
||||
// JWT 路径
|
||||
let verify_result = jwt::verify_token(token, state.jwt_secret.expose_secret());
|
||||
verify_result
|
||||
.map(|claims| AuthContext {
|
||||
account_id: claims.sub,
|
||||
role: claims.role,
|
||||
permissions: claims.permissions,
|
||||
client_ip,
|
||||
})
|
||||
.map_err(|_| SaasError::Unauthorized)
|
||||
match jwt::verify_token(token, state.jwt_secret.expose_secret()) {
|
||||
Ok(claims) => {
|
||||
// H1: 验证 password_version — 密码变更后旧 token 失效
|
||||
let pwv_row: Option<(i32,)> = sqlx::query_as(
|
||||
"SELECT password_version FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&claims.sub)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
match pwv_row {
|
||||
Some((current_pwv,)) if (current_pwv as u32) == claims.pwv => {
|
||||
Ok(AuthContext {
|
||||
account_id: claims.sub,
|
||||
role: claims.role,
|
||||
permissions: claims.permissions,
|
||||
client_ip,
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
account_id = %claims.sub,
|
||||
token_pwv = claims.pwv,
|
||||
"Token rejected: password_version mismatch or account not found"
|
||||
);
|
||||
Err(SaasError::Unauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => Err(SaasError::Unauthorized),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(SaasError::Unauthorized)
|
||||
|
||||
@@ -3,10 +3,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use secrecy::SecretString;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use secrecy::ExposeSecret;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use sha2::Digest;
|
||||
|
||||
/// SaaS 服务器完整配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -287,11 +283,9 @@ impl SaaSConfig {
|
||||
}
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
// 生产环境: 使用 JWT 密钥的 SHA-256 哈希作为加密密钥
|
||||
tracing::warn!("ZCLAW_TOTP_ENCRYPTION_KEY not set, deriving from JWT secret");
|
||||
let jwt = self.jwt_secret()?;
|
||||
let hash = sha2::Sha256::digest(jwt.expose_secret().as_bytes());
|
||||
Ok(hash.into())
|
||||
anyhow::bail!(
|
||||
"生产环境必须设置 ZCLAW_TOTP_ENCRYPTION_KEY 环境变量 (64 个十六进制字符, 32 字节)"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,32 @@ use aes_gcm::aead::rand_core::RngCore;
|
||||
use aes_gcm::{Aes256Gcm, Nonce};
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
|
||||
/// 启动时迁移所有旧格式 TOTP secret(明文或固定 nonce → 随机 nonce `enc:` 格式)
|
||||
///
|
||||
/// 查找 `totp_secret IS NOT NULL AND totp_secret != '' AND totp_secret NOT LIKE 'enc:%'` 的行,
|
||||
/// 用当前 AES-256-GCM 密钥加密后写回。
|
||||
pub async fn migrate_legacy_totp_secrets(pool: &sqlx::PgPool, enc_key: &[u8; 32]) -> anyhow::Result<u32> {
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(
|
||||
"SELECT id, totp_secret FROM accounts WHERE totp_secret IS NOT NULL AND totp_secret != '' AND totp_secret NOT LIKE 'enc:%'"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let count = rows.len() as u32;
|
||||
for (account_id, plaintext_secret) in &rows {
|
||||
let encrypted = encrypt_value(plaintext_secret, enc_key)?;
|
||||
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
|
||||
.bind(&encrypted)
|
||||
.bind(account_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
if count > 0 {
|
||||
tracing::info!("Migrated {} legacy TOTP secrets to encrypted format", count);
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// 加密值的前缀标识
|
||||
pub const ENCRYPTED_PREFIX: &str = "enc:";
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
pub mod common;
|
||||
pub mod config;
|
||||
pub mod crypto;
|
||||
pub mod cache;
|
||||
pub mod db;
|
||||
pub mod error;
|
||||
pub mod middleware;
|
||||
|
||||
@@ -40,6 +40,44 @@ async fn main() -> anyhow::Result<()> {
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?;
|
||||
|
||||
// Restore rate limit counts from DB so limits survive server restarts
|
||||
{
|
||||
let rows: Vec<(String, i64)> = sqlx::query_as(
|
||||
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key"
|
||||
)
|
||||
.fetch_all(&db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut restored_count = 0usize;
|
||||
for (key, count) in rows {
|
||||
let mut entries = Vec::new();
|
||||
// Approximate: insert count timestamps at "now" — the DashMap will
|
||||
// expire them naturally via the retain() call in the middleware.
|
||||
// This is intentionally approximate; exact window alignment is not
|
||||
// required for rate limiting correctness.
|
||||
for _ in 0..count as usize {
|
||||
entries.push(std::time::Instant::now());
|
||||
}
|
||||
state.rate_limit_entries.insert(key, entries);
|
||||
restored_count += 1;
|
||||
}
|
||||
info!("Restored rate limit state from DB: {} keys", restored_count);
|
||||
}
|
||||
|
||||
// 迁移旧格式 TOTP secret(明文 → 加密 enc: 格式)
|
||||
{
|
||||
let config_for_migration = state.config.read().await;
|
||||
if let Ok(enc_key) = config_for_migration.totp_encryption_key() {
|
||||
drop(config_for_migration);
|
||||
if let Err(e) = zclaw_saas::crypto::migrate_legacy_totp_secrets(&db, &enc_key).await {
|
||||
tracing::warn!("TOTP legacy migration check failed: {}", e);
|
||||
}
|
||||
} else {
|
||||
drop(config_for_migration);
|
||||
}
|
||||
}
|
||||
|
||||
// 启动声明式 Scheduler(从 TOML 配置读取定时任务)
|
||||
let scheduler_config = &config.scheduler;
|
||||
zclaw_saas::scheduler::start_scheduler(scheduler_config, db.clone(), state.worker_dispatcher.clone_ref());
|
||||
|
||||
@@ -74,17 +74,17 @@ pub async fn rate_limit_middleware(
|
||||
let window_start = now - std::time::Duration::from_secs(60);
|
||||
|
||||
// DashMap 操作限定在作用域块内,确保 RefMut(持有 parking_lot 锁)在 await 前释放
|
||||
let blocked = {
|
||||
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
|
||||
let (blocked, should_persist) = {
|
||||
let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new);
|
||||
entries.retain(|&time| time > window_start);
|
||||
|
||||
if entries.len() >= rate_limit {
|
||||
true
|
||||
(true, false)
|
||||
} else {
|
||||
entries.push(now);
|
||||
false
|
||||
(false, true)
|
||||
}
|
||||
}; // ← RefMut 在此处 drop,释放 parking_lot shard 锁
|
||||
}; // <- RefMut 在此处 drop,释放 parking_lot shard 锁
|
||||
|
||||
if blocked {
|
||||
return SaasError::RateLimited(format!(
|
||||
@@ -93,6 +93,19 @@ pub async fn rate_limit_middleware(
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
// Write-through to DB for persistence across restarts (fire-and-forget)
|
||||
if should_persist {
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
|
||||
)
|
||||
.bind(&key)
|
||||
.execute(&db)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
@@ -163,15 +176,15 @@ pub async fn public_rate_limit_middleware(
|
||||
let window_start = now - std::time::Duration::from_secs(window_secs);
|
||||
|
||||
// DashMap 操作限定在作用域块内,确保 RefMut 在 await 前释放
|
||||
let blocked = {
|
||||
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
|
||||
let (blocked, should_persist) = {
|
||||
let mut entries = state.rate_limit_entries.entry(key.clone()).or_insert_with(Vec::new);
|
||||
entries.retain(|&time| time > window_start);
|
||||
|
||||
if entries.len() >= limit {
|
||||
true
|
||||
(true, false)
|
||||
} else {
|
||||
entries.push(now);
|
||||
false
|
||||
(false, true)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -179,6 +192,19 @@ pub async fn public_rate_limit_middleware(
|
||||
return SaasError::RateLimited(error_msg.into()).into_response();
|
||||
}
|
||||
|
||||
// Write-through to DB for persistence across restarts (fire-and-forget)
|
||||
if should_persist {
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
|
||||
)
|
||||
.bind(&key)
|
||||
.execute(&db)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,39 @@ use crate::auth::handlers::{log_operation, check_permission};
|
||||
use crate::common::PaginatedResponse;
|
||||
use super::{types::*, service};
|
||||
|
||||
/// 验证 Provider base_url: 必须 HTTPS (开发环境允许 HTTP),不能指向本地/私有地址
|
||||
fn validate_provider_base_url(url: &str) -> Result<(), String> {
|
||||
if url.is_empty() {
|
||||
return Err("base_url 不能为空".into());
|
||||
}
|
||||
if let Ok(parsed) = url::Url::parse(url) {
|
||||
let scheme = parsed.scheme();
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV").map(|v| v == "true").unwrap_or(false);
|
||||
if scheme != "https" && !(is_dev && scheme == "http") {
|
||||
return Err(format!("base_url 必须使用 HTTPS{}", if is_dev { "(开发环境允许 HTTP)" } else { "" }));
|
||||
}
|
||||
if let Some(host) = parsed.host_str() {
|
||||
let blocked = ["localhost", "127.0.0.1", "0.0.0.0", "metadata.google.internal"];
|
||||
if blocked.contains(&host) {
|
||||
return Err("base_url 不能指向本地或内部地址".into());
|
||||
}
|
||||
for prefix in &["10.", "172.16.", "192.168.", "169.254."] {
|
||||
if host.starts_with(prefix) {
|
||||
return Err("base_url 不能指向私有 IP 地址".into());
|
||||
}
|
||||
}
|
||||
for suffix in &[".localhost", ".internal", ".local"] {
|
||||
if host.ends_with(suffix) {
|
||||
return Err(format!("base_url 域名不能以 {} 结尾", suffix));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Err("base_url 格式无效".into())
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Providers ============
|
||||
|
||||
/// GET /api/v1/providers?enabled=true&page=1&page_size=20
|
||||
@@ -41,6 +74,7 @@ pub async fn create_provider(
|
||||
Json(req): Json<CreateProviderRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
validate_provider_base_url(&req.base_url).map_err(|e| SaasError::InvalidInput(e))?;
|
||||
let config = state.config.read().await;
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
@@ -59,6 +93,9 @@ pub async fn update_provider(
|
||||
Json(req): Json<UpdateProviderRequest>,
|
||||
) -> SaasResult<Json<ProviderInfo>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
if let Some(ref base_url) = req.base_url {
|
||||
validate_provider_base_url(base_url).map_err(|e| SaasError::InvalidInput(e))?;
|
||||
}
|
||||
let config = state.config.read().await;
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
@@ -31,7 +31,7 @@ pub struct AccountAuthRow {
|
||||
pub llm_routing: String,
|
||||
}
|
||||
|
||||
/// Login 一次性查询行(合并用户信息 + password_hash + totp_secret)
|
||||
/// Login 一次性查询行(合并用户信息 + password_hash + totp_secret + 安全字段)
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct AccountLoginRow {
|
||||
pub id: String,
|
||||
@@ -45,6 +45,9 @@ pub struct AccountLoginRow {
|
||||
pub totp_secret: Option<String>,
|
||||
pub created_at: String,
|
||||
pub llm_routing: String,
|
||||
pub password_version: i32,
|
||||
pub failed_login_count: i32,
|
||||
pub locked_until: Option<String>,
|
||||
}
|
||||
|
||||
/// operation_logs 表行
|
||||
|
||||
@@ -8,6 +8,7 @@ use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use crate::config::SaaSConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
use crate::cache::AppCache;
|
||||
|
||||
/// 全局应用状态,通过 Axum State 共享
|
||||
#[derive(Clone)]
|
||||
@@ -30,6 +31,8 @@ pub struct AppState {
|
||||
pub worker_dispatcher: WorkerDispatcher,
|
||||
/// 优雅停机令牌 — 触发后所有 SSE 流和长连接应立即终止
|
||||
pub shutdown_token: CancellationToken,
|
||||
/// 应用缓存: Model/Provider/队列计数器
|
||||
pub cache: AppCache,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -46,6 +49,7 @@ impl AppState {
|
||||
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
|
||||
worker_dispatcher,
|
||||
shutdown_token,
|
||||
cache: AppCache::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user