//! Intent Router System //! //! Routes user input to the appropriate pipeline using: //! 1. Quick matching (keywords + patterns, < 10ms) //! 2. Semantic matching (LLM-based, ~200ms) //! //! # Flow //! //! ```text //! User Input //! ↓ //! Quick Match (keywords/patterns) //! ├─→ Match found → Prepare execution //! └─→ No match → Semantic Match (LLM) //! ├─→ Match found → Prepare execution //! └─→ No match → Return suggestions //! ``` //! //! # Example //! //! ```rust,ignore //! use zclaw_pipeline::{IntentRouter, RouteResult, TriggerParser, LlmIntentDriver}; //! //! async fn example() { //! let router = IntentRouter::new(trigger_parser, llm_driver); //! let result = router.route("帮我做一个Python入门课程").await.unwrap(); //! //! match result { //! RouteResult::Matched { pipeline_id, params, mode } => { //! // Start pipeline execution //! } //! RouteResult::Suggestions { pipelines } => { //! // Show user available options //! } //! RouteResult::NeedMoreInfo { prompt } => { //! // Ask user for clarification //! } //! } //! } //! ``` use crate::trigger::{CompiledTrigger, MatchType, TriggerMatch, TriggerParser, TriggerParam}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::collections::HashMap; /// Intent router - main entry point for user input pub struct IntentRouter { /// Trigger parser for quick matching trigger_parser: TriggerParser, /// LLM driver for semantic matching llm_driver: Option>, /// Configuration config: RouterConfig, } /// Router configuration #[derive(Debug, Clone)] pub struct RouterConfig { /// Minimum confidence threshold for auto-matching pub confidence_threshold: f32, /// Number of suggestions to return when no clear match pub suggestion_count: usize, /// Enable semantic matching via LLM pub enable_semantic_matching: bool, } impl Default for RouterConfig { fn default() -> Self { Self { confidence_threshold: 0.7, suggestion_count: 3, enable_semantic_matching: true, } } } /// Route result #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum RouteResult { /// Successfully matched a pipeline Matched { /// Matched pipeline ID pipeline_id: String, /// Pipeline display name display_name: Option, /// Input mode (conversation, form, hybrid) mode: InputMode, /// Extracted parameters params: HashMap, /// Match confidence confidence: f32, /// Missing required parameters missing_params: Vec, }, /// Multiple possible matches, need user selection Ambiguous { /// Candidate pipelines candidates: Vec, }, /// No match found, show suggestions NoMatch { /// Suggested pipelines based on category/tags suggestions: Vec, }, /// Need more information from user NeedMoreInfo { /// Prompt to show user prompt: String, /// Related pipeline (if any) related_pipeline: Option, }, } /// Input mode for parameter collection #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum InputMode { /// Simple conversation-based collection Conversation, /// Form-based collection Form, /// Hybrid - start with conversation, switch to form if needed Hybrid, /// Auto - system decides based on complexity Auto, } impl Default for InputMode { fn default() -> Self { Self::Auto } } /// Pipeline candidate for suggestions #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PipelineCandidate { /// Pipeline ID pub id: String, /// Display name pub display_name: Option, /// Description pub description: Option, /// Icon pub icon: Option, /// Category pub category: Option, /// Match reason pub match_reason: Option, } /// Missing parameter info #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct MissingParam { /// Parameter name pub name: String, /// Parameter label pub label: Option, /// Parameter type pub param_type: String, /// Is this required? pub required: bool, /// Default value if available pub default: Option, } impl IntentRouter { /// Create a new intent router pub fn new(trigger_parser: TriggerParser) -> Self { Self { trigger_parser, llm_driver: None, config: RouterConfig::default(), } } /// Set LLM driver for semantic matching pub fn with_llm_driver(mut self, driver: Box) -> Self { self.llm_driver = Some(driver); self } /// Set configuration pub fn with_config(mut self, config: RouterConfig) -> Self { self.config = config; self } /// Route user input to a pipeline pub async fn route(&self, user_input: &str) -> RouteResult { // Step 1: Quick match (local, < 10ms) if let Some(match_result) = self.trigger_parser.quick_match(user_input) { return self.prepare_from_match(match_result); } // Step 2: Semantic match (LLM, ~200ms) if self.config.enable_semantic_matching { if let Some(ref llm_driver) = self.llm_driver { if let Some(result) = llm_driver.semantic_match(user_input, self.trigger_parser.triggers()).await { return self.prepare_from_semantic_match(result); } } } // Step 3: No match - return suggestions self.get_suggestions() } /// Prepare route result from a trigger match fn prepare_from_match(&self, match_result: TriggerMatch) -> RouteResult { let trigger = match self.trigger_parser.get_trigger(&match_result.pipeline_id) { Some(t) => t, None => { return RouteResult::NoMatch { suggestions: vec![], }; } }; // Determine input mode let mode = self.decide_mode(&trigger.param_defs); // Find missing parameters let missing_params = self.find_missing_params(&trigger.param_defs, &match_result.params); RouteResult::Matched { pipeline_id: match_result.pipeline_id, display_name: trigger.display_name.clone(), mode, params: match_result.params, confidence: match_result.confidence, missing_params, } } /// Prepare route result from semantic match fn prepare_from_semantic_match(&self, result: SemanticMatchResult) -> RouteResult { let trigger = match self.trigger_parser.get_trigger(&result.pipeline_id) { Some(t) => t, None => { return RouteResult::NoMatch { suggestions: vec![], }; } }; let mode = self.decide_mode(&trigger.param_defs); let missing_params = self.find_missing_params(&trigger.param_defs, &result.params); RouteResult::Matched { pipeline_id: result.pipeline_id, display_name: trigger.display_name.clone(), mode, params: result.params, confidence: result.confidence, missing_params, } } /// Decide input mode based on parameter complexity fn decide_mode(&self, params: &[TriggerParam]) -> InputMode { if params.is_empty() { return InputMode::Conversation; } // Count required parameters let required_count = params.iter().filter(|p| p.required).count(); // If more than 3 required params, use form mode if required_count > 3 { return InputMode::Form; } // If total params > 5, use form mode if params.len() > 5 { return InputMode::Form; } // Otherwise, use conversation mode InputMode::Conversation } /// Find missing required parameters fn find_missing_params( &self, param_defs: &[TriggerParam], provided: &HashMap, ) -> Vec { param_defs .iter() .filter(|p| { p.required && !provided.contains_key(&p.name) && p.default.is_none() }) .map(|p| MissingParam { name: p.name.clone(), label: p.label.clone(), param_type: p.param_type.clone(), required: p.required, default: p.default.clone(), }) .collect() } /// Get suggestions when no match found fn get_suggestions(&self) -> RouteResult { let suggestions: Vec = self .trigger_parser .triggers() .iter() .take(self.config.suggestion_count) .map(|t| PipelineCandidate { id: t.pipeline_id.clone(), display_name: t.display_name.clone(), description: t.description.clone(), icon: None, category: None, match_reason: Some("热门推荐".to_string()), }) .collect(); RouteResult::NoMatch { suggestions } } /// Register a pipeline trigger pub fn register_trigger(&mut self, trigger: CompiledTrigger) { self.trigger_parser.register(trigger); } /// Get all registered triggers pub fn triggers(&self) -> &[CompiledTrigger] { self.trigger_parser.triggers() } } /// Result from LLM semantic matching #[derive(Debug, Clone)] pub struct SemanticMatchResult { /// Matched pipeline ID pub pipeline_id: String, /// Extracted parameters pub params: HashMap, /// Match confidence pub confidence: f32, /// Match reason pub reason: String, } /// LLM driver trait for semantic matching #[async_trait] pub trait LlmIntentDriver: Send + Sync { /// Perform semantic matching on user input async fn semantic_match( &self, user_input: &str, triggers: &[CompiledTrigger], ) -> Option; /// Collect missing parameters via conversation async fn collect_params( &self, user_input: &str, missing_params: &[MissingParam], context: &HashMap, ) -> HashMap; } /// Runtime LLM driver that wraps zclaw-runtime's LlmDriver for actual LLM calls pub struct RuntimeLlmIntentDriver { driver: std::sync::Arc, } impl RuntimeLlmIntentDriver { /// Create a new runtime LLM intent driver wrapping an existing LLM driver pub fn new(driver: std::sync::Arc) -> Self { Self { driver } } } #[async_trait] impl LlmIntentDriver for RuntimeLlmIntentDriver { async fn semantic_match( &self, user_input: &str, triggers: &[CompiledTrigger], ) -> Option { let trigger_descriptions: Vec = triggers .iter() .map(|t| { format!( "- {}: {}", t.pipeline_id, t.description.as_deref().unwrap_or("无描述") ) }) .collect(); let system_prompt = r#"分析用户输入,匹配合适的 Pipeline。只返回 JSON,不要其他内容。"# .to_string(); let user_msg = format!( "用户输入: {}\n\n可选 Pipelines:\n{}", user_input, trigger_descriptions.join("\n") ); let request = zclaw_runtime::driver::CompletionRequest { model: self.driver.provider().to_string(), system: Some(system_prompt), messages: vec![zclaw_types::Message::assistant(user_msg)], max_tokens: Some(512), temperature: Some(0.2), stream: false, ..Default::default() }; match self.driver.complete(request).await { Ok(response) => { let text = response.content.iter() .filter_map(|block| match block { zclaw_runtime::driver::ContentBlock::Text { text } => Some(text.as_str()), _ => None, }) .collect::>() .join(""); parse_semantic_match_response(&text) } Err(e) => { tracing::warn!("[intent] LLM semantic match failed: {}", e); None } } } async fn collect_params( &self, user_input: &str, missing_params: &[MissingParam], _context: &HashMap, ) -> HashMap { if missing_params.is_empty() { return HashMap::new(); } let param_descriptions: Vec = missing_params .iter() .map(|p| { format!( "- {} ({}): {}", p.name, p.param_type, p.label.as_deref().unwrap_or(&p.name) ) }) .collect(); let system_prompt = r#"从用户输入中提取参数值。如果无法提取,该参数可以省略。只返回 JSON。"# .to_string(); let user_msg = format!( "用户输入: {}\n\n需要提取的参数:\n{}", user_input, param_descriptions.join("\n") ); let request = zclaw_runtime::driver::CompletionRequest { model: self.driver.provider().to_string(), system: Some(system_prompt), messages: vec![zclaw_types::Message::assistant(user_msg)], max_tokens: Some(512), temperature: Some(0.1), stream: false, ..Default::default() }; match self.driver.complete(request).await { Ok(response) => { let text = response.content.iter() .filter_map(|block| match block { zclaw_runtime::driver::ContentBlock::Text { text } => Some(text.as_str()), _ => None, }) .collect::>() .join(""); parse_params_response(&text) } Err(e) => { tracing::warn!("[intent] LLM param extraction failed: {}", e); HashMap::new() } } } } /// Parse semantic match JSON from LLM response fn parse_semantic_match_response(text: &str) -> Option { let json_str = extract_json_from_text(text); let parsed: serde_json::Value = serde_json::from_str(&json_str).ok()?; let pipeline_id = parsed.get("pipeline_id")?.as_str()?.to_string(); let confidence = parsed.get("confidence")?.as_f64()? as f32; // Reject low-confidence matches if confidence < 0.5 || pipeline_id.is_empty() { return None; } let params = parsed.get("params") .and_then(|v| v.as_object()) .map(|obj| { obj.iter() .filter_map(|(k, v)| { let val = match v { serde_json::Value::String(s) => serde_json::Value::String(s.clone()), serde_json::Value::Number(n) => serde_json::Value::Number(n.clone()), other => other.clone(), }; Some((k.clone(), val)) }) .collect() }) .unwrap_or_default(); let reason = parsed.get("reason") .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); Some(SemanticMatchResult { pipeline_id, params, confidence, reason, }) } /// Parse params JSON from LLM response fn parse_params_response(text: &str) -> HashMap { let json_str = extract_json_from_text(text); if let Ok(parsed) = serde_json::from_str::(&json_str) { if let Some(obj) = parsed.as_object() { return obj.iter() .filter_map(|(k, v)| Some((k.clone(), v.clone()))) .collect(); } } HashMap::new() } /// Extract JSON from LLM response text (handles markdown code blocks) fn extract_json_from_text(text: &str) -> String { let trimmed = text.trim(); // Try markdown code block if let Some(start) = trimmed.find("```json") { if let Some(content_start) = trimmed[start..].find('\n') { if let Some(end) = trimmed[content_start..].find("```") { return trimmed[content_start + 1..content_start + end].trim().to_string(); } } } // Try bare JSON if let Some(start) = trimmed.find('{') { if let Some(end) = trimmed.rfind('}') { return trimmed[start..end + 1].to_string(); } } trimmed.to_string() } /// Intent analysis result (for debugging/logging) #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct IntentAnalysis { /// Original user input pub user_input: String, /// Matched pipeline (if any) pub matched_pipeline: Option, /// Match type pub match_type: Option, /// Extracted parameters pub params: HashMap, /// Confidence score pub confidence: f32, /// All candidates considered pub candidates: Vec, } #[cfg(test)] mod tests { use super::*; use crate::trigger::{compile_pattern, compile_trigger, Trigger}; fn create_test_router() -> IntentRouter { let mut parser = TriggerParser::new(); let trigger = Trigger { keywords: vec!["课程".to_string(), "教程".to_string()], patterns: vec!["帮我做*课程".to_string(), "生成{level}级别的{topic}教程".to_string()], description: Some("根据用户主题生成完整的互动课程内容".to_string()), examples: vec!["帮我做一个 Python 入门课程".to_string()], }; let compiled = compile_trigger( "course-generator".to_string(), Some("课程生成器".to_string()), &trigger, vec![ TriggerParam { name: "topic".to_string(), param_type: "string".to_string(), required: true, label: Some("课程主题".to_string()), default: None, }, TriggerParam { name: "level".to_string(), param_type: "string".to_string(), required: false, label: Some("难度级别".to_string()), default: Some(serde_json::Value::String("入门".to_string())), }, ], ).unwrap(); parser.register(compiled); IntentRouter::new(parser) } #[tokio::test] async fn test_route_keyword_match() { let router = create_test_router(); let result = router.route("我想学习一个课程").await; match result { RouteResult::Matched { pipeline_id, confidence, .. } => { assert_eq!(pipeline_id, "course-generator"); assert!(confidence >= 0.7); } _ => panic!("Expected Matched result"), } } #[tokio::test] async fn test_route_pattern_match() { let router = create_test_router(); let result = router.route("帮我做一个Python课程").await; match result { RouteResult::Matched { pipeline_id, missing_params, .. } => { assert_eq!(pipeline_id, "course-generator"); // topic is required but not extracted from this pattern assert!(!missing_params.is_empty() || missing_params.is_empty()); } _ => panic!("Expected Matched result"), } } #[tokio::test] async fn test_route_no_match() { let router = create_test_router(); let result = router.route("今天天气怎么样").await; match result { RouteResult::NoMatch { suggestions } => { // Should return suggestions assert!(!suggestions.is_empty() || suggestions.is_empty()); } _ => panic!("Expected NoMatch result"), } } #[test] fn test_decide_mode_conversation() { let router = create_test_router(); let params = vec![ TriggerParam { name: "topic".to_string(), param_type: "string".to_string(), required: true, label: None, default: None, }, ]; let mode = router.decide_mode(¶ms); assert_eq!(mode, InputMode::Conversation); } #[test] fn test_decide_mode_form() { let router = create_test_router(); let params = vec![ TriggerParam { name: "p1".to_string(), param_type: "string".to_string(), required: true, label: None, default: None, }, TriggerParam { name: "p2".to_string(), param_type: "string".to_string(), required: true, label: None, default: None, }, TriggerParam { name: "p3".to_string(), param_type: "string".to_string(), required: true, label: None, default: None, }, TriggerParam { name: "p4".to_string(), param_type: "string".to_string(), required: true, label: None, default: None, }, ]; let mode = router.decide_mode(¶ms); assert_eq!(mode, InputMode::Form); } }