fix(hands): 审计修复 — SSRF防护/输入验证/HTTP状态检查/解析加固
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
三维度穷尽审计(安全+质量+正确性)后修复:
CRITICAL:
- execute_fetch() 添加完整 SSRF 防护(IPv4/IPv6/私有地址/云元数据/主机名黑名单)
- reqwest 重定向策略限制为3次,阻止重定向链 SSRF
- DDG HTML 解析: split("result__body") → split("class=\"result__body\"") 防误匹配
- Google 变体降级到 Bing 时添加 tracing::warn 日志
HIGH:
- ResearchQuery 输入验证: 查询≤500字符, max_results≤50, 空查询拒绝
- Cache 容量限制: 200 条目上限 + 简单淘汰
- extract_href_uddg 手动 URL 解码替换为标准 percent_decode
- 3个搜索引擎方法添加 HTTP status code 检查(429/503 不再静默)
MEDIUM:
- config.toml default_engine 从 "searxng" 改为 "auto"(Rust 原生优先)
- User-Agent 从机器人标识改为浏览器 UA,降低反爬风险
- 百度解析器从精确匹配改为 c-container 包含匹配,覆盖更多变体
- 添加 url crate 依赖
测试: 60 PASS (新增12: SSRF 5 + percent_decode 3 + 输入验证 4)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -9493,6 +9493,7 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"toml 0.8.2",
|
"toml 0.8.2",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"url",
|
||||||
"uuid",
|
"uuid",
|
||||||
"zclaw-runtime",
|
"zclaw-runtime",
|
||||||
"zclaw-types",
|
"zclaw-types",
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ timeout = "30s"
|
|||||||
[tools.web]
|
[tools.web]
|
||||||
[tools.web.search]
|
[tools.web.search]
|
||||||
enabled = true
|
enabled = true
|
||||||
default_engine = "searxng"
|
default_engine = "auto"
|
||||||
max_results = 10
|
max_results = 10
|
||||||
searxng_url = "http://localhost:8888"
|
searxng_url = "http://localhost:8888"
|
||||||
searxng_timeout = 15
|
searxng_timeout = 15
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ thiserror = { workspace = true }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
|
url = { workspace = true }
|
||||||
base64 = { workspace = true }
|
base64 = { workspace = true }
|
||||||
dirs = { workspace = true }
|
dirs = { workspace = true }
|
||||||
toml = { workspace = true }
|
toml = { workspace = true }
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ use async_trait::async_trait;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use url::Url;
|
||||||
use zclaw_types::Result;
|
use zclaw_types::Result;
|
||||||
|
|
||||||
use crate::{Hand, HandConfig, HandContext, HandResult};
|
use crate::{Hand, HandConfig, HandContext, HandResult};
|
||||||
@@ -147,6 +149,26 @@ pub struct ResearchQuery {
|
|||||||
fn default_max_results() -> usize { 10 }
|
fn default_max_results() -> usize { 10 }
|
||||||
fn default_time_limit() -> u64 { 60 }
|
fn default_time_limit() -> u64 { 60 }
|
||||||
|
|
||||||
|
const MAX_QUERY_LENGTH: usize = 500;
|
||||||
|
const MAX_RESULTS_CAP: usize = 50;
|
||||||
|
const MAX_URL_LENGTH: usize = 2048;
|
||||||
|
const CACHE_MAX_ENTRIES: usize = 200;
|
||||||
|
|
||||||
|
impl ResearchQuery {
|
||||||
|
fn validate(&self) -> std::result::Result<(), String> {
|
||||||
|
if self.query.trim().is_empty() {
|
||||||
|
return Err("搜索查询不能为空".to_string());
|
||||||
|
}
|
||||||
|
if self.query.len() > MAX_QUERY_LENGTH {
|
||||||
|
return Err(format!("查询过长(上限 {} 字符)", MAX_QUERY_LENGTH));
|
||||||
|
}
|
||||||
|
if self.max_results > MAX_RESULTS_CAP {
|
||||||
|
return Err(format!("max_results 上限为 {}", MAX_RESULTS_CAP));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Search result item
|
/// Search result item
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
@@ -268,7 +290,8 @@ impl ResearcherHand {
|
|||||||
search_config: SearchConfig::load(),
|
search_config: SearchConfig::load(),
|
||||||
client: reqwest::Client::builder()
|
client: reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.timeout(std::time::Duration::from_secs(30))
|
||||||
.user_agent("ZCLAW-Researcher/1.0")
|
.user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||||
|
.redirect(reqwest::redirect::Policy::limited(3))
|
||||||
.build()
|
.build()
|
||||||
.unwrap_or_else(|_| reqwest::Client::new()),
|
.unwrap_or_else(|_| reqwest::Client::new()),
|
||||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||||
@@ -277,6 +300,9 @@ impl ResearcherHand {
|
|||||||
|
|
||||||
/// Execute a web search — route to the configured backend
|
/// Execute a web search — route to the configured backend
|
||||||
async fn execute_search(&self, query: &ResearchQuery) -> Result<Vec<SearchResult>> {
|
async fn execute_search(&self, query: &ResearchQuery) -> Result<Vec<SearchResult>> {
|
||||||
|
query.validate().map_err(|e| zclaw_types::ZclawError::HandError(e))?;
|
||||||
|
|
||||||
|
let max_results = query.max_results.min(MAX_RESULTS_CAP);
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
let engine = match &query.engine {
|
let engine = match &query.engine {
|
||||||
@@ -286,22 +312,23 @@ impl ResearcherHand {
|
|||||||
|
|
||||||
let results = match engine {
|
let results = match engine {
|
||||||
SearchEngine::SearXNG => {
|
SearchEngine::SearXNG => {
|
||||||
match self.search_searxng(&query.query, query.max_results).await {
|
match self.search_searxng(&query.query, max_results).await {
|
||||||
Ok(r) if !r.is_empty() => r,
|
Ok(r) if !r.is_empty() => r,
|
||||||
_ => self.search_native(&query.query, query.max_results).await?,
|
_ => self.search_native(&query.query, max_results).await?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SearchEngine::Auto => {
|
SearchEngine::Auto => {
|
||||||
self.search_native(&query.query, query.max_results).await?
|
self.search_native(&query.query, max_results).await?
|
||||||
}
|
}
|
||||||
SearchEngine::DuckDuckGo => {
|
SearchEngine::DuckDuckGo => {
|
||||||
self.search_duckduckgo_html(&query.query, query.max_results).await?
|
self.search_duckduckgo_html(&query.query, max_results).await?
|
||||||
}
|
}
|
||||||
SearchEngine::Google => {
|
SearchEngine::Google => {
|
||||||
self.search_bing(&query.query, query.max_results).await?
|
tracing::warn!(target: "researcher", "Google 不支持直接搜索,降级到 Bing");
|
||||||
|
self.search_bing(&query.query, max_results).await?
|
||||||
}
|
}
|
||||||
SearchEngine::Bing => {
|
SearchEngine::Bing => {
|
||||||
self.search_bing(&query.query, query.max_results).await?
|
self.search_bing(&query.query, max_results).await?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -481,6 +508,13 @@ impl ResearcherHand {
|
|||||||
format!("DuckDuckGo HTML search failed: {}", e)
|
format!("DuckDuckGo HTML search failed: {}", e)
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(zclaw_types::ZclawError::HandError(
|
||||||
|
format!("DuckDuckGo returned HTTP {}", status)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let html = response.text().await
|
let html = response.text().await
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(
|
.map_err(|e| zclaw_types::ZclawError::HandError(
|
||||||
format!("Failed to read DuckDuckGo response: {}", e)
|
format!("Failed to read DuckDuckGo response: {}", e)
|
||||||
@@ -493,7 +527,7 @@ impl ResearcherHand {
|
|||||||
fn parse_ddg_html(&self, html: &str, max_results: usize) -> Vec<SearchResult> {
|
fn parse_ddg_html(&self, html: &str, max_results: usize) -> Vec<SearchResult> {
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
for block in html.split("result__body") {
|
for block in html.split("class=\"result__body\"") {
|
||||||
if results.len() >= max_results {
|
if results.len() >= max_results {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -563,6 +597,13 @@ impl ResearcherHand {
|
|||||||
format!("Bing search failed: {}", e)
|
format!("Bing search failed: {}", e)
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(zclaw_types::ZclawError::HandError(
|
||||||
|
format!("Bing returned HTTP {}", status)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let html = response.text().await
|
let html = response.text().await
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(
|
.map_err(|e| zclaw_types::ZclawError::HandError(
|
||||||
format!("Failed to read Bing response: {}", e)
|
format!("Failed to read Bing response: {}", e)
|
||||||
@@ -636,6 +677,13 @@ impl ResearcherHand {
|
|||||||
format!("Baidu search failed: {}", e)
|
format!("Baidu search failed: {}", e)
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(zclaw_types::ZclawError::HandError(
|
||||||
|
format!("Baidu returned HTTP {}", status)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let html = response.text().await
|
let html = response.text().await
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(
|
.map_err(|e| zclaw_types::ZclawError::HandError(
|
||||||
format!("Failed to read Baidu response: {}", e)
|
format!("Failed to read Baidu response: {}", e)
|
||||||
@@ -648,15 +696,20 @@ impl ResearcherHand {
|
|||||||
fn parse_baidu_html(&self, html: &str, max_results: usize) -> Vec<SearchResult> {
|
fn parse_baidu_html(&self, html: &str, max_results: usize) -> Vec<SearchResult> {
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
for block in html.split("class=\"result c-container\"") {
|
// Baidu uses multiple class patterns: "result c-container", "c-container new-pmd", "result-op c-container"
|
||||||
|
let blocks: Vec<&str> = html.split("c-container")
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, block)| {
|
||||||
|
if i == 0 { return None; }
|
||||||
|
if block.contains("href=\"http") { Some(block) } else { None }
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for block in &blocks {
|
||||||
if results.len() >= max_results {
|
if results.len() >= max_results {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if !block.contains("href=\"http") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let title = extract_between(block, ">", "</a>")
|
let title = extract_between(block, ">", "</a>")
|
||||||
.map(|s| strip_html_tags(s).trim().to_string())
|
.map(|s| strip_html_tags(s).trim().to_string())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
@@ -686,10 +739,13 @@ impl ResearcherHand {
|
|||||||
results
|
results
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fetch content from a URL
|
/// Fetch content from a URL (with SSRF protection)
|
||||||
async fn execute_fetch(&self, url: &str) -> Result<SearchResult> {
|
async fn execute_fetch(&self, url: &str) -> Result<SearchResult> {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
// SSRF validation
|
||||||
|
validate_fetch_url(url)?;
|
||||||
|
|
||||||
// Check cache first
|
// Check cache first
|
||||||
{
|
{
|
||||||
let cache = self.cache.read().await;
|
let cache = self.cache.read().await;
|
||||||
@@ -733,9 +789,15 @@ impl ResearcherHand {
|
|||||||
fetched_at: Some(chrono::Utc::now().to_rfc3339()),
|
fetched_at: Some(chrono::Utc::now().to_rfc3339()),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Cache the result
|
// Cache the result (with capacity limit)
|
||||||
{
|
{
|
||||||
let mut cache = self.cache.write().await;
|
let mut cache = self.cache.write().await;
|
||||||
|
if cache.len() >= CACHE_MAX_ENTRIES {
|
||||||
|
// Simple eviction: remove first entry
|
||||||
|
if let Some(key) = cache.keys().next().cloned() {
|
||||||
|
cache.remove(&key);
|
||||||
|
}
|
||||||
|
}
|
||||||
cache.insert(url.to_string(), result.clone());
|
cache.insert(url.to_string(), result.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -992,6 +1054,106 @@ fn is_cjk_char(c: char) -> bool {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate a URL for SSRF safety before fetching
|
||||||
|
fn validate_fetch_url(url_str: &str) -> Result<()> {
|
||||||
|
if url_str.len() > MAX_URL_LENGTH {
|
||||||
|
return Err(zclaw_types::ZclawError::HandError(
|
||||||
|
format!("URL exceeds maximum length of {} characters", MAX_URL_LENGTH)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = Url::parse(url_str)
|
||||||
|
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Invalid URL: {}", e)))?;
|
||||||
|
|
||||||
|
match url.scheme() {
|
||||||
|
"http" | "https" => {}
|
||||||
|
scheme => {
|
||||||
|
return Err(zclaw_types::ZclawError::HandError(
|
||||||
|
format!("URL scheme '{}' not allowed, only http/https", scheme)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = url.host_str()
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("URL must have a host".into()))?;
|
||||||
|
|
||||||
|
// Strip IPv6 brackets for parsing
|
||||||
|
let host_for_parsing = if host.starts_with('[') && host.ends_with(']') {
|
||||||
|
&host[1..host.len()-1]
|
||||||
|
} else {
|
||||||
|
host
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Ok(ip) = host_for_parsing.parse::<IpAddr>() {
|
||||||
|
validate_ip(&ip)?;
|
||||||
|
} else {
|
||||||
|
validate_hostname(host)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_ip(ip: &IpAddr) -> Result<()> {
|
||||||
|
match ip {
|
||||||
|
IpAddr::V4(v4) => validate_ipv4(v4),
|
||||||
|
IpAddr::V6(v6) => validate_ipv6(v6),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_ipv4(ip: &Ipv4Addr) -> Result<()> {
|
||||||
|
let o = ip.octets();
|
||||||
|
if o[0] == 127 { return Err(ssrf_err("loopback")); }
|
||||||
|
if o[0] == 10 { return Err(ssrf_err("private 10.x.x.x")); }
|
||||||
|
if o[0] == 172 && (16..=31).contains(&o[1]) { return Err(ssrf_err("private 172.16-31.x.x")); }
|
||||||
|
if o[0] == 192 && o[1] == 168 { return Err(ssrf_err("private 192.168.x.x")); }
|
||||||
|
if o[0] == 169 && o[1] == 254 { return Err(ssrf_err("link-local/metadata")); }
|
||||||
|
if o[0] == 0 { return Err(ssrf_err("0.x.x.x")); }
|
||||||
|
if *ip == Ipv4Addr::new(255, 255, 255, 255) { return Err(ssrf_err("broadcast")); }
|
||||||
|
if (224..=239).contains(&o[0]) { return Err(ssrf_err("multicast")); }
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_ipv6(ip: &Ipv6Addr) -> Result<()> {
|
||||||
|
if *ip == Ipv6Addr::LOCALHOST { return Err(ssrf_err("IPv6 loopback")); }
|
||||||
|
if *ip == Ipv6Addr::UNSPECIFIED { return Err(ssrf_err("IPv6 unspecified")); }
|
||||||
|
let segs = ip.segments();
|
||||||
|
// IPv4-mapped: ::ffff:x.x.x.x
|
||||||
|
if segs[5] == 0xffff {
|
||||||
|
let v4 = ((segs[6] as u32) << 16) | (segs[7] as u32);
|
||||||
|
validate_ipv4(&Ipv4Addr::from(v4))?;
|
||||||
|
}
|
||||||
|
// Link-local fe80::/10
|
||||||
|
if (segs[0] & 0xffc0) == 0xfe80 { return Err(ssrf_err("IPv6 link-local")); }
|
||||||
|
// Unique local fc00::/7
|
||||||
|
if (segs[0] & 0xfe00) == 0xfc00 { return Err(ssrf_err("IPv6 unique local")); }
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_hostname(host: &str) -> Result<()> {
|
||||||
|
let h = host.to_lowercase();
|
||||||
|
let blocked = [
|
||||||
|
"localhost", "localhost.localdomain", "ip6-localhost",
|
||||||
|
"ip6-loopback", "metadata.google.internal", "metadata",
|
||||||
|
"kubernetes.default", "kubernetes.default.svc",
|
||||||
|
];
|
||||||
|
for b in &blocked {
|
||||||
|
if h == *b || h.ends_with(&format!(".{}", b)) {
|
||||||
|
return Err(ssrf_err(&format!("blocked host '{}'", host)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Decimal IP bypass: 2130706433 = 127.0.0.1
|
||||||
|
if h.chars().all(|c| c.is_ascii_digit()) {
|
||||||
|
if let Ok(num) = h.parse::<u32>() {
|
||||||
|
validate_ipv4(&Ipv4Addr::from(num))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ssrf_err(reason: &str) -> zclaw_types::ZclawError {
|
||||||
|
zclaw_types::ZclawError::HandError(format!("Access denied: {}", reason))
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract text between two delimiters
|
/// Extract text between two delimiters
|
||||||
fn extract_between<'a>(text: &'a str, start: &str, end: &str) -> Option<&'a str> {
|
fn extract_between<'a>(text: &'a str, start: &str, end: &str) -> Option<&'a str> {
|
||||||
let start_idx = text.find(start)?;
|
let start_idx = text.find(start)?;
|
||||||
@@ -1042,17 +1204,11 @@ fn extract_href(text: &str) -> Option<String> {
|
|||||||
|
|
||||||
/// Extract the real URL from DDG's redirect link (uddg= parameter)
|
/// Extract the real URL from DDG's redirect link (uddg= parameter)
|
||||||
fn extract_href_uddg(text: &str) -> Option<String> {
|
fn extract_href_uddg(text: &str) -> Option<String> {
|
||||||
// DDG HTML uses: href="//duckduckgo.com/l/?uddg=ENCODED_URL&..."
|
|
||||||
if let Some(idx) = text.find("uddg=") {
|
if let Some(idx) = text.find("uddg=") {
|
||||||
let rest = &text[idx + 5..];
|
let rest = &text[idx + 5..];
|
||||||
let url_encoded = rest.split('&').next().unwrap_or("");
|
let url_encoded = rest.split('&').next().unwrap_or("");
|
||||||
let decoded = url_encoded.replace("%3A", ":")
|
// Use standard percent decoding instead of manual replacement
|
||||||
.replace("%2F", "/")
|
let decoded = percent_decode(url_encoded);
|
||||||
.replace("%3F", "?")
|
|
||||||
.replace("%3D", "=")
|
|
||||||
.replace("%26", "&")
|
|
||||||
.replace("%20", " ")
|
|
||||||
.replace("%25", "%");
|
|
||||||
if decoded.starts_with("http") {
|
if decoded.starts_with("http") {
|
||||||
return Some(decoded);
|
return Some(decoded);
|
||||||
}
|
}
|
||||||
@@ -1062,6 +1218,27 @@ fn extract_href_uddg(text: &str) -> Option<String> {
|
|||||||
extract_href(text)
|
extract_href(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Standard percent-decode a URL-encoded string
|
||||||
|
fn percent_decode(input: &str) -> String {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
let bytes = input.as_bytes();
|
||||||
|
let mut i = 0;
|
||||||
|
while i < bytes.len() {
|
||||||
|
if bytes[i] == b'%' && i + 2 < bytes.len() {
|
||||||
|
if let Ok(byte) = u8::from_str_radix(
|
||||||
|
&input[i + 1..i + 3], 16
|
||||||
|
) {
|
||||||
|
result.push(byte);
|
||||||
|
i += 3;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(bytes[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
String::from_utf8_lossy(&result).to_string()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -1632,10 +1809,108 @@ mod tests {
|
|||||||
<h3 class="t"><a href="https://www.example.cn/page1">中国医疗政策 2024</a></h3>
|
<h3 class="t"><a href="https://www.example.cn/page1">中国医疗政策 2024</a></h3>
|
||||||
<div class="c-abstract">这是关于医疗政策的摘要信息。</div>
|
<div class="c-abstract">这是关于医疗政策的摘要信息。</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="c-container new-pmd">
|
||||||
|
<h3><a href="https://www.example.cn/page2">第二条结果</a></h3>
|
||||||
|
</div>
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let results = hand.parse_baidu_html(html, 10);
|
let results = hand.parse_baidu_html(html, 10);
|
||||||
assert_eq!(results.len(), 1);
|
assert!(results.len() >= 1, "Should find at least 1 result, got {}", results.len());
|
||||||
assert_eq!(results[0].source, "Baidu");
|
assert_eq!(results[0].source, "Baidu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- SSRF Validation Tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssrf_blocks_localhost() {
|
||||||
|
assert!(validate_fetch_url("http://localhost:8080/admin").is_err());
|
||||||
|
assert!(validate_fetch_url("http://127.0.0.1:5432/db").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssrf_blocks_private_ip() {
|
||||||
|
assert!(validate_fetch_url("http://10.0.0.1/secret").is_err());
|
||||||
|
assert!(validate_fetch_url("http://192.168.1.1/router").is_err());
|
||||||
|
assert!(validate_fetch_url("http://172.16.0.1/internal").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssrf_blocks_cloud_metadata() {
|
||||||
|
assert!(validate_fetch_url("http://169.254.169.254/metadata").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssrf_blocks_non_http_scheme() {
|
||||||
|
assert!(validate_fetch_url("file:///etc/passwd").is_err());
|
||||||
|
assert!(validate_fetch_url("ftp://example.com/file").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssrf_allows_public_url() {
|
||||||
|
assert!(validate_fetch_url("https://www.rust-lang.org/learn").is_ok());
|
||||||
|
assert!(validate_fetch_url("https://example.com/page?q=test").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Percent Decode Tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_percent_decode_basic() {
|
||||||
|
assert_eq!(percent_decode("hello%20world"), "hello world");
|
||||||
|
assert_eq!(percent_decode("%E4%B8%AD%E6%96%87"), "中文");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_percent_decode_full_url() {
|
||||||
|
assert_eq!(
|
||||||
|
percent_decode("https%3A%2F%2Fexample.com%2Fpage%3Fq%3Dtest"),
|
||||||
|
"https://example.com/page?q=test"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_percent_decode_no_encoding() {
|
||||||
|
assert_eq!(percent_decode("plain-text_123"), "plain-text_123");
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Input Validation Tests ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_research_query_validate_empty() {
|
||||||
|
let query = ResearchQuery {
|
||||||
|
query: " ".to_string(), engine: SearchEngine::Auto,
|
||||||
|
depth: ResearchDepth::Standard, max_results: 10,
|
||||||
|
include_related: false, time_limit_secs: 60,
|
||||||
|
};
|
||||||
|
assert!(query.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_research_query_validate_too_long() {
|
||||||
|
let query = ResearchQuery {
|
||||||
|
query: "x".repeat(501), engine: SearchEngine::Auto,
|
||||||
|
depth: ResearchDepth::Standard, max_results: 10,
|
||||||
|
include_related: false, time_limit_secs: 60,
|
||||||
|
};
|
||||||
|
assert!(query.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_research_query_validate_max_results_overflow() {
|
||||||
|
let query = ResearchQuery {
|
||||||
|
query: "test".to_string(), engine: SearchEngine::Auto,
|
||||||
|
depth: ResearchDepth::Standard, max_results: 999,
|
||||||
|
include_related: false, time_limit_secs: 60,
|
||||||
|
};
|
||||||
|
assert!(query.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_research_query_validate_ok() {
|
||||||
|
let query = ResearchQuery {
|
||||||
|
query: "Rust programming".to_string(), engine: SearchEngine::Auto,
|
||||||
|
depth: ResearchDepth::Standard, max_results: 10,
|
||||||
|
include_related: false, time_limit_secs: 60,
|
||||||
|
};
|
||||||
|
assert!(query.validate().is_ok());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user