//! Web fetch tool with SSRF protection //! //! This module provides a secure web fetching capability with comprehensive //! SSRF (Server-Side Request Forgery) protection including: //! - Private IP range blocking (RFC 1918) //! - Cloud metadata endpoint blocking (169.254.169.254) //! - Localhost/loopback blocking //! - Redirect protection with recursive checks //! - Timeout control //! - Response size limits use async_trait::async_trait; use reqwest::redirect::Policy; use serde_json::{json, Value}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::time::Duration; use url::Url; use zclaw_types::{Result, ZclawError}; use crate::tool::{Tool, ToolContext}; /// Maximum response size in bytes (10 MB) const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; /// Request timeout in seconds const REQUEST_TIMEOUT_SECS: u64 = 30; /// Maximum number of redirect hops allowed const MAX_REDIRECT_HOPS: usize = 5; /// Maximum URL length const MAX_URL_LENGTH: usize = 2048; pub struct WebFetchTool { client: reqwest::Client, } impl WebFetchTool { pub fn new() -> Self { // Build a client with redirect policy that we control // We'll handle redirects manually to validate each target let client = reqwest::Client::builder() .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) .redirect(Policy::none()) // Handle redirects manually for SSRF validation .user_agent("ZCLAW/1.0") .build() .unwrap_or_else(|_| reqwest::Client::new()); Self { client } } /// Validate a URL for SSRF safety /// /// This checks: /// - URL scheme (only http/https allowed) /// - Private IP ranges (RFC 1918) /// - Loopback addresses /// - Cloud metadata endpoints /// - Link-local addresses fn validate_url(&self, url_str: &str) -> Result { // Check URL length if url_str.len() > MAX_URL_LENGTH { return Err(ZclawError::InvalidInput(format!( "URL exceeds maximum length of {} characters", MAX_URL_LENGTH ))); } // Parse URL let url = Url::parse(url_str) .map_err(|e| ZclawError::InvalidInput(format!("Invalid URL: {}", e)))?; // Check scheme - only allow http and https match url.scheme() { "http" | "https" => {} scheme => { return Err(ZclawError::InvalidInput(format!( "URL scheme '{}' is not allowed. Only http and https are permitted.", scheme ))); } } // Extract host - for IPv6, url.host_str() returns the address without brackets // But url::Url also provides host() which gives us the parsed Host type let host = url .host_str() .ok_or_else(|| ZclawError::InvalidInput("URL must have a host".into()))?; // Check if host is an IP address or domain // For IPv6 in URLs, host_str returns the address with brackets, e.g., "[::1]" // We need to strip the brackets for parsing let host_for_parsing = if host.starts_with('[') && host.ends_with(']') { &host[1..host.len()-1] } else { host }; if let Ok(ip) = host_for_parsing.parse::() { self.validate_ip_address(&ip)?; } else { // For domain names, we need to resolve and check the IP // This is handled during the actual request, but we do basic checks here self.validate_hostname(host)?; } Ok(url) } /// Validate an IP address for SSRF safety fn validate_ip_address(&self, ip: &IpAddr) -> Result<()> { match ip { IpAddr::V4(ipv4) => self.validate_ipv4(ipv4)?, IpAddr::V6(ipv6) => self.validate_ipv6(ipv6)?, } Ok(()) } /// Validate IPv4 address fn validate_ipv4(&self, ip: &Ipv4Addr) -> Result<()> { let octets = ip.octets(); // Block loopback (127.0.0.0/8) if octets[0] == 127 { return Err(ZclawError::InvalidInput( "Access to loopback addresses (127.x.x.x) is not allowed".into(), )); } // Block private ranges (RFC 1918) // 10.0.0.0/8 if octets[0] == 10 { return Err(ZclawError::InvalidInput( "Access to private IP range 10.x.x.x is not allowed".into(), )); } // 172.16.0.0/12 (172.16.0.0 - 172.31.255.255) if octets[0] == 172 && (16..=31).contains(&octets[1]) { return Err(ZclawError::InvalidInput( "Access to private IP range 172.16-31.x.x is not allowed".into(), )); } // 192.168.0.0/16 if octets[0] == 192 && octets[1] == 168 { return Err(ZclawError::InvalidInput( "Access to private IP range 192.168.x.x is not allowed".into(), )); } // Block cloud metadata endpoint (169.254.169.254) if octets[0] == 169 && octets[1] == 254 && octets[2] == 169 && octets[3] == 254 { return Err(ZclawError::InvalidInput( "Access to cloud metadata endpoint (169.254.169.254) is not allowed".into(), )); } // Block link-local addresses (169.254.0.0/16) if octets[0] == 169 && octets[1] == 254 { return Err(ZclawError::InvalidInput( "Access to link-local addresses (169.254.x.x) is not allowed".into(), )); } // Block 0.0.0.0/8 (current network) if octets[0] == 0 { return Err(ZclawError::InvalidInput( "Access to 0.x.x.x addresses is not allowed".into(), )); } // Block broadcast address if *ip == Ipv4Addr::new(255, 255, 255, 255) { return Err(ZclawError::InvalidInput( "Access to broadcast address is not allowed".into(), )); } // Block multicast addresses (224.0.0.0/4) if (224..=239).contains(&octets[0]) { return Err(ZclawError::InvalidInput( "Access to multicast addresses is not allowed".into(), )); } Ok(()) } /// Validate IPv6 address fn validate_ipv6(&self, ip: &Ipv6Addr) -> Result<()> { // Block loopback (::1) if *ip == Ipv6Addr::LOCALHOST { return Err(ZclawError::InvalidInput( "Access to IPv6 loopback address (::1) is not allowed".into(), )); } // Block unspecified address (::) if *ip == Ipv6Addr::UNSPECIFIED { return Err(ZclawError::InvalidInput( "Access to unspecified IPv6 address (::) is not allowed".into(), )); } // Block IPv4-mapped IPv6 addresses (::ffff:0:0/96) // These could bypass IPv4 checks if ip.to_string().starts_with("::ffff:") { // Extract the embedded IPv4 and validate it let segments = ip.segments(); // IPv4-mapped format: 0:0:0:0:0:ffff:xxxx:xxxx if segments[5] == 0xffff { let v4_addr = ((segments[6] as u32) << 16) | (segments[7] as u32); let ipv4 = Ipv4Addr::from(v4_addr); self.validate_ipv4(&ipv4)?; } } // Block link-local IPv6 (fe80::/10) let segments = ip.segments(); if (segments[0] & 0xffc0) == 0xfe80 { return Err(ZclawError::InvalidInput( "Access to IPv6 link-local addresses is not allowed".into(), )); } // Block unique local addresses (fc00::/7) - IPv6 equivalent of private ranges if (segments[0] & 0xfe00) == 0xfc00 { return Err(ZclawError::InvalidInput( "Access to IPv6 unique local addresses is not allowed".into(), )); } Ok(()) } /// Validate a hostname for potential SSRF attacks fn validate_hostname(&self, host: &str) -> Result<()> { let host_lower = host.to_lowercase(); // Block localhost variants let blocked_hosts = [ "localhost", "localhost.localdomain", "ip6-localhost", "ip6-loopback", "metadata.google.internal", "metadata", "kubernetes.default", "kubernetes.default.svc", ]; for blocked in &blocked_hosts { if host_lower == *blocked || host_lower.ends_with(&format!(".{}", blocked)) { return Err(ZclawError::InvalidInput(format!( "Access to '{}' is not allowed", host ))); } } // Block hostnames that look like IP addresses (decimal, octal, hex encoding) // These could be used to bypass IP checks self.check_hostname_ip_bypass(&host_lower)?; Ok(()) } /// Check for hostname-based IP bypass attempts fn check_hostname_ip_bypass(&self, host: &str) -> Result<()> { // Check for decimal IP encoding (e.g., 2130706433 = 127.0.0.1) if host.chars().all(|c| c.is_ascii_digit()) { if let Ok(num) = host.parse::() { let ip = Ipv4Addr::from(num); self.validate_ipv4(&ip)?; } } // Check for domains that might resolve to private IPs // This is not exhaustive but catches common patterns // The actual DNS resolution check happens during the request Ok(()) } /// Follow redirects with SSRF validation async fn follow_redirects_safe(&self, url: Url, max_hops: usize) -> Result<(Url, reqwest::Response)> { let mut current_url = url; let mut hops = 0; loop { // Validate the current URL current_url = self.validate_url(current_url.as_str())?; // Make the request let response = self .client .get(current_url.clone()) .send() .await .map_err(|e| ZclawError::ToolError(format!("Request failed: {}", e)))?; // Check if it's a redirect let status = response.status(); if status.is_redirection() { hops += 1; if hops > max_hops { return Err(ZclawError::InvalidInput(format!( "Too many redirects (max {})", max_hops ))); } // Get the redirect location let location = response .headers() .get(reqwest::header::LOCATION) .and_then(|h| h.to_str().ok()) .ok_or_else(|| { ZclawError::ToolError("Redirect without Location header".into()) })?; // Resolve the location against the current URL let new_url = current_url.join(location).map_err(|e| { ZclawError::InvalidInput(format!("Invalid redirect location: {}", e)) })?; tracing::debug!( "Following redirect {} -> {}", current_url.as_str(), new_url.as_str() ); current_url = new_url; // Continue loop to validate and follow } else { // Not a redirect, return the response return Ok((current_url, response)); } } } } #[async_trait] impl Tool for WebFetchTool { fn name(&self) -> &str { "web_fetch" } fn description(&self) -> &str { "Fetch content from a URL with SSRF protection" } fn input_schema(&self) -> Value { json!({ "type": "object", "properties": { "url": { "type": "string", "description": "The URL to fetch (must be http or https)" }, "method": { "type": "string", "enum": ["GET", "POST"], "description": "HTTP method (default: GET)" }, "headers": { "type": "object", "description": "Optional HTTP headers (key-value pairs)", "additionalProperties": { "type": "string" } }, "body": { "type": "string", "description": "Request body for POST requests" }, "timeout": { "type": "integer", "description": "Timeout in seconds (default: 30, max: 60)", "minimum": 1, "maximum": 60 } }, "required": ["url"] }) } async fn execute(&self, input: Value, _context: &ToolContext) -> Result { let url_str = input["url"] .as_str() .ok_or_else(|| ZclawError::InvalidInput("Missing 'url' parameter".into()))?; let method = input["method"].as_str().unwrap_or("GET").to_uppercase(); let timeout_secs = input["timeout"].as_u64().unwrap_or(REQUEST_TIMEOUT_SECS).min(60); // Validate URL for SSRF let url = self.validate_url(url_str)?; tracing::info!("WebFetch: Fetching {} with method {}", url.as_str(), method); // Build request with validated URL let mut request_builder = match method.as_str() { "GET" => self.client.get(url.clone()), "POST" => { let mut builder = self.client.post(url.clone()); if let Some(body) = input["body"].as_str() { builder = builder.body(body.to_string()); } builder } _ => { return Err(ZclawError::InvalidInput(format!( "Unsupported HTTP method: {}", method ))); } }; // Add custom headers if provided if let Some(headers) = input["headers"].as_object() { for (key, value) in headers { if let Some(value_str) = value.as_str() { // Block dangerous headers let key_lower = key.to_lowercase(); if key_lower == "host" { continue; // Don't allow overriding host } if key_lower.starts_with("x-forwarded") { continue; // Block proxy header injection } let header_name = reqwest::header::HeaderName::try_from(key.as_str()) .map_err(|e| { ZclawError::InvalidInput(format!("Invalid header name '{}': {}", key, e)) })?; let header_value = reqwest::header::HeaderValue::from_str(value_str) .map_err(|e| { ZclawError::InvalidInput(format!("Invalid header value: {}", e)) })?; request_builder = request_builder.header(header_name, header_value); } } } // Set timeout let request_builder = request_builder.timeout(Duration::from_secs(timeout_secs)); // Execute with redirect handling let response = request_builder .send() .await .map_err(|e| { let error_msg = e.to_string(); // Provide user-friendly error messages if error_msg.contains("dns") || error_msg.contains("resolve") { ZclawError::ToolError(format!( "Failed to resolve hostname: {}. Please check the URL.", url.host_str().unwrap_or("unknown") )) } else if error_msg.contains("timeout") { ZclawError::ToolError(format!( "Request timed out after {} seconds", timeout_secs )) } else if error_msg.contains("connection refused") { ZclawError::ToolError( "Connection refused. The server may be down or unreachable.".into(), ) } else { ZclawError::ToolError(format!("Request failed: {}", error_msg)) } })?; // Handle redirects manually with SSRF validation let (final_url, response) = if response.status().is_redirection() { // Start redirect following process let location = response .headers() .get(reqwest::header::LOCATION) .and_then(|h| h.to_str().ok()) .ok_or_else(|| { ZclawError::ToolError("Redirect without Location header".into()) })?; let redirect_url = url.join(location).map_err(|e| { ZclawError::InvalidInput(format!("Invalid redirect location: {}", e)) })?; self.follow_redirects_safe(redirect_url, MAX_REDIRECT_HOPS).await? } else { (url, response) }; // Check response status let status = response.status(); let status_code = status.as_u16(); // Check content length before reading body if let Some(content_length) = response.content_length() { if content_length > MAX_RESPONSE_SIZE { return Err(ZclawError::ToolError(format!( "Response too large: {} bytes (max: {} bytes)", content_length, MAX_RESPONSE_SIZE ))); } } // Get content type BEFORE consuming response with bytes() let content_type = response .headers() .get(reqwest::header::CONTENT_TYPE) .and_then(|h| h.to_str().ok()) .unwrap_or("text/plain") .to_string(); // Read response body with size limit let bytes = response.bytes().await.map_err(|e| { ZclawError::ToolError(format!("Failed to read response body: {}", e)) })?; // Double-check size after reading if bytes.len() as u64 > MAX_RESPONSE_SIZE { return Err(ZclawError::ToolError(format!( "Response too large: {} bytes (max: {} bytes)", bytes.len(), MAX_RESPONSE_SIZE ))); } // Try to decode as UTF-8, fall back to base64 for binary let content = String::from_utf8(bytes.to_vec()).unwrap_or_else(|_| { use base64::Engine; base64::engine::general_purpose::STANDARD.encode(&bytes) }); tracing::info!( "WebFetch: Successfully fetched {} bytes from {} (status: {})", content.len(), final_url.as_str(), status_code ); Ok(json!({ "status": status_code, "url": final_url.as_str(), "content_type": content_type, "content": content, "size": content.len() })) } } impl Default for WebFetchTool { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_validate_localhost() { let tool = WebFetchTool::new(); // Test localhost assert!(tool.validate_url("http://localhost/test").is_err()); assert!(tool.validate_url("http://127.0.0.1/test").is_err()); assert!(tool.validate_url("http://127.0.0.2/test").is_err()); } #[test] fn test_validate_private_ips() { let tool = WebFetchTool::new(); // Test 10.x.x.x assert!(tool.validate_url("http://10.0.0.1/test").is_err()); assert!(tool.validate_url("http://10.255.255.255/test").is_err()); // Test 172.16-31.x.x assert!(tool.validate_url("http://172.16.0.1/test").is_err()); assert!(tool.validate_url("http://172.31.255.255/test").is_err()); // 172.15.x.x should be allowed assert!(tool.validate_url("http://172.15.0.1/test").is_ok()); // Test 192.168.x.x assert!(tool.validate_url("http://192.168.0.1/test").is_err()); assert!(tool.validate_url("http://192.168.255.255/test").is_err()); } #[test] fn test_validate_cloud_metadata() { let tool = WebFetchTool::new(); // Test cloud metadata endpoint assert!(tool.validate_url("http://169.254.169.254/metadata").is_err()); } #[test] fn test_validate_ipv6() { let tool = WebFetchTool::new(); // Test IPv6 loopback assert!(tool.validate_url("http://[::1]/test").is_err()); // Test IPv6 unspecified assert!(tool.validate_url("http://[::]/test").is_err()); // Test IPv4-mapped loopback assert!(tool.validate_url("http://[::ffff:127.0.0.1]/test").is_err()); } #[test] fn test_validate_scheme() { let tool = WebFetchTool::new(); // Only http and https allowed assert!(tool.validate_url("ftp://example.com/test").is_err()); assert!(tool.validate_url("file:///etc/passwd").is_err()); assert!(tool.validate_url("javascript:alert(1)").is_err()); // http and https should be allowed (URL parsing succeeds) assert!(tool.validate_url("http://example.com/test").is_ok()); assert!(tool.validate_url("https://example.com/test").is_ok()); } #[test] fn test_validate_blocked_hostnames() { let tool = WebFetchTool::new(); assert!(tool.validate_url("http://localhost/test").is_err()); assert!(tool.validate_url("http://metadata.google.internal/test").is_err()); assert!(tool.validate_url("http://kubernetes.default/test").is_err()); } #[test] fn test_validate_url_length() { let tool = WebFetchTool::new(); // Create a URL that's too long let long_url = format!("http://example.com/{}", "a".repeat(3000)); assert!(tool.validate_url(&long_url).is_err()); } }