Files
zclaw_openfang/desktop/src-tauri/src/intelligence/triggers.rs
iven 29fbfbec59
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
feat(intelligence): Phase 2 学习循环基础 — 触发信号 + 经验行业维度
- 新增 triggers.rs: 5 种触发信号(痛点确认/正反馈/复杂工具链/用户纠正/行业模式)
- ExperienceStore 增加 industry_context + source_trigger 字段
- experience.rs format_for_injection 支持行业标签
- intelligence_hooks.rs 集成触发信号评估
- 17 个测试全通过 (7 trigger + 10 experience)
2026-04-12 15:52:29 +08:00

230 lines
7.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 学习触发信号系统
//!
//! 规则驱动的低成本触发判断,在 `post_conversation_hook` 中调用。
//! 有信号时才进入 LLM 经验提取,无信号则零成本跳过。
use super::experience::{CompletionStatus, detect_implicit_feedback};
/// 触发信号类型
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TriggerSignal {
/// 痛点确认confidence >= 0.7
PainConfirmed,
/// 隐式正反馈("谢谢/解决了/对了/好了"
PositiveFeedback,
/// 复杂工具链(单次对话 3+ tool calls
ComplexToolChain,
/// 用户纠正(含"不对/不是/应该是"
UserCorrection,
/// 行业模式(同一行业关键词在多轮出现)
IndustryPattern,
}
/// 触发信号判断的输入
pub struct TriggerContext {
/// 最新用户消息
pub user_message: String,
/// 本轮工具调用次数
pub tool_call_count: usize,
/// 对话中累计的用户消息(用于行业关键词统计)
pub conversation_messages: Vec<String>,
/// 检测到的痛点置信度(如有)
pub pain_confidence: Option<f64>,
/// 用户授权的行业关键词
pub industry_keywords: Vec<String>,
}
/// 判断是否触发学习信号(纯规则,零 LLM 调用)
///
/// 返回匹配到的所有触发信号。空 Vec = 无信号,跳过。
pub fn evaluate_triggers(ctx: &TriggerContext) -> Vec<TriggerSignal> {
let mut signals = Vec::new();
// 1. 痛点确认
if let Some(confidence) = ctx.pain_confidence {
if confidence >= 0.7 {
signals.push(TriggerSignal::PainConfirmed);
}
}
// 2. 隐式正反馈
if let Some(status) = detect_implicit_feedback(&ctx.user_message) {
if status == CompletionStatus::Success {
signals.push(TriggerSignal::PositiveFeedback);
}
}
// 3. 复杂工具链
if ctx.tool_call_count >= 3 {
signals.push(TriggerSignal::ComplexToolChain);
}
// 4. 用户纠正
if is_user_correction(&ctx.user_message) {
signals.push(TriggerSignal::UserCorrection);
}
// 5. 行业模式
if detects_industry_pattern(&ctx.conversation_messages, &ctx.industry_keywords) {
signals.push(TriggerSignal::IndustryPattern);
}
signals
}
/// 检测用户纠正信号
fn is_user_correction(message: &str) -> bool {
let lower = message.to_lowercase();
let correction_patterns = [
"不对", "不是", "应该是", "错了", "重新", "换一个",
"不是这个", "搞错了", "你理解错了", "我的意思是",
];
correction_patterns.iter().any(|p| lower.contains(p))
}
/// 检测行业关键词在多轮对话中反复出现
fn detects_industry_pattern(messages: &[String], industry_keywords: &[String]) -> bool {
if messages.len() < 3 || industry_keywords.is_empty() {
return false;
}
// 统计行业关键词在所有消息中的出现次数
let mut keyword_hits: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
for msg in messages {
let lower = msg.to_lowercase();
for kw in industry_keywords {
if lower.contains(kw.to_lowercase().as_str()) {
*keyword_hits.entry(kw).or_default() += 1;
}
}
}
// 至少有 1 个关键词在 3+ 轮中出现
keyword_hits.values().any(|&count| count >= 3)
}
/// 触发信号的可读描述(用于日志)
pub fn signal_description(signal: &TriggerSignal) -> &'static str {
match signal {
TriggerSignal::PainConfirmed => "痛点确认",
TriggerSignal::PositiveFeedback => "隐式正反馈",
TriggerSignal::ComplexToolChain => "复杂工具链",
TriggerSignal::UserCorrection => "用户纠正",
TriggerSignal::IndustryPattern => "行业模式",
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pain_confirmed_trigger() {
let ctx = TriggerContext {
user_message: "这个报表还是有问题".to_string(),
tool_call_count: 1,
conversation_messages: vec!["报表太慢".into()],
pain_confidence: Some(0.8),
industry_keywords: vec![],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.contains(&TriggerSignal::PainConfirmed));
}
#[test]
fn test_positive_feedback_trigger() {
let ctx = TriggerContext {
user_message: "好了,解决了!谢谢".to_string(),
tool_call_count: 0,
conversation_messages: vec![],
pain_confidence: None,
industry_keywords: vec![],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.contains(&TriggerSignal::PositiveFeedback));
}
#[test]
fn test_complex_tool_chain_trigger() {
let ctx = TriggerContext {
user_message: "帮我处理一下".to_string(),
tool_call_count: 4,
conversation_messages: vec![],
pain_confidence: None,
industry_keywords: vec![],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.contains(&TriggerSignal::ComplexToolChain));
}
#[test]
fn test_user_correction_trigger() {
let ctx = TriggerContext {
user_message: "不对,应该是另一个方案".to_string(),
tool_call_count: 0,
conversation_messages: vec![],
pain_confidence: None,
industry_keywords: vec![],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.contains(&TriggerSignal::UserCorrection));
}
#[test]
fn test_industry_pattern_trigger() {
let ctx = TriggerContext {
user_message: "库存又不够了".to_string(),
tool_call_count: 0,
conversation_messages: vec![
"帮我查库存".into(),
"库存数据怎么看".into(),
"库存预警设置".into(),
"库存又不够了".into(),
],
pain_confidence: None,
industry_keywords: vec!["库存".to_string(), "SKU".to_string(), "GMV".to_string()],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.contains(&TriggerSignal::IndustryPattern));
}
#[test]
fn test_no_trigger() {
let ctx = TriggerContext {
user_message: "今天天气怎么样".to_string(),
tool_call_count: 0,
conversation_messages: vec![],
pain_confidence: None,
industry_keywords: vec![],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.is_empty());
}
#[test]
fn test_multiple_triggers() {
let ctx = TriggerContext {
user_message: "不对,帮我重新做一下库存报表".to_string(),
tool_call_count: 3,
conversation_messages: vec![
"库存报表".into(),
"帮我做库存报表".into(),
"库存报表数据".into(),
"不对,帮我重新做一下库存报表".into(),
],
pain_confidence: Some(0.8),
industry_keywords: vec!["库存".to_string()],
};
let signals = evaluate_triggers(&ctx);
assert!(signals.contains(&TriggerSignal::PainConfirmed));
assert!(signals.contains(&TriggerSignal::ComplexToolChain));
assert!(signals.contains(&TriggerSignal::UserCorrection));
assert!(signals.contains(&TriggerSignal::IndustryPattern));
assert!(signals.len() >= 3);
}
}