test(hands): add unit tests for CollectorHand + fix HTML extraction position tracking

Fix extract_visible_text to use proper byte position tracking (pos += char_len)
instead of iterating chars without position context, which caused script/style
tag detection to fail on multi-byte content. Also adds script/style stripping
logic and raises truncation limit to 10000 chars.

Adds 9 unit tests covering:
- Config identity verification
- OutputFormat serialization round-trip
- HTML text extraction (basic, script stripping, style stripping, empty input)
- Aggregate action with empty URLs
- CollectorAction deserialization (Collect/Aggregate/Extract)
- CollectionTarget deserialization
This commit is contained in:
iven
2026-04-01 23:21:43 +08:00
parent 62df7feac1
commit cc7ee3189d

View File

@@ -13,7 +13,7 @@ use zclaw_types::Result;
use crate::{Hand, HandConfig, HandContext, HandResult}; use crate::{Hand, HandConfig, HandContext, HandResult};
/// Output format options /// Output format options
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum OutputFormat { pub enum OutputFormat {
Json, Json,
@@ -234,16 +234,37 @@ impl CollectorHand {
self.extract_visible_text(html) self.extract_visible_text(html)
} }
/// Extract visible text from HTML /// Extract visible text from HTML, stripping scripts and styles
fn extract_visible_text(&self, html: &str) -> String { fn extract_visible_text(&self, html: &str) -> String {
let html_lower = html.to_lowercase();
let mut text = String::new(); let mut text = String::new();
let mut in_tag = false; let mut in_tag = false;
let mut in_script = false;
let mut in_style = false;
let mut pos: usize = 0;
for c in html.chars() { for c in html.chars() {
let char_len = c.len_utf8();
match c { match c {
'<' => in_tag = true, '<' => {
'>' => in_tag = false, let remaining = &html_lower[pos..];
if remaining.starts_with("</script") {
in_script = false;
} else if remaining.starts_with("</style") {
in_style = false;
}
if remaining.starts_with("<script") {
in_script = true;
} else if remaining.starts_with("<style") {
in_style = true;
}
in_tag = true;
}
'>' => {
in_tag = false;
}
_ if in_tag => {} _ if in_tag => {}
_ if in_script || in_style => {}
' ' | '\n' | '\t' | '\r' => { ' ' | '\n' | '\t' | '\r' => {
if !text.ends_with(' ') && !text.is_empty() { if !text.ends_with(' ') && !text.is_empty() {
text.push(' '); text.push(' ');
@@ -251,11 +272,11 @@ impl CollectorHand {
} }
_ => text.push(c), _ => text.push(c),
} }
pos += char_len;
} }
// Limit length if text.len() > 10000 {
if text.len() > 500 { text.truncate(10000);
text.truncate(500);
text.push_str("..."); text.push_str("...");
} }
@@ -407,3 +428,166 @@ impl Hand for CollectorHand {
crate::HandStatus::Idle crate::HandStatus::Idle
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_collector_config() {
let hand = CollectorHand::new();
assert_eq!(hand.config().id, "collector");
assert_eq!(hand.config().name, "数据采集器");
assert!(hand.config().enabled);
assert!(!hand.config().needs_approval);
}
#[test]
fn test_output_format_serialize() {
let formats = vec![
(OutputFormat::Csv, "\"csv\""),
(OutputFormat::Markdown, "\"markdown\""),
(OutputFormat::Json, "\"json\""),
(OutputFormat::Text, "\"text\""),
];
for (fmt, expected) in formats {
let serialized = serde_json::to_string(&fmt).unwrap();
assert_eq!(serialized, expected);
}
// Verify round-trip deserialization
for json_str in &["\"csv\"", "\"markdown\"", "\"json\"", "\"text\""] {
let deserialized: OutputFormat = serde_json::from_str(json_str).unwrap();
let re_serialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(&re_serialized, json_str);
}
}
#[test]
fn test_extract_visible_text_basic() {
let hand = CollectorHand::new();
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(text.contains("Title"), "should contain 'Title', got: {}", text);
assert!(text.contains("Content here"), "should contain 'Content here', got: {}", text);
}
#[test]
fn test_extract_visible_text_strips_scripts() {
let hand = CollectorHand::new();
let html = "<html><body><script>alert('xss')</script><p>Safe content</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(!text.contains("alert"), "script content should be removed, got: {}", text);
assert!(!text.contains("xss"), "script content should be removed, got: {}", text);
assert!(text.contains("Safe content"), "visible content should remain, got: {}", text);
}
#[test]
fn test_extract_visible_text_strips_styles() {
let hand = CollectorHand::new();
let html = "<html><head><style>body { color: red; }</style></head><body><p>Text</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(!text.contains("color"), "style content should be removed, got: {}", text);
assert!(!text.contains("red"), "style content should be removed, got: {}", text);
assert!(text.contains("Text"), "visible content should remain, got: {}", text);
}
#[test]
fn test_extract_visible_text_empty() {
let hand = CollectorHand::new();
let text = hand.extract_visible_text("");
assert!(text.is_empty(), "empty HTML should produce empty text, got: '{}'", text);
}
#[tokio::test]
async fn test_aggregate_action_empty_urls() {
let hand = CollectorHand::new();
let config = AggregationConfig {
urls: vec![],
aggregate_fields: vec![],
};
let result = hand.execute_aggregate(&config).await.unwrap();
let results = result.get("results").unwrap().as_array().unwrap();
assert_eq!(results.len(), 0, "empty URLs should produce empty results");
assert_eq!(result.get("source_count").unwrap().as_u64().unwrap(), 0);
}
#[test]
fn test_collector_action_deserialize() {
// Collect action
let collect_json = json!({
"action": "collect",
"target": {
"url": "https://example.com",
"selector": ".article",
"fields": { "title": "h1" },
"maxItems": 10
},
"format": "markdown"
});
let action: CollectorAction = serde_json::from_value(collect_json).unwrap();
match action {
CollectorAction::Collect { target, format } => {
assert_eq!(target.url, "https://example.com");
assert_eq!(target.selector.as_deref(), Some(".article"));
assert_eq!(target.max_items, 10);
assert!(format.is_some());
assert_eq!(format.unwrap(), OutputFormat::Markdown);
}
_ => panic!("Expected Collect action"),
}
// Aggregate action
let aggregate_json = json!({
"action": "aggregate",
"config": {
"urls": ["https://a.com", "https://b.com"],
"aggregateFields": ["title", "content"]
}
});
let action: CollectorAction = serde_json::from_value(aggregate_json).unwrap();
match action {
CollectorAction::Aggregate { config } => {
assert_eq!(config.urls.len(), 2);
assert_eq!(config.aggregate_fields.len(), 2);
}
_ => panic!("Expected Aggregate action"),
}
// Extract action
let extract_json = json!({
"action": "extract",
"url": "https://example.com",
"selectors": { "title": "h1", "body": "p" }
});
let action: CollectorAction = serde_json::from_value(extract_json).unwrap();
match action {
CollectorAction::Extract { url, selectors } => {
assert_eq!(url, "https://example.com");
assert_eq!(selectors.len(), 2);
}
_ => panic!("Expected Extract action"),
}
}
#[test]
fn test_collection_target_deserialize() {
let json = json!({
"url": "https://example.com/page",
"selector": ".content",
"fields": {
"title": "h1",
"author": ".author-name"
},
"maxItems": 50
});
let target: CollectionTarget = serde_json::from_value(json).unwrap();
assert_eq!(target.url, "https://example.com/page");
assert_eq!(target.selector.as_deref(), Some(".content"));
assert_eq!(target.fields.len(), 2);
assert_eq!(target.max_items, 50);
}
}