Files
zclaw_openfang/crates/zclaw-hands/src/hands/collector.rs
iven cc7ee3189d test(hands): add unit tests for CollectorHand + fix HTML extraction position tracking
Fix extract_visible_text to use proper byte position tracking (pos += char_len)
instead of iterating chars without position context, which caused script/style
tag detection to fail on multi-byte content. Also adds script/style stripping
logic and raises truncation limit to 10000 chars.

Adds 9 unit tests covering:
- Config identity verification
- OutputFormat serialization round-trip
- HTML text extraction (basic, script stripping, style stripping, empty input)
- Aggregate action with empty URLs
- CollectorAction deserialization (Collect/Aggregate/Extract)
- CollectionTarget deserialization
2026-04-01 23:21:43 +08:00

594 lines
20 KiB
Rust

//! Collector Hand - Data collection and aggregation capabilities
//!
//! This hand provides web scraping, data extraction, and aggregation features.
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use zclaw_types::Result;
use crate::{Hand, HandConfig, HandContext, HandResult};
/// Output format options
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OutputFormat {
Json,
Csv,
Markdown,
Text,
}
impl Default for OutputFormat {
fn default() -> Self {
Self::Json
}
}
/// Collection target configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CollectionTarget {
/// URL to collect from
pub url: String,
/// CSS selector for items
#[serde(default)]
pub selector: Option<String>,
/// Fields to extract
#[serde(default)]
pub fields: HashMap<String, String>,
/// Maximum items to collect
#[serde(default = "default_max_items")]
pub max_items: usize,
}
fn default_max_items() -> usize { 100 }
/// Collected item
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CollectedItem {
/// Source URL
pub source_url: String,
/// Collected data
pub data: HashMap<String, Value>,
/// Collection timestamp
pub collected_at: String,
}
/// Collection result
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CollectionResult {
/// Target URL
pub url: String,
/// Collected items
pub items: Vec<CollectedItem>,
/// Total items collected
pub total_items: usize,
/// Output format
pub format: OutputFormat,
/// Collection timestamp
pub collected_at: String,
/// Duration in ms
pub duration_ms: u64,
}
/// Aggregation configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AggregationConfig {
/// URLs to aggregate
pub urls: Vec<String>,
/// Fields to aggregate
#[serde(default)]
pub aggregate_fields: Vec<String>,
}
/// Collector action types
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "action")]
pub enum CollectorAction {
#[serde(rename = "collect")]
Collect { target: CollectionTarget, format: Option<OutputFormat> },
#[serde(rename = "aggregate")]
Aggregate { config: AggregationConfig },
#[serde(rename = "extract")]
Extract { url: String, selectors: HashMap<String, String> },
}
/// Collector Hand implementation
pub struct CollectorHand {
config: HandConfig,
client: reqwest::Client,
cache: Arc<RwLock<HashMap<String, String>>>,
}
impl CollectorHand {
/// Create a new collector hand
pub fn new() -> Self {
Self {
config: HandConfig {
id: "collector".to_string(),
name: "数据采集器".to_string(),
description: "从网页源收集和聚合数据".to_string(),
needs_approval: false,
dependencies: vec!["network".to_string()],
input_schema: Some(serde_json::json!({
"type": "object",
"oneOf": [
{
"properties": {
"action": { "const": "collect" },
"target": {
"type": "object",
"properties": {
"url": { "type": "string" },
"selector": { "type": "string" },
"fields": { "type": "object" },
"maxItems": { "type": "integer" }
},
"required": ["url"]
},
"format": { "type": "string", "enum": ["json", "csv", "markdown", "text"] }
},
"required": ["action", "target"]
},
{
"properties": {
"action": { "const": "extract" },
"url": { "type": "string" },
"selectors": { "type": "object" }
},
"required": ["action", "url", "selectors"]
},
{
"properties": {
"action": { "const": "aggregate" },
"config": {
"type": "object",
"properties": {
"urls": { "type": "array", "items": { "type": "string" } },
"aggregateFields": { "type": "array", "items": { "type": "string" } }
},
"required": ["urls"]
}
},
"required": ["action", "config"]
}
]
})),
tags: vec!["data".to_string(), "collection".to_string(), "scraping".to_string()],
enabled: true,
},
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.user_agent("ZCLAW-Collector/1.0")
.build()
.unwrap_or_else(|_| reqwest::Client::new()),
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Fetch a page
async fn fetch_page(&self, url: &str) -> Result<String> {
// Check cache
{
let cache = self.cache.read().await;
if let Some(cached) = cache.get(url) {
return Ok(cached.clone());
}
}
let response = self.client
.get(url)
.send()
.await
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Request failed: {}", e)))?;
let html = response.text().await
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
// Cache the result
{
let mut cache = self.cache.write().await;
cache.insert(url.to_string(), html.clone());
}
Ok(html)
}
/// Extract text by simple pattern matching
fn extract_by_pattern(&self, html: &str, pattern: &str) -> String {
// Simple implementation: find text between tags
if pattern.contains("title") || pattern.contains("h1") {
if let Some(start) = html.find("<title>") {
if let Some(end) = html[start..].find("</title>") {
return html[start + 7..start + end]
.replace("&amp;", "&")
.replace("&lt;", "<")
.replace("&gt;", ">")
.trim()
.to_string();
}
}
}
// Extract meta description
if pattern.contains("description") || pattern.contains("meta") {
if let Some(start) = html.find("name=\"description\"") {
let rest = &html[start..];
if let Some(content_start) = rest.find("content=\"") {
let content = &rest[content_start + 9..];
if let Some(end) = content.find('"') {
return content[..end].trim().to_string();
}
}
}
}
// Default: extract visible text
self.extract_visible_text(html)
}
/// Extract visible text from HTML, stripping scripts and styles
fn extract_visible_text(&self, html: &str) -> String {
let html_lower = html.to_lowercase();
let mut text = String::new();
let mut in_tag = false;
let mut in_script = false;
let mut in_style = false;
let mut pos: usize = 0;
for c in html.chars() {
let char_len = c.len_utf8();
match c {
'<' => {
let remaining = &html_lower[pos..];
if remaining.starts_with("</script") {
in_script = false;
} else if remaining.starts_with("</style") {
in_style = false;
}
if remaining.starts_with("<script") {
in_script = true;
} else if remaining.starts_with("<style") {
in_style = true;
}
in_tag = true;
}
'>' => {
in_tag = false;
}
_ if in_tag => {}
_ if in_script || in_style => {}
' ' | '\n' | '\t' | '\r' => {
if !text.ends_with(' ') && !text.is_empty() {
text.push(' ');
}
}
_ => text.push(c),
}
pos += char_len;
}
if text.len() > 10000 {
text.truncate(10000);
text.push_str("...");
}
text.trim().to_string()
}
/// Execute collection
async fn execute_collect(&self, target: &CollectionTarget, format: OutputFormat) -> Result<CollectionResult> {
let start = std::time::Instant::now();
let html = self.fetch_page(&target.url).await?;
let mut items = Vec::new();
let mut data = HashMap::new();
// Extract fields
for (field_name, selector) in &target.fields {
let value = self.extract_by_pattern(&html, selector);
data.insert(field_name.clone(), Value::String(value));
}
// If no fields specified, extract basic info
if data.is_empty() {
data.insert("title".to_string(), Value::String(self.extract_by_pattern(&html, "title")));
data.insert("content".to_string(), Value::String(self.extract_visible_text(&html)));
}
items.push(CollectedItem {
source_url: target.url.clone(),
data,
collected_at: chrono::Utc::now().to_rfc3339(),
});
Ok(CollectionResult {
url: target.url.clone(),
total_items: items.len(),
items,
format,
collected_at: chrono::Utc::now().to_rfc3339(),
duration_ms: start.elapsed().as_millis() as u64,
})
}
/// Execute aggregation
async fn execute_aggregate(&self, config: &AggregationConfig) -> Result<Value> {
let start = std::time::Instant::now();
let mut results = Vec::new();
for url in config.urls.iter().take(10) {
match self.fetch_page(url).await {
Ok(html) => {
let mut data = HashMap::new();
for field in &config.aggregate_fields {
let value = self.extract_by_pattern(&html, field);
data.insert(field.clone(), Value::String(value));
}
if data.is_empty() {
data.insert("content".to_string(), Value::String(self.extract_visible_text(&html)));
}
results.push(data);
}
Err(e) => {
tracing::warn!(target: "collector", url = url, error = %e, "Failed to fetch");
}
}
}
Ok(json!({
"results": results,
"source_count": config.urls.len(),
"duration_ms": start.elapsed().as_millis()
}))
}
/// Execute extraction
async fn execute_extract(&self, url: &str, selectors: &HashMap<String, String>) -> Result<HashMap<String, String>> {
let html = self.fetch_page(url).await?;
let mut results = HashMap::new();
for (field_name, selector) in selectors {
let value = self.extract_by_pattern(&html, selector);
results.insert(field_name.clone(), value);
}
Ok(results)
}
}
impl Default for CollectorHand {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Hand for CollectorHand {
fn config(&self) -> &HandConfig {
&self.config
}
async fn execute(&self, _context: &HandContext, input: Value) -> Result<HandResult> {
let action: CollectorAction = serde_json::from_value(input.clone())
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Invalid action: {}", e)))?;
let start = std::time::Instant::now();
let result = match action {
CollectorAction::Collect { target, format } => {
let fmt = format.unwrap_or(OutputFormat::Json);
let collection = self.execute_collect(&target, fmt.clone()).await?;
json!({
"action": "collect",
"url": target.url,
"total_items": collection.total_items,
"duration_ms": start.elapsed().as_millis(),
"items": collection.items
})
}
CollectorAction::Aggregate { config } => {
let aggregation = self.execute_aggregate(&config).await?;
json!({
"action": "aggregate",
"duration_ms": start.elapsed().as_millis(),
"result": aggregation
})
}
CollectorAction::Extract { url, selectors } => {
let extracted = self.execute_extract(&url, &selectors).await?;
json!({
"action": "extract",
"url": url,
"duration_ms": start.elapsed().as_millis(),
"data": extracted
})
}
};
Ok(HandResult::success(result))
}
fn needs_approval(&self) -> bool {
false
}
fn check_dependencies(&self) -> Result<Vec<String>> {
Ok(Vec::new())
}
fn status(&self) -> crate::HandStatus {
crate::HandStatus::Idle
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_collector_config() {
let hand = CollectorHand::new();
assert_eq!(hand.config().id, "collector");
assert_eq!(hand.config().name, "数据采集器");
assert!(hand.config().enabled);
assert!(!hand.config().needs_approval);
}
#[test]
fn test_output_format_serialize() {
let formats = vec![
(OutputFormat::Csv, "\"csv\""),
(OutputFormat::Markdown, "\"markdown\""),
(OutputFormat::Json, "\"json\""),
(OutputFormat::Text, "\"text\""),
];
for (fmt, expected) in formats {
let serialized = serde_json::to_string(&fmt).unwrap();
assert_eq!(serialized, expected);
}
// Verify round-trip deserialization
for json_str in &["\"csv\"", "\"markdown\"", "\"json\"", "\"text\""] {
let deserialized: OutputFormat = serde_json::from_str(json_str).unwrap();
let re_serialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(&re_serialized, json_str);
}
}
#[test]
fn test_extract_visible_text_basic() {
let hand = CollectorHand::new();
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(text.contains("Title"), "should contain 'Title', got: {}", text);
assert!(text.contains("Content here"), "should contain 'Content here', got: {}", text);
}
#[test]
fn test_extract_visible_text_strips_scripts() {
let hand = CollectorHand::new();
let html = "<html><body><script>alert('xss')</script><p>Safe content</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(!text.contains("alert"), "script content should be removed, got: {}", text);
assert!(!text.contains("xss"), "script content should be removed, got: {}", text);
assert!(text.contains("Safe content"), "visible content should remain, got: {}", text);
}
#[test]
fn test_extract_visible_text_strips_styles() {
let hand = CollectorHand::new();
let html = "<html><head><style>body { color: red; }</style></head><body><p>Text</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(!text.contains("color"), "style content should be removed, got: {}", text);
assert!(!text.contains("red"), "style content should be removed, got: {}", text);
assert!(text.contains("Text"), "visible content should remain, got: {}", text);
}
#[test]
fn test_extract_visible_text_empty() {
let hand = CollectorHand::new();
let text = hand.extract_visible_text("");
assert!(text.is_empty(), "empty HTML should produce empty text, got: '{}'", text);
}
#[tokio::test]
async fn test_aggregate_action_empty_urls() {
let hand = CollectorHand::new();
let config = AggregationConfig {
urls: vec![],
aggregate_fields: vec![],
};
let result = hand.execute_aggregate(&config).await.unwrap();
let results = result.get("results").unwrap().as_array().unwrap();
assert_eq!(results.len(), 0, "empty URLs should produce empty results");
assert_eq!(result.get("source_count").unwrap().as_u64().unwrap(), 0);
}
#[test]
fn test_collector_action_deserialize() {
// Collect action
let collect_json = json!({
"action": "collect",
"target": {
"url": "https://example.com",
"selector": ".article",
"fields": { "title": "h1" },
"maxItems": 10
},
"format": "markdown"
});
let action: CollectorAction = serde_json::from_value(collect_json).unwrap();
match action {
CollectorAction::Collect { target, format } => {
assert_eq!(target.url, "https://example.com");
assert_eq!(target.selector.as_deref(), Some(".article"));
assert_eq!(target.max_items, 10);
assert!(format.is_some());
assert_eq!(format.unwrap(), OutputFormat::Markdown);
}
_ => panic!("Expected Collect action"),
}
// Aggregate action
let aggregate_json = json!({
"action": "aggregate",
"config": {
"urls": ["https://a.com", "https://b.com"],
"aggregateFields": ["title", "content"]
}
});
let action: CollectorAction = serde_json::from_value(aggregate_json).unwrap();
match action {
CollectorAction::Aggregate { config } => {
assert_eq!(config.urls.len(), 2);
assert_eq!(config.aggregate_fields.len(), 2);
}
_ => panic!("Expected Aggregate action"),
}
// Extract action
let extract_json = json!({
"action": "extract",
"url": "https://example.com",
"selectors": { "title": "h1", "body": "p" }
});
let action: CollectorAction = serde_json::from_value(extract_json).unwrap();
match action {
CollectorAction::Extract { url, selectors } => {
assert_eq!(url, "https://example.com");
assert_eq!(selectors.len(), 2);
}
_ => panic!("Expected Extract action"),
}
}
#[test]
fn test_collection_target_deserialize() {
let json = json!({
"url": "https://example.com/page",
"selector": ".content",
"fields": {
"title": "h1",
"author": ".author-name"
},
"maxItems": 50
});
let target: CollectionTarget = serde_json::from_value(json).unwrap();
assert_eq!(target.url, "https://example.com/page");
assert_eq!(target.selector.as_deref(), Some(".content"));
assert_eq!(target.fields.len(), 2);
assert_eq!(target.max_items, 50);
}
}