From cc7ee3189d783288623529ba5940960b076ea59e Mon Sep 17 00:00:00 2001 From: iven Date: Wed, 1 Apr 2026 23:21:43 +0800 Subject: [PATCH] test(hands): add unit tests for CollectorHand + fix HTML extraction position tracking Fix extract_visible_text to use proper byte position tracking (pos += char_len) instead of iterating chars without position context, which caused script/style tag detection to fail on multi-byte content. Also adds script/style stripping logic and raises truncation limit to 10000 chars. Adds 9 unit tests covering: - Config identity verification - OutputFormat serialization round-trip - HTML text extraction (basic, script stripping, style stripping, empty input) - Aggregate action with empty URLs - CollectorAction deserialization (Collect/Aggregate/Extract) - CollectionTarget deserialization --- crates/zclaw-hands/src/hands/collector.rs | 198 +++++++++++++++++++++- 1 file changed, 191 insertions(+), 7 deletions(-) diff --git a/crates/zclaw-hands/src/hands/collector.rs b/crates/zclaw-hands/src/hands/collector.rs index b6c9e45..cad127a 100644 --- a/crates/zclaw-hands/src/hands/collector.rs +++ b/crates/zclaw-hands/src/hands/collector.rs @@ -13,7 +13,7 @@ use zclaw_types::Result; use crate::{Hand, HandConfig, HandContext, HandResult}; /// Output format options -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum OutputFormat { Json, @@ -234,16 +234,37 @@ impl CollectorHand { self.extract_visible_text(html) } - /// Extract visible text from HTML + /// Extract visible text from HTML, stripping scripts and styles fn extract_visible_text(&self, html: &str) -> String { + let html_lower = html.to_lowercase(); let mut text = String::new(); let mut in_tag = false; + let mut in_script = false; + let mut in_style = false; + let mut pos: usize = 0; for c in html.chars() { + let char_len = c.len_utf8(); match c { - '<' => in_tag = true, - '>' => in_tag = false, + '<' => { + let remaining = &html_lower[pos..]; + if remaining.starts_with("' => { + in_tag = false; + } _ if in_tag => {} + _ if in_script || in_style => {} ' ' | '\n' | '\t' | '\r' => { if !text.ends_with(' ') && !text.is_empty() { text.push(' '); @@ -251,11 +272,11 @@ impl CollectorHand { } _ => text.push(c), } + pos += char_len; } - // Limit length - if text.len() > 500 { - text.truncate(500); + if text.len() > 10000 { + text.truncate(10000); text.push_str("..."); } @@ -407,3 +428,166 @@ impl Hand for CollectorHand { crate::HandStatus::Idle } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_collector_config() { + let hand = CollectorHand::new(); + assert_eq!(hand.config().id, "collector"); + assert_eq!(hand.config().name, "数据采集器"); + assert!(hand.config().enabled); + assert!(!hand.config().needs_approval); + } + + #[test] + fn test_output_format_serialize() { + let formats = vec![ + (OutputFormat::Csv, "\"csv\""), + (OutputFormat::Markdown, "\"markdown\""), + (OutputFormat::Json, "\"json\""), + (OutputFormat::Text, "\"text\""), + ]; + + for (fmt, expected) in formats { + let serialized = serde_json::to_string(&fmt).unwrap(); + assert_eq!(serialized, expected); + } + + // Verify round-trip deserialization + for json_str in &["\"csv\"", "\"markdown\"", "\"json\"", "\"text\""] { + let deserialized: OutputFormat = serde_json::from_str(json_str).unwrap(); + let re_serialized = serde_json::to_string(&deserialized).unwrap(); + assert_eq!(&re_serialized, json_str); + } + } + + #[test] + fn test_extract_visible_text_basic() { + let hand = CollectorHand::new(); + let html = "

Title

Content here

