diff --git a/crates/zclaw-saas/src/auth/handlers.rs b/crates/zclaw-saas/src/auth/handlers.rs index 1ec448d..b8a0af5 100644 --- a/crates/zclaw-saas/src/auth/handlers.rs +++ b/crates/zclaw-saas/src/auth/handlers.rs @@ -36,7 +36,7 @@ pub async fn register( let password_hash = hash_password(&req.password)?; let account_id = uuid::Uuid::new_v4().to_string(); - let role = req.role.unwrap_or_else(|| "user".into()); + let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配 let display_name = req.display_name.unwrap_or_default(); let now = chrono::Utc::now().to_rfc3339(); diff --git a/crates/zclaw-saas/src/auth/types.rs b/crates/zclaw-saas/src/auth/types.rs index babc48e..d50cdb6 100644 --- a/crates/zclaw-saas/src/auth/types.rs +++ b/crates/zclaw-saas/src/auth/types.rs @@ -24,7 +24,6 @@ pub struct RegisterRequest { pub email: String, pub password: String, pub display_name: Option, - pub role: Option, } /// 公开账号信息 (无敏感数据) diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index a314f83..4ed16b3 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -37,11 +37,19 @@ fn build_router(state: AppState) -> axum::Router { use axum::http::HeaderValue; let cors = { let config = state.config.blocking_read(); + let is_dev = std::env::var("ZCLAW_SAAS_DEV") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); if config.server.cors_origins.is_empty() { - CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any) + if is_dev { + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any) + } else { + tracing::error!("生产环境必须配置 server.cors_origins,不能使用 allow_origin(Any)"); + panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。"); + } } else { let origins: Vec = config.server.cors_origins.iter() .filter_map(|o: &String| o.parse::().ok()) diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index 5a796c1..bcac77a 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -228,12 +228,65 @@ fn validate_provider_url(url: &str) -> SaasResult<()> { Some(h) => h, None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())), }; - let blocked = ["127.0.0.1", "0.0.0.0", "localhost", "::1", "169.254.169.254", "metadata.google.internal"]; - for blocked_host in &blocked { - if host == *blocked_host || host.ends_with(&format!(".{}", blocked_host)) { + + // 精确匹配的阻止列表 + let blocked_exact = [ + "127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1", + "0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal", + "10.0.0.1", "172.16.0.1", "192.168.0.1", + ]; + if blocked_exact.contains(&host) { + return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host))); + } + + // 后缀匹配 (阻止子域名) + let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"]; + for suffix in &blocked_suffixes { + if host.ends_with(&format!(".{}", suffix)) { return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host))); } } + // 阻止 IPv4 私有网段 (通过解析 IP) + if let Ok(ip) = host.parse::() { + if is_private_ip(&ip) { + return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host))); + } + } + + // 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1) + if host.parse::().is_ok() { + return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host))); + } + Ok(()) } + +/// 检查 IP 是否属于私有/内网地址范围 +fn is_private_ip(ip: &std::net::IpAddr) -> bool { + match ip { + std::net::IpAddr::V4(v4) => { + let octets = v4.octets(); + // 10.0.0.0/8 + octets[0] == 10 + // 172.16.0.0/12 + || (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) + // 192.168.0.0/16 + || (octets[0] == 192 && octets[1] == 168) + // 127.0.0.0/8 (loopback) + || octets[0] == 127 + // 169.254.0.0/16 (link-local) + || (octets[0] == 169 && octets[1] == 254) + // 0.0.0.0/8 + || octets[0] == 0 + } + std::net::IpAddr::V6(v6) => { + // ::1 (loopback) + v6.is_loopback() + // ::ffff:x.x.x.x (IPv6-mapped IPv4) + || v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4))) + // fe80::/10 (link-local) + || (v6.segments()[0] & 0xffc0) == 0xfe80 + } + } +}