//! 学习触发信号系统 //! //! 规则驱动的低成本触发判断,在 `post_conversation_hook` 中调用。 //! 有信号时才进入 LLM 经验提取,无信号则零成本跳过。 use super::experience::{CompletionStatus, detect_implicit_feedback}; /// 触发信号类型 #[derive(Debug, Clone, PartialEq, Eq)] pub enum TriggerSignal { /// 痛点确认(confidence >= 0.7) PainConfirmed, /// 隐式正反馈("谢谢/解决了/对了/好了") PositiveFeedback, /// 复杂工具链(单次对话 3+ tool calls) ComplexToolChain, /// 用户纠正(含"不对/不是/应该是") UserCorrection, /// 行业模式(同一行业关键词在多轮出现) IndustryPattern, } /// 触发信号判断的输入 pub struct TriggerContext { /// 最新用户消息 pub user_message: String, /// 本轮工具调用次数 pub tool_call_count: usize, /// 对话中累计的用户消息(用于行业关键词统计) pub conversation_messages: Vec, /// 检测到的痛点置信度(如有) pub pain_confidence: Option, /// 用户授权的行业关键词 pub industry_keywords: Vec, } /// 判断是否触发学习信号(纯规则,零 LLM 调用) /// /// 返回匹配到的所有触发信号。空 Vec = 无信号,跳过。 pub fn evaluate_triggers(ctx: &TriggerContext) -> Vec { let mut signals = Vec::new(); // 1. 痛点确认 if let Some(confidence) = ctx.pain_confidence { if confidence >= 0.7 { signals.push(TriggerSignal::PainConfirmed); } } // 2. 隐式正反馈 if let Some(status) = detect_implicit_feedback(&ctx.user_message) { if status == CompletionStatus::Success { signals.push(TriggerSignal::PositiveFeedback); } } // 3. 复杂工具链 if ctx.tool_call_count >= 3 { signals.push(TriggerSignal::ComplexToolChain); } // 4. 用户纠正 if is_user_correction(&ctx.user_message) { signals.push(TriggerSignal::UserCorrection); } // 5. 行业模式 if detects_industry_pattern(&ctx.conversation_messages, &ctx.industry_keywords) { signals.push(TriggerSignal::IndustryPattern); } signals } /// 检测用户纠正信号 fn is_user_correction(message: &str) -> bool { let lower = message.to_lowercase(); let correction_patterns = [ "不对", "不是", "应该是", "错了", "重新", "换一个", "不是这个", "搞错了", "你理解错了", "我的意思是", ]; correction_patterns.iter().any(|p| lower.contains(p)) } /// 检测行业关键词在多轮对话中反复出现 fn detects_industry_pattern(messages: &[String], industry_keywords: &[String]) -> bool { if messages.len() < 3 || industry_keywords.is_empty() { return false; } // 统计行业关键词在所有消息中的出现次数 let mut keyword_hits: std::collections::HashMap<&str, usize> = std::collections::HashMap::new(); for msg in messages { let lower = msg.to_lowercase(); for kw in industry_keywords { if lower.contains(kw.to_lowercase().as_str()) { *keyword_hits.entry(kw).or_default() += 1; } } } // 至少有 1 个关键词在 3+ 轮中出现 keyword_hits.values().any(|&count| count >= 3) } /// 触发信号的可读描述(用于日志) pub fn signal_description(signal: &TriggerSignal) -> &'static str { match signal { TriggerSignal::PainConfirmed => "痛点确认", TriggerSignal::PositiveFeedback => "隐式正反馈", TriggerSignal::ComplexToolChain => "复杂工具链", TriggerSignal::UserCorrection => "用户纠正", TriggerSignal::IndustryPattern => "行业模式", } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; #[test] fn test_pain_confirmed_trigger() { let ctx = TriggerContext { user_message: "这个报表还是有问题".to_string(), tool_call_count: 1, conversation_messages: vec!["报表太慢".into()], pain_confidence: Some(0.8), industry_keywords: vec![], }; let signals = evaluate_triggers(&ctx); assert!(signals.contains(&TriggerSignal::PainConfirmed)); } #[test] fn test_positive_feedback_trigger() { let ctx = TriggerContext { user_message: "好了,解决了!谢谢".to_string(), tool_call_count: 0, conversation_messages: vec![], pain_confidence: None, industry_keywords: vec![], }; let signals = evaluate_triggers(&ctx); assert!(signals.contains(&TriggerSignal::PositiveFeedback)); } #[test] fn test_complex_tool_chain_trigger() { let ctx = TriggerContext { user_message: "帮我处理一下".to_string(), tool_call_count: 4, conversation_messages: vec![], pain_confidence: None, industry_keywords: vec![], }; let signals = evaluate_triggers(&ctx); assert!(signals.contains(&TriggerSignal::ComplexToolChain)); } #[test] fn test_user_correction_trigger() { let ctx = TriggerContext { user_message: "不对,应该是另一个方案".to_string(), tool_call_count: 0, conversation_messages: vec![], pain_confidence: None, industry_keywords: vec![], }; let signals = evaluate_triggers(&ctx); assert!(signals.contains(&TriggerSignal::UserCorrection)); } #[test] fn test_industry_pattern_trigger() { let ctx = TriggerContext { user_message: "库存又不够了".to_string(), tool_call_count: 0, conversation_messages: vec![ "帮我查库存".into(), "库存数据怎么看".into(), "库存预警设置".into(), "库存又不够了".into(), ], pain_confidence: None, industry_keywords: vec!["库存".to_string(), "SKU".to_string(), "GMV".to_string()], }; let signals = evaluate_triggers(&ctx); assert!(signals.contains(&TriggerSignal::IndustryPattern)); } #[test] fn test_no_trigger() { let ctx = TriggerContext { user_message: "今天天气怎么样".to_string(), tool_call_count: 0, conversation_messages: vec![], pain_confidence: None, industry_keywords: vec![], }; let signals = evaluate_triggers(&ctx); assert!(signals.is_empty()); } #[test] fn test_multiple_triggers() { let ctx = TriggerContext { user_message: "不对,帮我重新做一下库存报表".to_string(), tool_call_count: 3, conversation_messages: vec![ "库存报表".into(), "帮我做库存报表".into(), "库存报表数据".into(), "不对,帮我重新做一下库存报表".into(), ], pain_confidence: Some(0.8), industry_keywords: vec!["库存".to_string()], }; let signals = evaluate_triggers(&ctx); assert!(signals.contains(&TriggerSignal::PainConfirmed)); assert!(signals.contains(&TriggerSignal::ComplexToolChain)); assert!(signals.contains(&TriggerSignal::UserCorrection)); assert!(signals.contains(&TriggerSignal::IndustryPattern)); assert!(signals.len() >= 3); } }