fix(hands): 审计修复 — SSRF防护/输入验证/HTTP状态检查/解析加固
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
三维度穷尽审计(安全+质量+正确性)后修复:
CRITICAL:
- execute_fetch() 添加完整 SSRF 防护(IPv4/IPv6/私有地址/云元数据/主机名黑名单)
- reqwest 重定向策略限制为3次,阻止重定向链 SSRF
- DDG HTML 解析: split("result__body") → split("class=\"result__body\"") 防误匹配
- Google 变体降级到 Bing 时添加 tracing::warn 日志
HIGH:
- ResearchQuery 输入验证: 查询≤500字符, max_results≤50, 空查询拒绝
- Cache 容量限制: 200 条目上限 + 简单淘汰
- extract_href_uddg 手动 URL 解码替换为标准 percent_decode
- 3个搜索引擎方法添加 HTTP status code 检查(429/503 不再静默)
MEDIUM:
- config.toml default_engine 从 "searxng" 改为 "auto"(Rust 原生优先)
- User-Agent 从机器人标识改为浏览器 UA,降低反爬风险
- 百度解析器从精确匹配改为 c-container 包含匹配,覆盖更多变体
- 添加 url crate 依赖
测试: 60 PASS (新增12: SSRF 5 + percent_decode 3 + 输入验证 4)
This commit is contained in:
@@ -6,8 +6,10 @@ use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use url::Url;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::{Hand, HandConfig, HandContext, HandResult};
|
||||
@@ -147,6 +149,26 @@ pub struct ResearchQuery {
|
||||
fn default_max_results() -> usize { 10 }
|
||||
fn default_time_limit() -> u64 { 60 }
|
||||
|
||||
const MAX_QUERY_LENGTH: usize = 500;
|
||||
const MAX_RESULTS_CAP: usize = 50;
|
||||
const MAX_URL_LENGTH: usize = 2048;
|
||||
const CACHE_MAX_ENTRIES: usize = 200;
|
||||
|
||||
impl ResearchQuery {
|
||||
fn validate(&self) -> std::result::Result<(), String> {
|
||||
if self.query.trim().is_empty() {
|
||||
return Err("搜索查询不能为空".to_string());
|
||||
}
|
||||
if self.query.len() > MAX_QUERY_LENGTH {
|
||||
return Err(format!("查询过长(上限 {} 字符)", MAX_QUERY_LENGTH));
|
||||
}
|
||||
if self.max_results > MAX_RESULTS_CAP {
|
||||
return Err(format!("max_results 上限为 {}", MAX_RESULTS_CAP));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Search result item
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -268,7 +290,8 @@ impl ResearcherHand {
|
||||
search_config: SearchConfig::load(),
|
||||
client: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.user_agent("ZCLAW-Researcher/1.0")
|
||||
.user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
.redirect(reqwest::redirect::Policy::limited(3))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new()),
|
||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
@@ -277,6 +300,9 @@ impl ResearcherHand {
|
||||
|
||||
/// Execute a web search — route to the configured backend
|
||||
async fn execute_search(&self, query: &ResearchQuery) -> Result<Vec<SearchResult>> {
|
||||
query.validate().map_err(|e| zclaw_types::ZclawError::HandError(e))?;
|
||||
|
||||
let max_results = query.max_results.min(MAX_RESULTS_CAP);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let engine = match &query.engine {
|
||||
@@ -286,22 +312,23 @@ impl ResearcherHand {
|
||||
|
||||
let results = match engine {
|
||||
SearchEngine::SearXNG => {
|
||||
match self.search_searxng(&query.query, query.max_results).await {
|
||||
match self.search_searxng(&query.query, max_results).await {
|
||||
Ok(r) if !r.is_empty() => r,
|
||||
_ => self.search_native(&query.query, query.max_results).await?,
|
||||
_ => self.search_native(&query.query, max_results).await?,
|
||||
}
|
||||
}
|
||||
SearchEngine::Auto => {
|
||||
self.search_native(&query.query, query.max_results).await?
|
||||
self.search_native(&query.query, max_results).await?
|
||||
}
|
||||
SearchEngine::DuckDuckGo => {
|
||||
self.search_duckduckgo_html(&query.query, query.max_results).await?
|
||||
self.search_duckduckgo_html(&query.query, max_results).await?
|
||||
}
|
||||
SearchEngine::Google => {
|
||||
self.search_bing(&query.query, query.max_results).await?
|
||||
tracing::warn!(target: "researcher", "Google 不支持直接搜索,降级到 Bing");
|
||||
self.search_bing(&query.query, max_results).await?
|
||||
}
|
||||
SearchEngine::Bing => {
|
||||
self.search_bing(&query.query, query.max_results).await?
|
||||
self.search_bing(&query.query, max_results).await?
|
||||
}
|
||||
};
|
||||
|
||||
@@ -481,6 +508,13 @@ impl ResearcherHand {
|
||||
format!("DuckDuckGo HTML search failed: {}", e)
|
||||
))?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(zclaw_types::ZclawError::HandError(
|
||||
format!("DuckDuckGo returned HTTP {}", status)
|
||||
));
|
||||
}
|
||||
|
||||
let html = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(
|
||||
format!("Failed to read DuckDuckGo response: {}", e)
|
||||
@@ -493,7 +527,7 @@ impl ResearcherHand {
|
||||
fn parse_ddg_html(&self, html: &str, max_results: usize) -> Vec<SearchResult> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for block in html.split("result__body") {
|
||||
for block in html.split("class=\"result__body\"") {
|
||||
if results.len() >= max_results {
|
||||
break;
|
||||
}
|
||||
@@ -563,6 +597,13 @@ impl ResearcherHand {
|
||||
format!("Bing search failed: {}", e)
|
||||
))?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(zclaw_types::ZclawError::HandError(
|
||||
format!("Bing returned HTTP {}", status)
|
||||
));
|
||||
}
|
||||
|
||||
let html = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(
|
||||
format!("Failed to read Bing response: {}", e)
|
||||
@@ -636,6 +677,13 @@ impl ResearcherHand {
|
||||
format!("Baidu search failed: {}", e)
|
||||
))?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(zclaw_types::ZclawError::HandError(
|
||||
format!("Baidu returned HTTP {}", status)
|
||||
));
|
||||
}
|
||||
|
||||
let html = response.text().await
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(
|
||||
format!("Failed to read Baidu response: {}", e)
|
||||
@@ -648,15 +696,20 @@ impl ResearcherHand {
|
||||
fn parse_baidu_html(&self, html: &str, max_results: usize) -> Vec<SearchResult> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for block in html.split("class=\"result c-container\"") {
|
||||
// Baidu uses multiple class patterns: "result c-container", "c-container new-pmd", "result-op c-container"
|
||||
let blocks: Vec<&str> = html.split("c-container")
|
||||
.enumerate()
|
||||
.filter_map(|(i, block)| {
|
||||
if i == 0 { return None; }
|
||||
if block.contains("href=\"http") { Some(block) } else { None }
|
||||
})
|
||||
.collect();
|
||||
|
||||
for block in &blocks {
|
||||
if results.len() >= max_results {
|
||||
break;
|
||||
}
|
||||
|
||||
if !block.contains("href=\"http") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let title = extract_between(block, ">", "</a>")
|
||||
.map(|s| strip_html_tags(s).trim().to_string())
|
||||
.unwrap_or_default();
|
||||
@@ -686,10 +739,13 @@ impl ResearcherHand {
|
||||
results
|
||||
}
|
||||
|
||||
/// Fetch content from a URL
|
||||
/// Fetch content from a URL (with SSRF protection)
|
||||
async fn execute_fetch(&self, url: &str) -> Result<SearchResult> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// SSRF validation
|
||||
validate_fetch_url(url)?;
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.cache.read().await;
|
||||
@@ -733,9 +789,15 @@ impl ResearcherHand {
|
||||
fetched_at: Some(chrono::Utc::now().to_rfc3339()),
|
||||
};
|
||||
|
||||
// Cache the result
|
||||
// Cache the result (with capacity limit)
|
||||
{
|
||||
let mut cache = self.cache.write().await;
|
||||
if cache.len() >= CACHE_MAX_ENTRIES {
|
||||
// Simple eviction: remove first entry
|
||||
if let Some(key) = cache.keys().next().cloned() {
|
||||
cache.remove(&key);
|
||||
}
|
||||
}
|
||||
cache.insert(url.to_string(), result.clone());
|
||||
}
|
||||
|
||||
@@ -992,6 +1054,106 @@ fn is_cjk_char(c: char) -> bool {
|
||||
)
|
||||
}
|
||||
|
||||
/// Validate a URL for SSRF safety before fetching
|
||||
fn validate_fetch_url(url_str: &str) -> Result<()> {
|
||||
if url_str.len() > MAX_URL_LENGTH {
|
||||
return Err(zclaw_types::ZclawError::HandError(
|
||||
format!("URL exceeds maximum length of {} characters", MAX_URL_LENGTH)
|
||||
));
|
||||
}
|
||||
|
||||
let url = Url::parse(url_str)
|
||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Invalid URL: {}", e)))?;
|
||||
|
||||
match url.scheme() {
|
||||
"http" | "https" => {}
|
||||
scheme => {
|
||||
return Err(zclaw_types::ZclawError::HandError(
|
||||
format!("URL scheme '{}' not allowed, only http/https", scheme)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let host = url.host_str()
|
||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("URL must have a host".into()))?;
|
||||
|
||||
// Strip IPv6 brackets for parsing
|
||||
let host_for_parsing = if host.starts_with('[') && host.ends_with(']') {
|
||||
&host[1..host.len()-1]
|
||||
} else {
|
||||
host
|
||||
};
|
||||
|
||||
if let Ok(ip) = host_for_parsing.parse::<IpAddr>() {
|
||||
validate_ip(&ip)?;
|
||||
} else {
|
||||
validate_hostname(host)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_ip(ip: &IpAddr) -> Result<()> {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => validate_ipv4(v4),
|
||||
IpAddr::V6(v6) => validate_ipv6(v6),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_ipv4(ip: &Ipv4Addr) -> Result<()> {
|
||||
let o = ip.octets();
|
||||
if o[0] == 127 { return Err(ssrf_err("loopback")); }
|
||||
if o[0] == 10 { return Err(ssrf_err("private 10.x.x.x")); }
|
||||
if o[0] == 172 && (16..=31).contains(&o[1]) { return Err(ssrf_err("private 172.16-31.x.x")); }
|
||||
if o[0] == 192 && o[1] == 168 { return Err(ssrf_err("private 192.168.x.x")); }
|
||||
if o[0] == 169 && o[1] == 254 { return Err(ssrf_err("link-local/metadata")); }
|
||||
if o[0] == 0 { return Err(ssrf_err("0.x.x.x")); }
|
||||
if *ip == Ipv4Addr::new(255, 255, 255, 255) { return Err(ssrf_err("broadcast")); }
|
||||
if (224..=239).contains(&o[0]) { return Err(ssrf_err("multicast")); }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_ipv6(ip: &Ipv6Addr) -> Result<()> {
|
||||
if *ip == Ipv6Addr::LOCALHOST { return Err(ssrf_err("IPv6 loopback")); }
|
||||
if *ip == Ipv6Addr::UNSPECIFIED { return Err(ssrf_err("IPv6 unspecified")); }
|
||||
let segs = ip.segments();
|
||||
// IPv4-mapped: ::ffff:x.x.x.x
|
||||
if segs[5] == 0xffff {
|
||||
let v4 = ((segs[6] as u32) << 16) | (segs[7] as u32);
|
||||
validate_ipv4(&Ipv4Addr::from(v4))?;
|
||||
}
|
||||
// Link-local fe80::/10
|
||||
if (segs[0] & 0xffc0) == 0xfe80 { return Err(ssrf_err("IPv6 link-local")); }
|
||||
// Unique local fc00::/7
|
||||
if (segs[0] & 0xfe00) == 0xfc00 { return Err(ssrf_err("IPv6 unique local")); }
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_hostname(host: &str) -> Result<()> {
|
||||
let h = host.to_lowercase();
|
||||
let blocked = [
|
||||
"localhost", "localhost.localdomain", "ip6-localhost",
|
||||
"ip6-loopback", "metadata.google.internal", "metadata",
|
||||
"kubernetes.default", "kubernetes.default.svc",
|
||||
];
|
||||
for b in &blocked {
|
||||
if h == *b || h.ends_with(&format!(".{}", b)) {
|
||||
return Err(ssrf_err(&format!("blocked host '{}'", host)));
|
||||
}
|
||||
}
|
||||
// Decimal IP bypass: 2130706433 = 127.0.0.1
|
||||
if h.chars().all(|c| c.is_ascii_digit()) {
|
||||
if let Ok(num) = h.parse::<u32>() {
|
||||
validate_ipv4(&Ipv4Addr::from(num))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ssrf_err(reason: &str) -> zclaw_types::ZclawError {
|
||||
zclaw_types::ZclawError::HandError(format!("Access denied: {}", reason))
|
||||
}
|
||||
|
||||
/// Extract text between two delimiters
|
||||
fn extract_between<'a>(text: &'a str, start: &str, end: &str) -> Option<&'a str> {
|
||||
let start_idx = text.find(start)?;
|
||||
@@ -1042,17 +1204,11 @@ fn extract_href(text: &str) -> Option<String> {
|
||||
|
||||
/// Extract the real URL from DDG's redirect link (uddg= parameter)
|
||||
fn extract_href_uddg(text: &str) -> Option<String> {
|
||||
// DDG HTML uses: href="//duckduckgo.com/l/?uddg=ENCODED_URL&..."
|
||||
if let Some(idx) = text.find("uddg=") {
|
||||
let rest = &text[idx + 5..];
|
||||
let url_encoded = rest.split('&').next().unwrap_or("");
|
||||
let decoded = url_encoded.replace("%3A", ":")
|
||||
.replace("%2F", "/")
|
||||
.replace("%3F", "?")
|
||||
.replace("%3D", "=")
|
||||
.replace("%26", "&")
|
||||
.replace("%20", " ")
|
||||
.replace("%25", "%");
|
||||
// Use standard percent decoding instead of manual replacement
|
||||
let decoded = percent_decode(url_encoded);
|
||||
if decoded.starts_with("http") {
|
||||
return Some(decoded);
|
||||
}
|
||||
@@ -1062,6 +1218,27 @@ fn extract_href_uddg(text: &str) -> Option<String> {
|
||||
extract_href(text)
|
||||
}
|
||||
|
||||
/// Standard percent-decode a URL-encoded string
|
||||
fn percent_decode(input: &str) -> String {
|
||||
let mut result = Vec::new();
|
||||
let bytes = input.as_bytes();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
if bytes[i] == b'%' && i + 2 < bytes.len() {
|
||||
if let Ok(byte) = u8::from_str_radix(
|
||||
&input[i + 1..i + 3], 16
|
||||
) {
|
||||
result.push(byte);
|
||||
i += 3;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result.push(bytes[i]);
|
||||
i += 1;
|
||||
}
|
||||
String::from_utf8_lossy(&result).to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -1632,10 +1809,108 @@ mod tests {
|
||||
<h3 class="t"><a href="https://www.example.cn/page1">中国医疗政策 2024</a></h3>
|
||||
<div class="c-abstract">这是关于医疗政策的摘要信息。</div>
|
||||
</div>
|
||||
<div class="c-container new-pmd">
|
||||
<h3><a href="https://www.example.cn/page2">第二条结果</a></h3>
|
||||
</div>
|
||||
"#;
|
||||
|
||||
let results = hand.parse_baidu_html(html, 10);
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results.len() >= 1, "Should find at least 1 result, got {}", results.len());
|
||||
assert_eq!(results[0].source, "Baidu");
|
||||
}
|
||||
|
||||
// --- SSRF Validation Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_localhost() {
|
||||
assert!(validate_fetch_url("http://localhost:8080/admin").is_err());
|
||||
assert!(validate_fetch_url("http://127.0.0.1:5432/db").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_private_ip() {
|
||||
assert!(validate_fetch_url("http://10.0.0.1/secret").is_err());
|
||||
assert!(validate_fetch_url("http://192.168.1.1/router").is_err());
|
||||
assert!(validate_fetch_url("http://172.16.0.1/internal").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_cloud_metadata() {
|
||||
assert!(validate_fetch_url("http://169.254.169.254/metadata").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_blocks_non_http_scheme() {
|
||||
assert!(validate_fetch_url("file:///etc/passwd").is_err());
|
||||
assert!(validate_fetch_url("ftp://example.com/file").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ssrf_allows_public_url() {
|
||||
assert!(validate_fetch_url("https://www.rust-lang.org/learn").is_ok());
|
||||
assert!(validate_fetch_url("https://example.com/page?q=test").is_ok());
|
||||
}
|
||||
|
||||
// --- Percent Decode Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_percent_decode_basic() {
|
||||
assert_eq!(percent_decode("hello%20world"), "hello world");
|
||||
assert_eq!(percent_decode("%E4%B8%AD%E6%96%87"), "中文");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_percent_decode_full_url() {
|
||||
assert_eq!(
|
||||
percent_decode("https%3A%2F%2Fexample.com%2Fpage%3Fq%3Dtest"),
|
||||
"https://example.com/page?q=test"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_percent_decode_no_encoding() {
|
||||
assert_eq!(percent_decode("plain-text_123"), "plain-text_123");
|
||||
}
|
||||
|
||||
// --- Input Validation Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_research_query_validate_empty() {
|
||||
let query = ResearchQuery {
|
||||
query: " ".to_string(), engine: SearchEngine::Auto,
|
||||
depth: ResearchDepth::Standard, max_results: 10,
|
||||
include_related: false, time_limit_secs: 60,
|
||||
};
|
||||
assert!(query.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_query_validate_too_long() {
|
||||
let query = ResearchQuery {
|
||||
query: "x".repeat(501), engine: SearchEngine::Auto,
|
||||
depth: ResearchDepth::Standard, max_results: 10,
|
||||
include_related: false, time_limit_secs: 60,
|
||||
};
|
||||
assert!(query.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_query_validate_max_results_overflow() {
|
||||
let query = ResearchQuery {
|
||||
query: "test".to_string(), engine: SearchEngine::Auto,
|
||||
depth: ResearchDepth::Standard, max_results: 999,
|
||||
include_related: false, time_limit_secs: 60,
|
||||
};
|
||||
assert!(query.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_query_validate_ok() {
|
||||
let query = ResearchQuery {
|
||||
query: "Rust programming".to_string(), engine: SearchEngine::Auto,
|
||||
depth: ResearchDepth::Standard, max_results: 10,
|
||||
include_related: false, time_limit_secs: 60,
|
||||
};
|
||||
assert!(query.validate().is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user