Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- 新增 triggers.rs: 5 种触发信号(痛点确认/正反馈/复杂工具链/用户纠正/行业模式) - ExperienceStore 增加 industry_context + source_trigger 字段 - experience.rs format_for_injection 支持行业标签 - intelligence_hooks.rs 集成触发信号评估 - 17 个测试全通过 (7 trigger + 10 experience)
230 lines
7.7 KiB
Rust
230 lines
7.7 KiB
Rust
//! 学习触发信号系统
|
||
//!
|
||
//! 规则驱动的低成本触发判断,在 `post_conversation_hook` 中调用。
|
||
//! 有信号时才进入 LLM 经验提取,无信号则零成本跳过。
|
||
|
||
use super::experience::{CompletionStatus, detect_implicit_feedback};
|
||
|
||
/// 触发信号类型
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
pub enum TriggerSignal {
|
||
/// 痛点确认(confidence >= 0.7)
|
||
PainConfirmed,
|
||
/// 隐式正反馈("谢谢/解决了/对了/好了")
|
||
PositiveFeedback,
|
||
/// 复杂工具链(单次对话 3+ tool calls)
|
||
ComplexToolChain,
|
||
/// 用户纠正(含"不对/不是/应该是")
|
||
UserCorrection,
|
||
/// 行业模式(同一行业关键词在多轮出现)
|
||
IndustryPattern,
|
||
}
|
||
|
||
/// 触发信号判断的输入
|
||
pub struct TriggerContext {
|
||
/// 最新用户消息
|
||
pub user_message: String,
|
||
/// 本轮工具调用次数
|
||
pub tool_call_count: usize,
|
||
/// 对话中累计的用户消息(用于行业关键词统计)
|
||
pub conversation_messages: Vec<String>,
|
||
/// 检测到的痛点置信度(如有)
|
||
pub pain_confidence: Option<f64>,
|
||
/// 用户授权的行业关键词
|
||
pub industry_keywords: Vec<String>,
|
||
}
|
||
|
||
/// 判断是否触发学习信号(纯规则,零 LLM 调用)
|
||
///
|
||
/// 返回匹配到的所有触发信号。空 Vec = 无信号,跳过。
|
||
pub fn evaluate_triggers(ctx: &TriggerContext) -> Vec<TriggerSignal> {
|
||
let mut signals = Vec::new();
|
||
|
||
// 1. 痛点确认
|
||
if let Some(confidence) = ctx.pain_confidence {
|
||
if confidence >= 0.7 {
|
||
signals.push(TriggerSignal::PainConfirmed);
|
||
}
|
||
}
|
||
|
||
// 2. 隐式正反馈
|
||
if let Some(status) = detect_implicit_feedback(&ctx.user_message) {
|
||
if status == CompletionStatus::Success {
|
||
signals.push(TriggerSignal::PositiveFeedback);
|
||
}
|
||
}
|
||
|
||
// 3. 复杂工具链
|
||
if ctx.tool_call_count >= 3 {
|
||
signals.push(TriggerSignal::ComplexToolChain);
|
||
}
|
||
|
||
// 4. 用户纠正
|
||
if is_user_correction(&ctx.user_message) {
|
||
signals.push(TriggerSignal::UserCorrection);
|
||
}
|
||
|
||
// 5. 行业模式
|
||
if detects_industry_pattern(&ctx.conversation_messages, &ctx.industry_keywords) {
|
||
signals.push(TriggerSignal::IndustryPattern);
|
||
}
|
||
|
||
signals
|
||
}
|
||
|
||
/// 检测用户纠正信号
|
||
fn is_user_correction(message: &str) -> bool {
|
||
let lower = message.to_lowercase();
|
||
let correction_patterns = [
|
||
"不对", "不是", "应该是", "错了", "重新", "换一个",
|
||
"不是这个", "搞错了", "你理解错了", "我的意思是",
|
||
];
|
||
correction_patterns.iter().any(|p| lower.contains(p))
|
||
}
|
||
|
||
/// 检测行业关键词在多轮对话中反复出现
|
||
fn detects_industry_pattern(messages: &[String], industry_keywords: &[String]) -> bool {
|
||
if messages.len() < 3 || industry_keywords.is_empty() {
|
||
return false;
|
||
}
|
||
|
||
// 统计行业关键词在所有消息中的出现次数
|
||
let mut keyword_hits: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
|
||
for msg in messages {
|
||
let lower = msg.to_lowercase();
|
||
for kw in industry_keywords {
|
||
if lower.contains(kw.to_lowercase().as_str()) {
|
||
*keyword_hits.entry(kw).or_default() += 1;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 至少有 1 个关键词在 3+ 轮中出现
|
||
keyword_hits.values().any(|&count| count >= 3)
|
||
}
|
||
|
||
/// 触发信号的可读描述(用于日志)
|
||
pub fn signal_description(signal: &TriggerSignal) -> &'static str {
|
||
match signal {
|
||
TriggerSignal::PainConfirmed => "痛点确认",
|
||
TriggerSignal::PositiveFeedback => "隐式正反馈",
|
||
TriggerSignal::ComplexToolChain => "复杂工具链",
|
||
TriggerSignal::UserCorrection => "用户纠正",
|
||
TriggerSignal::IndustryPattern => "行业模式",
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_pain_confirmed_trigger() {
|
||
let ctx = TriggerContext {
|
||
user_message: "这个报表还是有问题".to_string(),
|
||
tool_call_count: 1,
|
||
conversation_messages: vec!["报表太慢".into()],
|
||
pain_confidence: Some(0.8),
|
||
industry_keywords: vec![],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.contains(&TriggerSignal::PainConfirmed));
|
||
}
|
||
|
||
#[test]
|
||
fn test_positive_feedback_trigger() {
|
||
let ctx = TriggerContext {
|
||
user_message: "好了,解决了!谢谢".to_string(),
|
||
tool_call_count: 0,
|
||
conversation_messages: vec![],
|
||
pain_confidence: None,
|
||
industry_keywords: vec![],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.contains(&TriggerSignal::PositiveFeedback));
|
||
}
|
||
|
||
#[test]
|
||
fn test_complex_tool_chain_trigger() {
|
||
let ctx = TriggerContext {
|
||
user_message: "帮我处理一下".to_string(),
|
||
tool_call_count: 4,
|
||
conversation_messages: vec![],
|
||
pain_confidence: None,
|
||
industry_keywords: vec![],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.contains(&TriggerSignal::ComplexToolChain));
|
||
}
|
||
|
||
#[test]
|
||
fn test_user_correction_trigger() {
|
||
let ctx = TriggerContext {
|
||
user_message: "不对,应该是另一个方案".to_string(),
|
||
tool_call_count: 0,
|
||
conversation_messages: vec![],
|
||
pain_confidence: None,
|
||
industry_keywords: vec![],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.contains(&TriggerSignal::UserCorrection));
|
||
}
|
||
|
||
#[test]
|
||
fn test_industry_pattern_trigger() {
|
||
let ctx = TriggerContext {
|
||
user_message: "库存又不够了".to_string(),
|
||
tool_call_count: 0,
|
||
conversation_messages: vec![
|
||
"帮我查库存".into(),
|
||
"库存数据怎么看".into(),
|
||
"库存预警设置".into(),
|
||
"库存又不够了".into(),
|
||
],
|
||
pain_confidence: None,
|
||
industry_keywords: vec!["库存".to_string(), "SKU".to_string(), "GMV".to_string()],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.contains(&TriggerSignal::IndustryPattern));
|
||
}
|
||
|
||
#[test]
|
||
fn test_no_trigger() {
|
||
let ctx = TriggerContext {
|
||
user_message: "今天天气怎么样".to_string(),
|
||
tool_call_count: 0,
|
||
conversation_messages: vec![],
|
||
pain_confidence: None,
|
||
industry_keywords: vec![],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_multiple_triggers() {
|
||
let ctx = TriggerContext {
|
||
user_message: "不对,帮我重新做一下库存报表".to_string(),
|
||
tool_call_count: 3,
|
||
conversation_messages: vec![
|
||
"库存报表".into(),
|
||
"帮我做库存报表".into(),
|
||
"库存报表数据".into(),
|
||
"不对,帮我重新做一下库存报表".into(),
|
||
],
|
||
pain_confidence: Some(0.8),
|
||
industry_keywords: vec!["库存".to_string()],
|
||
};
|
||
let signals = evaluate_triggers(&ctx);
|
||
assert!(signals.contains(&TriggerSignal::PainConfirmed));
|
||
assert!(signals.contains(&TriggerSignal::ComplexToolChain));
|
||
assert!(signals.contains(&TriggerSignal::UserCorrection));
|
||
assert!(signals.contains(&TriggerSignal::IndustryPattern));
|
||
assert!(signals.len() >= 3);
|
||
}
|
||
}
|