Fix extract_visible_text to use proper byte position tracking (pos += char_len) instead of iterating chars without position context, which caused script/style tag detection to fail on multi-byte content. Also adds script/style stripping logic and raises truncation limit to 10000 chars. Adds 9 unit tests covering: - Config identity verification - OutputFormat serialization round-trip - HTML text extraction (basic, script stripping, style stripping, empty input) - Aggregate action with empty URLs - CollectorAction deserialization (Collect/Aggregate/Extract) - CollectionTarget deserialization
594 lines
20 KiB
Rust
594 lines
20 KiB
Rust
//! Collector Hand - Data collection and aggregation capabilities
|
|
//!
|
|
//! This hand provides web scraping, data extraction, and aggregation features.
|
|
|
|
use async_trait::async_trait;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{json, Value};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
use zclaw_types::Result;
|
|
|
|
use crate::{Hand, HandConfig, HandContext, HandResult};
|
|
|
|
/// Output format options
|
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum OutputFormat {
|
|
Json,
|
|
Csv,
|
|
Markdown,
|
|
Text,
|
|
}
|
|
|
|
impl Default for OutputFormat {
|
|
fn default() -> Self {
|
|
Self::Json
|
|
}
|
|
}
|
|
|
|
/// Collection target configuration
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct CollectionTarget {
|
|
/// URL to collect from
|
|
pub url: String,
|
|
/// CSS selector for items
|
|
#[serde(default)]
|
|
pub selector: Option<String>,
|
|
/// Fields to extract
|
|
#[serde(default)]
|
|
pub fields: HashMap<String, String>,
|
|
/// Maximum items to collect
|
|
#[serde(default = "default_max_items")]
|
|
pub max_items: usize,
|
|
}
|
|
|
|
fn default_max_items() -> usize { 100 }
|
|
|
|
/// Collected item
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct CollectedItem {
|
|
/// Source URL
|
|
pub source_url: String,
|
|
/// Collected data
|
|
pub data: HashMap<String, Value>,
|
|
/// Collection timestamp
|
|
pub collected_at: String,
|
|
}
|
|
|
|
/// Collection result
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct CollectionResult {
|
|
/// Target URL
|
|
pub url: String,
|
|
/// Collected items
|
|
pub items: Vec<CollectedItem>,
|
|
/// Total items collected
|
|
pub total_items: usize,
|
|
/// Output format
|
|
pub format: OutputFormat,
|
|
/// Collection timestamp
|
|
pub collected_at: String,
|
|
/// Duration in ms
|
|
pub duration_ms: u64,
|
|
}
|
|
|
|
/// Aggregation configuration
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct AggregationConfig {
|
|
/// URLs to aggregate
|
|
pub urls: Vec<String>,
|
|
/// Fields to aggregate
|
|
#[serde(default)]
|
|
pub aggregate_fields: Vec<String>,
|
|
}
|
|
|
|
/// Collector action types
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "action")]
|
|
pub enum CollectorAction {
|
|
#[serde(rename = "collect")]
|
|
Collect { target: CollectionTarget, format: Option<OutputFormat> },
|
|
#[serde(rename = "aggregate")]
|
|
Aggregate { config: AggregationConfig },
|
|
#[serde(rename = "extract")]
|
|
Extract { url: String, selectors: HashMap<String, String> },
|
|
}
|
|
|
|
/// Collector Hand implementation
|
|
pub struct CollectorHand {
|
|
config: HandConfig,
|
|
client: reqwest::Client,
|
|
cache: Arc<RwLock<HashMap<String, String>>>,
|
|
}
|
|
|
|
impl CollectorHand {
|
|
/// Create a new collector hand
|
|
pub fn new() -> Self {
|
|
Self {
|
|
config: HandConfig {
|
|
id: "collector".to_string(),
|
|
name: "数据采集器".to_string(),
|
|
description: "从网页源收集和聚合数据".to_string(),
|
|
needs_approval: false,
|
|
dependencies: vec!["network".to_string()],
|
|
input_schema: Some(serde_json::json!({
|
|
"type": "object",
|
|
"oneOf": [
|
|
{
|
|
"properties": {
|
|
"action": { "const": "collect" },
|
|
"target": {
|
|
"type": "object",
|
|
"properties": {
|
|
"url": { "type": "string" },
|
|
"selector": { "type": "string" },
|
|
"fields": { "type": "object" },
|
|
"maxItems": { "type": "integer" }
|
|
},
|
|
"required": ["url"]
|
|
},
|
|
"format": { "type": "string", "enum": ["json", "csv", "markdown", "text"] }
|
|
},
|
|
"required": ["action", "target"]
|
|
},
|
|
{
|
|
"properties": {
|
|
"action": { "const": "extract" },
|
|
"url": { "type": "string" },
|
|
"selectors": { "type": "object" }
|
|
},
|
|
"required": ["action", "url", "selectors"]
|
|
},
|
|
{
|
|
"properties": {
|
|
"action": { "const": "aggregate" },
|
|
"config": {
|
|
"type": "object",
|
|
"properties": {
|
|
"urls": { "type": "array", "items": { "type": "string" } },
|
|
"aggregateFields": { "type": "array", "items": { "type": "string" } }
|
|
},
|
|
"required": ["urls"]
|
|
}
|
|
},
|
|
"required": ["action", "config"]
|
|
}
|
|
]
|
|
})),
|
|
tags: vec!["data".to_string(), "collection".to_string(), "scraping".to_string()],
|
|
enabled: true,
|
|
},
|
|
client: reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(30))
|
|
.user_agent("ZCLAW-Collector/1.0")
|
|
.build()
|
|
.unwrap_or_else(|_| reqwest::Client::new()),
|
|
cache: Arc::new(RwLock::new(HashMap::new())),
|
|
}
|
|
}
|
|
|
|
/// Fetch a page
|
|
async fn fetch_page(&self, url: &str) -> Result<String> {
|
|
// Check cache
|
|
{
|
|
let cache = self.cache.read().await;
|
|
if let Some(cached) = cache.get(url) {
|
|
return Ok(cached.clone());
|
|
}
|
|
}
|
|
|
|
let response = self.client
|
|
.get(url)
|
|
.send()
|
|
.await
|
|
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Request failed: {}", e)))?;
|
|
|
|
let html = response.text().await
|
|
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
|
|
// Cache the result
|
|
{
|
|
let mut cache = self.cache.write().await;
|
|
cache.insert(url.to_string(), html.clone());
|
|
}
|
|
|
|
Ok(html)
|
|
}
|
|
|
|
/// Extract text by simple pattern matching
|
|
fn extract_by_pattern(&self, html: &str, pattern: &str) -> String {
|
|
// Simple implementation: find text between tags
|
|
if pattern.contains("title") || pattern.contains("h1") {
|
|
if let Some(start) = html.find("<title>") {
|
|
if let Some(end) = html[start..].find("</title>") {
|
|
return html[start + 7..start + end]
|
|
.replace("&", "&")
|
|
.replace("<", "<")
|
|
.replace(">", ">")
|
|
.trim()
|
|
.to_string();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract meta description
|
|
if pattern.contains("description") || pattern.contains("meta") {
|
|
if let Some(start) = html.find("name=\"description\"") {
|
|
let rest = &html[start..];
|
|
if let Some(content_start) = rest.find("content=\"") {
|
|
let content = &rest[content_start + 9..];
|
|
if let Some(end) = content.find('"') {
|
|
return content[..end].trim().to_string();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Default: extract visible text
|
|
self.extract_visible_text(html)
|
|
}
|
|
|
|
/// Extract visible text from HTML, stripping scripts and styles
|
|
fn extract_visible_text(&self, html: &str) -> String {
|
|
let html_lower = html.to_lowercase();
|
|
let mut text = String::new();
|
|
let mut in_tag = false;
|
|
let mut in_script = false;
|
|
let mut in_style = false;
|
|
let mut pos: usize = 0;
|
|
|
|
for c in html.chars() {
|
|
let char_len = c.len_utf8();
|
|
match c {
|
|
'<' => {
|
|
let remaining = &html_lower[pos..];
|
|
if remaining.starts_with("</script") {
|
|
in_script = false;
|
|
} else if remaining.starts_with("</style") {
|
|
in_style = false;
|
|
}
|
|
if remaining.starts_with("<script") {
|
|
in_script = true;
|
|
} else if remaining.starts_with("<style") {
|
|
in_style = true;
|
|
}
|
|
in_tag = true;
|
|
}
|
|
'>' => {
|
|
in_tag = false;
|
|
}
|
|
_ if in_tag => {}
|
|
_ if in_script || in_style => {}
|
|
' ' | '\n' | '\t' | '\r' => {
|
|
if !text.ends_with(' ') && !text.is_empty() {
|
|
text.push(' ');
|
|
}
|
|
}
|
|
_ => text.push(c),
|
|
}
|
|
pos += char_len;
|
|
}
|
|
|
|
if text.len() > 10000 {
|
|
text.truncate(10000);
|
|
text.push_str("...");
|
|
}
|
|
|
|
text.trim().to_string()
|
|
}
|
|
|
|
/// Execute collection
|
|
async fn execute_collect(&self, target: &CollectionTarget, format: OutputFormat) -> Result<CollectionResult> {
|
|
let start = std::time::Instant::now();
|
|
let html = self.fetch_page(&target.url).await?;
|
|
|
|
let mut items = Vec::new();
|
|
let mut data = HashMap::new();
|
|
|
|
// Extract fields
|
|
for (field_name, selector) in &target.fields {
|
|
let value = self.extract_by_pattern(&html, selector);
|
|
data.insert(field_name.clone(), Value::String(value));
|
|
}
|
|
|
|
// If no fields specified, extract basic info
|
|
if data.is_empty() {
|
|
data.insert("title".to_string(), Value::String(self.extract_by_pattern(&html, "title")));
|
|
data.insert("content".to_string(), Value::String(self.extract_visible_text(&html)));
|
|
}
|
|
|
|
items.push(CollectedItem {
|
|
source_url: target.url.clone(),
|
|
data,
|
|
collected_at: chrono::Utc::now().to_rfc3339(),
|
|
});
|
|
|
|
Ok(CollectionResult {
|
|
url: target.url.clone(),
|
|
total_items: items.len(),
|
|
items,
|
|
format,
|
|
collected_at: chrono::Utc::now().to_rfc3339(),
|
|
duration_ms: start.elapsed().as_millis() as u64,
|
|
})
|
|
}
|
|
|
|
/// Execute aggregation
|
|
async fn execute_aggregate(&self, config: &AggregationConfig) -> Result<Value> {
|
|
let start = std::time::Instant::now();
|
|
let mut results = Vec::new();
|
|
|
|
for url in config.urls.iter().take(10) {
|
|
match self.fetch_page(url).await {
|
|
Ok(html) => {
|
|
let mut data = HashMap::new();
|
|
for field in &config.aggregate_fields {
|
|
let value = self.extract_by_pattern(&html, field);
|
|
data.insert(field.clone(), Value::String(value));
|
|
}
|
|
if data.is_empty() {
|
|
data.insert("content".to_string(), Value::String(self.extract_visible_text(&html)));
|
|
}
|
|
results.push(data);
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(target: "collector", url = url, error = %e, "Failed to fetch");
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(json!({
|
|
"results": results,
|
|
"source_count": config.urls.len(),
|
|
"duration_ms": start.elapsed().as_millis()
|
|
}))
|
|
}
|
|
|
|
/// Execute extraction
|
|
async fn execute_extract(&self, url: &str, selectors: &HashMap<String, String>) -> Result<HashMap<String, String>> {
|
|
let html = self.fetch_page(url).await?;
|
|
let mut results = HashMap::new();
|
|
|
|
for (field_name, selector) in selectors {
|
|
let value = self.extract_by_pattern(&html, selector);
|
|
results.insert(field_name.clone(), value);
|
|
}
|
|
|
|
Ok(results)
|
|
}
|
|
}
|
|
|
|
impl Default for CollectorHand {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Hand for CollectorHand {
|
|
fn config(&self) -> &HandConfig {
|
|
&self.config
|
|
}
|
|
|
|
async fn execute(&self, _context: &HandContext, input: Value) -> Result<HandResult> {
|
|
let action: CollectorAction = serde_json::from_value(input.clone())
|
|
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Invalid action: {}", e)))?;
|
|
|
|
let start = std::time::Instant::now();
|
|
|
|
let result = match action {
|
|
CollectorAction::Collect { target, format } => {
|
|
let fmt = format.unwrap_or(OutputFormat::Json);
|
|
let collection = self.execute_collect(&target, fmt.clone()).await?;
|
|
json!({
|
|
"action": "collect",
|
|
"url": target.url,
|
|
"total_items": collection.total_items,
|
|
"duration_ms": start.elapsed().as_millis(),
|
|
"items": collection.items
|
|
})
|
|
}
|
|
CollectorAction::Aggregate { config } => {
|
|
let aggregation = self.execute_aggregate(&config).await?;
|
|
json!({
|
|
"action": "aggregate",
|
|
"duration_ms": start.elapsed().as_millis(),
|
|
"result": aggregation
|
|
})
|
|
}
|
|
CollectorAction::Extract { url, selectors } => {
|
|
let extracted = self.execute_extract(&url, &selectors).await?;
|
|
json!({
|
|
"action": "extract",
|
|
"url": url,
|
|
"duration_ms": start.elapsed().as_millis(),
|
|
"data": extracted
|
|
})
|
|
}
|
|
};
|
|
|
|
Ok(HandResult::success(result))
|
|
}
|
|
|
|
fn needs_approval(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
fn check_dependencies(&self) -> Result<Vec<String>> {
|
|
Ok(Vec::new())
|
|
}
|
|
|
|
fn status(&self) -> crate::HandStatus {
|
|
crate::HandStatus::Idle
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_collector_config() {
|
|
let hand = CollectorHand::new();
|
|
assert_eq!(hand.config().id, "collector");
|
|
assert_eq!(hand.config().name, "数据采集器");
|
|
assert!(hand.config().enabled);
|
|
assert!(!hand.config().needs_approval);
|
|
}
|
|
|
|
#[test]
|
|
fn test_output_format_serialize() {
|
|
let formats = vec![
|
|
(OutputFormat::Csv, "\"csv\""),
|
|
(OutputFormat::Markdown, "\"markdown\""),
|
|
(OutputFormat::Json, "\"json\""),
|
|
(OutputFormat::Text, "\"text\""),
|
|
];
|
|
|
|
for (fmt, expected) in formats {
|
|
let serialized = serde_json::to_string(&fmt).unwrap();
|
|
assert_eq!(serialized, expected);
|
|
}
|
|
|
|
// Verify round-trip deserialization
|
|
for json_str in &["\"csv\"", "\"markdown\"", "\"json\"", "\"text\""] {
|
|
let deserialized: OutputFormat = serde_json::from_str(json_str).unwrap();
|
|
let re_serialized = serde_json::to_string(&deserialized).unwrap();
|
|
assert_eq!(&re_serialized, json_str);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_visible_text_basic() {
|
|
let hand = CollectorHand::new();
|
|
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
|
|
let text = hand.extract_visible_text(html);
|
|
assert!(text.contains("Title"), "should contain 'Title', got: {}", text);
|
|
assert!(text.contains("Content here"), "should contain 'Content here', got: {}", text);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_visible_text_strips_scripts() {
|
|
let hand = CollectorHand::new();
|
|
let html = "<html><body><script>alert('xss')</script><p>Safe content</p></body></html>";
|
|
let text = hand.extract_visible_text(html);
|
|
assert!(!text.contains("alert"), "script content should be removed, got: {}", text);
|
|
assert!(!text.contains("xss"), "script content should be removed, got: {}", text);
|
|
assert!(text.contains("Safe content"), "visible content should remain, got: {}", text);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_visible_text_strips_styles() {
|
|
let hand = CollectorHand::new();
|
|
let html = "<html><head><style>body { color: red; }</style></head><body><p>Text</p></body></html>";
|
|
let text = hand.extract_visible_text(html);
|
|
assert!(!text.contains("color"), "style content should be removed, got: {}", text);
|
|
assert!(!text.contains("red"), "style content should be removed, got: {}", text);
|
|
assert!(text.contains("Text"), "visible content should remain, got: {}", text);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_visible_text_empty() {
|
|
let hand = CollectorHand::new();
|
|
let text = hand.extract_visible_text("");
|
|
assert!(text.is_empty(), "empty HTML should produce empty text, got: '{}'", text);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_aggregate_action_empty_urls() {
|
|
let hand = CollectorHand::new();
|
|
let config = AggregationConfig {
|
|
urls: vec![],
|
|
aggregate_fields: vec![],
|
|
};
|
|
|
|
let result = hand.execute_aggregate(&config).await.unwrap();
|
|
let results = result.get("results").unwrap().as_array().unwrap();
|
|
assert_eq!(results.len(), 0, "empty URLs should produce empty results");
|
|
assert_eq!(result.get("source_count").unwrap().as_u64().unwrap(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_collector_action_deserialize() {
|
|
// Collect action
|
|
let collect_json = json!({
|
|
"action": "collect",
|
|
"target": {
|
|
"url": "https://example.com",
|
|
"selector": ".article",
|
|
"fields": { "title": "h1" },
|
|
"maxItems": 10
|
|
},
|
|
"format": "markdown"
|
|
});
|
|
let action: CollectorAction = serde_json::from_value(collect_json).unwrap();
|
|
match action {
|
|
CollectorAction::Collect { target, format } => {
|
|
assert_eq!(target.url, "https://example.com");
|
|
assert_eq!(target.selector.as_deref(), Some(".article"));
|
|
assert_eq!(target.max_items, 10);
|
|
assert!(format.is_some());
|
|
assert_eq!(format.unwrap(), OutputFormat::Markdown);
|
|
}
|
|
_ => panic!("Expected Collect action"),
|
|
}
|
|
|
|
// Aggregate action
|
|
let aggregate_json = json!({
|
|
"action": "aggregate",
|
|
"config": {
|
|
"urls": ["https://a.com", "https://b.com"],
|
|
"aggregateFields": ["title", "content"]
|
|
}
|
|
});
|
|
let action: CollectorAction = serde_json::from_value(aggregate_json).unwrap();
|
|
match action {
|
|
CollectorAction::Aggregate { config } => {
|
|
assert_eq!(config.urls.len(), 2);
|
|
assert_eq!(config.aggregate_fields.len(), 2);
|
|
}
|
|
_ => panic!("Expected Aggregate action"),
|
|
}
|
|
|
|
// Extract action
|
|
let extract_json = json!({
|
|
"action": "extract",
|
|
"url": "https://example.com",
|
|
"selectors": { "title": "h1", "body": "p" }
|
|
});
|
|
let action: CollectorAction = serde_json::from_value(extract_json).unwrap();
|
|
match action {
|
|
CollectorAction::Extract { url, selectors } => {
|
|
assert_eq!(url, "https://example.com");
|
|
assert_eq!(selectors.len(), 2);
|
|
}
|
|
_ => panic!("Expected Extract action"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_collection_target_deserialize() {
|
|
let json = json!({
|
|
"url": "https://example.com/page",
|
|
"selector": ".content",
|
|
"fields": {
|
|
"title": "h1",
|
|
"author": ".author-name"
|
|
},
|
|
"maxItems": 50
|
|
});
|
|
|
|
let target: CollectionTarget = serde_json::from_value(json).unwrap();
|
|
assert_eq!(target.url, "https://example.com/page");
|
|
assert_eq!(target.selector.as_deref(), Some(".content"));
|
|
assert_eq!(target.fields.len(), 2);
|
|
assert_eq!(target.max_items, 50);
|
|
}
|
|
}
|