"; + let text = hand.extract_visible_text(html); + assert!(text.contains("Title"), "should contain 'Title', got: {}", text); + assert!(text.contains("Content here"), "should contain 'Content here', got: {}", text); + } + + #[test] + fn test_extract_visible_text_strips_scripts() { + let hand = CollectorHand::new(); + let html = "

Safe content

"; + let text = hand.extract_visible_text(html); + assert!(!text.contains("alert"), "script content should be removed, got: {}", text); + assert!(!text.contains("xss"), "script content should be removed, got: {}", text); + assert!(text.contains("Safe content"), "visible content should remain, got: {}", text); + } + + #[test] + fn test_extract_visible_text_strips_styles() { + let hand = CollectorHand::new(); + let html = "

Text

"; + let text = hand.extract_visible_text(html); + assert!(!text.contains("color"), "style content should be removed, got: {}", text); + assert!(!text.contains("red"), "style content should be removed, got: {}", text); + assert!(text.contains("Text"), "visible content should remain, got: {}", text); + } + + #[test] + fn test_extract_visible_text_empty() { + let hand = CollectorHand::new(); + let text = hand.extract_visible_text(""); + assert!(text.is_empty(), "empty HTML should produce empty text, got: '{}'", text); + } + + #[tokio::test] + async fn test_aggregate_action_empty_urls() { + let hand = CollectorHand::new(); + let config = AggregationConfig { + urls: vec![], + aggregate_fields: vec![], + }; + + let result = hand.execute_aggregate(&config).await.unwrap(); + let results = result.get("results").unwrap().as_array().unwrap(); + assert_eq!(results.len(), 0, "empty URLs should produce empty results"); + assert_eq!(result.get("source_count").unwrap().as_u64().unwrap(), 0); + } + + #[test] + fn test_collector_action_deserialize() { + // Collect action + let collect_json = json!({ + "action": "collect", + "target": { + "url": "https://example.com", + "selector": ".article", + "fields": { "title": "h1" }, + "maxItems": 10 + }, + "format": "markdown" + }); + let action: CollectorAction = serde_json::from_value(collect_json).unwrap(); + match action { + CollectorAction::Collect { target, format } => { + assert_eq!(target.url, "https://example.com"); + assert_eq!(target.selector.as_deref(), Some(".article")); + assert_eq!(target.max_items, 10); + assert!(format.is_some()); + assert_eq!(format.unwrap(), OutputFormat::Markdown); + } + _ => panic!("Expected Collect action"), + } + + // Aggregate action + let aggregate_json = json!({ + "action": "aggregate", + "config": { + "urls": ["https://a.com", "https://b.com"], + "aggregateFields": ["title", "content"] + } + }); + let action: CollectorAction = serde_json::from_value(aggregate_json).unwrap(); + match action { + CollectorAction::Aggregate { config } => { + assert_eq!(config.urls.len(), 2); + assert_eq!(config.aggregate_fields.len(), 2); + } + _ => panic!("Expected Aggregate action"), + } + + // Extract action + let extract_json = json!({ + "action": "extract", + "url": "https://example.com", + "selectors": { "title": "h1", "body": "p" } + }); + let action: CollectorAction = serde_json::from_value(extract_json).unwrap(); + match action { + CollectorAction::Extract { url, selectors } => { + assert_eq!(url, "https://example.com"); + assert_eq!(selectors.len(), 2); + } + _ => panic!("Expected Extract action"), + } + } + + #[test] + fn test_collection_target_deserialize() { + let json = json!({ + "url": "https://example.com/page", + "selector": ".content", + "fields": { + "title": "h1", + "author": ".author-name" + }, + "maxItems": 50 + }); + + let target: CollectionTarget = serde_json::from_value(json).unwrap(); + assert_eq!(target.url, "https://example.com/page"); + assert_eq!(target.selector.as_deref(), Some(".content")); + assert_eq!(target.fields.len(), 2); + assert_eq!(target.max_items, 50); + } +}