//! Semantic skill router //! //! Routes user queries to the most relevant skill using a hybrid approach: //! 1. TF-IDF based text similarity (always available, no external deps) //! 2. Optional embedding similarity (when an Embedder is configured) //! 3. Optional LLM fallback for ambiguous cases //! //! **Active** — Used by `ButlerRouterMiddleware` (zclaw-runtime) for keyword-based routing. //! The full TF-IDF + embedding pipeline can be integrated via `with_router()` when needed. use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::SkillManifest; use crate::registry::SkillRegistry; /// Embedder trait — abstracts embedding computation. /// /// Default implementation uses TF-IDF. Real embedding APIs (OpenAI, local models) /// are adapted at the kernel layer where zclaw-growth is available. #[async_trait] pub trait Embedder: Send + Sync { /// Compute embedding vector for text. /// Returns `None` if embedding is unavailable (falls back to TF-IDF). async fn embed(&self, text: &str) -> Option>; } /// Runtime LLM intent resolution trait. /// /// When TF-IDF + embedding confidence is below the threshold, the router /// delegates to an LLM to pick the best skill from top candidates. #[async_trait] pub trait RuntimeLlmIntent: Send + Sync { /// Ask the LLM to select the best skill for a query. /// /// Returns `None` if the LLM cannot determine a match (e.g. query is /// genuinely unrelated to all candidates). async fn resolve_skill( &self, query: &str, candidates: &[ScoredCandidate], ) -> Option; } /// No-op embedder that always returns None (forces TF-IDF fallback). pub struct NoOpEmbedder; #[async_trait] impl Embedder for NoOpEmbedder { async fn embed(&self, _text: &str) -> Option> { None } } /// Skill routing result #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RoutingResult { /// Selected skill ID pub skill_id: String, /// Confidence score (0.0 - 1.0) pub confidence: f32, /// Extracted or inferred parameters pub parameters: serde_json::Value, /// Human-readable reasoning pub reasoning: String, } /// Candidate skill with similarity score #[derive(Debug, Clone)] pub struct ScoredCandidate { pub manifest: SkillManifest, pub score: f32, } /// Semantic skill router /// /// Uses a two-phase approach: /// - Phase 1: TF-IDF + optional embedding similarity to find top-K candidates /// - Phase 2: Optional LLM selection for ambiguous queries (threshold-based) pub struct SemanticSkillRouter { /// Skill registry for manifest lookups registry: Arc, /// Embedder (may be NoOp) embedder: Arc, /// Pre-built TF-IDF index over skill descriptions tfidf_index: SkillTfidfIndex, /// Pre-computed embedding vectors (skill_id → embedding) skill_embeddings: HashMap>, /// Confidence threshold for direct selection (skip LLM) confidence_threshold: f32, /// LLM fallback for ambiguous queries (confidence below threshold) llm_fallback: Option>, } impl SemanticSkillRouter { /// Create a new router with the given registry and embedder pub fn new(registry: Arc, embedder: Arc) -> Self { let mut router = Self { registry, embedder, tfidf_index: SkillTfidfIndex::new(), skill_embeddings: HashMap::new(), confidence_threshold: 0.85, llm_fallback: None, }; router.rebuild_index_sync(); router } /// Create with default TF-IDF only (no embedding) pub fn new_tf_idf_only(registry: Arc) -> Self { Self::new(registry, Arc::new(NoOpEmbedder)) } /// Set confidence threshold for direct selection pub fn with_confidence_threshold(mut self, threshold: f32) -> Self { self.confidence_threshold = threshold.clamp(0.0, 1.0); self } /// Set LLM fallback for ambiguous queries (confidence below threshold) pub fn with_llm_fallback(mut self, fallback: Arc) -> Self { self.llm_fallback = Some(fallback); self } /// Rebuild the TF-IDF index from current registry manifests fn rebuild_index_sync(&mut self) { let manifests = self.registry.manifests_snapshot(); self.tfidf_index.clear(); for (_, manifest) in &manifests { let text = Self::skill_text(manifest); self.tfidf_index.add_document(manifest.id.to_string(), &text); } } /// Rebuild index and pre-compute embeddings (async) pub async fn rebuild_index(&mut self) { let manifests = self.registry.manifests_snapshot(); self.tfidf_index.clear(); self.skill_embeddings.clear(); // Phase 1: Build TF-IDF index for (_, manifest) in &manifests { let text = Self::skill_text(manifest); self.tfidf_index.add_document(manifest.id.to_string(), &text); } // Phase 2: Pre-compute embeddings for (_, manifest) in &manifests { let text = Self::skill_text(manifest); if let Some(vec) = self.embedder.embed(&text).await { self.skill_embeddings.insert(manifest.id.to_string(), vec); } } tracing::info!( "[SemanticSkillRouter] Index rebuilt: {} skills, {} embeddings", manifests.len(), self.skill_embeddings.len() ); } /// Build searchable text from a skill manifest fn skill_text(manifest: &SkillManifest) -> String { let mut parts = vec![ manifest.name.clone(), manifest.description.clone(), ]; parts.extend(manifest.triggers.iter().cloned()); parts.extend(manifest.capabilities.iter().cloned()); parts.extend(manifest.tags.iter().cloned()); if let Some(ref cat) = manifest.category { parts.push(cat.clone()); } parts.join(" ") } /// Retrieve top-K candidate skills for a query pub async fn retrieve_candidates(&self, query: &str, top_k: usize) -> Vec { let manifests = self.registry.manifests_snapshot(); if manifests.is_empty() { return Vec::new(); } let mut scored: Vec = Vec::new(); // Try embedding-based scoring first let query_embedding = self.embedder.embed(query).await; for (skill_id, manifest) in &manifests { let tfidf_score = self.tfidf_index.score(query, &skill_id.to_string()); let final_score = if let Some(ref q_emb) = query_embedding { // Hybrid: embedding (70%) + TF-IDF (30%) if let Some(s_emb) = self.skill_embeddings.get(&skill_id.to_string()) { let emb_sim = cosine_similarity(q_emb, s_emb); emb_sim * 0.7 + tfidf_score * 0.3 } else { tfidf_score } } else { tfidf_score }; scored.push(ScoredCandidate { manifest: manifest.clone(), score: final_score, }); } // Sort descending by score scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); scored.truncate(top_k); scored } /// Route a query to the best matching skill. /// /// Returns `None` if no skill matches well enough. /// If top candidate exceeds `confidence_threshold`, returns directly. /// Otherwise, if an LLM fallback is configured, delegates to it for final selection. pub async fn route(&self, query: &str) -> Option { let candidates = self.retrieve_candidates(query, 3).await; if candidates.is_empty() { return None; } let best = &candidates[0]; // If score is very low, don't route even with LLM if best.score < 0.1 { return None; } // High confidence → return directly if best.score >= self.confidence_threshold { return Some(RoutingResult { skill_id: best.manifest.id.to_string(), confidence: best.score, parameters: serde_json::json!({}), reasoning: format!("High semantic match ({:.0}%)", best.score * 100.0), }); } // Medium confidence → try LLM fallback if available if let Some(ref llm) = self.llm_fallback { if let Some(result) = llm.resolve_skill(query, &candidates).await { tracing::debug!( "[SemanticSkillRouter] LLM fallback selected '{}' (original top: '{}' at {:.0}%)", result.skill_id, best.manifest.id, best.score * 100.0 ); return Some(result); } } // No LLM fallback or LLM couldn't decide → return best TF-IDF/embedding match Some(RoutingResult { skill_id: best.manifest.id.to_string(), confidence: best.score, parameters: serde_json::json!({}), reasoning: format!( "Best match ({:.0}%) — below threshold, no LLM refinement", best.score * 100.0 ), }) } /// Get index stats pub fn stats(&self) -> RouterStats { RouterStats { indexed_skills: self.tfidf_index.document_count(), embedding_count: self.skill_embeddings.len(), confidence_threshold: self.confidence_threshold, } } } /// Router statistics #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RouterStats { pub indexed_skills: usize, pub embedding_count: usize, pub confidence_threshold: f32, } /// Compute cosine similarity between two vectors pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { if a.is_empty() || b.is_empty() || a.len() != b.len() { return 0.0; } let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); let denom = norm_a * norm_b; if denom < 1e-10 { 0.0 } else { (dot / denom).clamp(0.0, 1.0) } } // --------------------------------------------------------------------------- // TF-IDF Index (lightweight, no external deps) // --------------------------------------------------------------------------- /// Lightweight TF-IDF index for skill descriptions struct SkillTfidfIndex { /// Per-document term frequencies: doc_id → (term → tf) doc_tfs: HashMap>, /// Document frequency: term → number of docs containing it doc_freq: HashMap, /// Total documents total_docs: usize, /// Stop words stop_words: std::collections::HashSet, } impl SkillTfidfIndex { fn new() -> Self { Self { doc_tfs: HashMap::new(), doc_freq: HashMap::new(), total_docs: 0, stop_words: Self::default_stop_words(), } } fn clear(&mut self) { self.doc_tfs.clear(); self.doc_freq.clear(); self.total_docs = 0; } fn document_count(&self) -> usize { self.total_docs } fn add_document(&mut self, doc_id: String, text: &str) { let tokens = self.tokenize(text); if tokens.is_empty() { return; } // Compute TF let mut tf = HashMap::new(); let total = tokens.len() as f32; for token in &tokens { *tf.entry(token.clone()).or_insert(0.0) += 1.0; } for count in tf.values_mut() { *count /= total; } // Update document frequency let unique: std::collections::HashSet<_> = tokens.into_iter().collect(); for term in &unique { *self.doc_freq.entry(term.clone()).or_insert(0) += 1; } self.total_docs += 1; self.doc_tfs.insert(doc_id, tf); } /// Score a query against a specific document fn score(&self, query: &str, doc_id: &str) -> f32 { let query_tokens = self.tokenize(query); if query_tokens.is_empty() { return 0.0; } let doc_tf = match self.doc_tfs.get(doc_id) { Some(tf) => tf, None => return 0.0, }; // Compute query TF-IDF vector let mut query_vec = HashMap::new(); let q_total = query_tokens.len() as f32; let mut q_tf = HashMap::new(); for token in &query_tokens { *q_tf.entry(token.clone()).or_insert(0.0) += 1.0; } for (term, tf_val) in &q_tf { let idf = self.idf(term); query_vec.insert(term.clone(), (tf_val / q_total) * idf); } // Compute doc TF-IDF vector (on the fly) let mut doc_vec = HashMap::new(); for (term, tf_val) in doc_tf { let idf = self.idf(term); doc_vec.insert(term.clone(), tf_val * idf); } // Cosine similarity Self::cosine_sim_maps(&query_vec, &doc_vec) } fn idf(&self, term: &str) -> f32 { let df = self.doc_freq.get(term).copied().unwrap_or(0); if df == 0 || self.total_docs == 0 { return 0.0; } ((self.total_docs as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0 } fn tokenize(&self, text: &str) -> Vec { let lower = text.to_lowercase(); let segments = lower.split(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) .collect::>(); let mut tokens = Vec::new(); for segment in &segments { let chars: Vec = segment.chars().collect(); // Check if segment contains CJK characters let has_cjk = chars.iter().any(|&c| Self::is_cjk(c)); if has_cjk && chars.len() >= 2 { // CJK: generate character bigrams (e.g. "财报解读" → ["财报", "报解", "解读"]) for window in chars.windows(2) { let bigram = format!("{}{}", window[0], window[1]); if !self.stop_words.contains(&bigram) { tokens.push(bigram); } } // Also add individual CJK chars as unigrams for shorter queries if chars.len() <= 4 { for &c in &chars { if Self::is_cjk(c) { let s = c.to_string(); if !self.stop_words.contains(&s) { tokens.push(s); } } } } } else if !has_cjk && segment.len() > 1 { // Non-CJK: use as-is (existing behavior) if !self.stop_words.contains(*segment) { tokens.push(segment.to_string()); } } } tokens } /// Check if a character is CJK (Chinese, Japanese, Korean) fn is_cjk(c: char) -> bool { matches!(c, '\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs '\u{3400}'..='\u{4DBF}' | // CJK Extension A '\u{F900}'..='\u{FAFF}' | // CJK Compatibility Ideographs '\u{3040}'..='\u{309F}' | // Hiragana '\u{30A0}'..='\u{30FF}' | // Katakana '\u{AC00}'..='\u{D7AF}' // Hangul Syllables ) } fn cosine_sim_maps(v1: &HashMap, v2: &HashMap) -> f32 { if v1.is_empty() || v2.is_empty() { return 0.0; } let mut dot = 0.0; let mut norm1 = 0.0; let mut norm2 = 0.0; for (k, v) in v1 { norm1 += v * v; if let Some(v2_val) = v2.get(k) { dot += v * v2_val; } } for v in v2.values() { norm2 += v * v; } let denom = (norm1 * norm2).sqrt(); if denom < 1e-10 { 0.0 } else { (dot / denom).clamp(0.0, 1.0) } } fn default_stop_words() -> std::collections::HashSet { [ "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can", "need", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through", "and", "but", "if", "or", "not", "this", "that", "it", "its", "i", "you", "he", "she", "we", "they", "my", "your", "his", "her", "our", ] .iter() .map(|s| s.to_string()) .collect() } } #[cfg(test)] mod tests { use super::*; use crate::{SkillManifest, SkillMode}; use zclaw_types::SkillId; fn make_manifest(id: &str, name: &str, desc: &str, triggers: Vec<&str>) -> SkillManifest { SkillManifest { id: SkillId::new(id), name: name.to_string(), description: desc.to_string(), version: "1.0.0".to_string(), author: None, mode: SkillMode::PromptOnly, capabilities: vec![], input_schema: None, output_schema: None, tags: vec![], category: None, triggers: triggers.into_iter().map(|s| s.to_string()).collect(), tools: vec![], enabled: true, } } #[tokio::test] async fn test_basic_routing() { let registry = Arc::new(SkillRegistry::new()); // Register test skills let finance = make_manifest( "finance-tracker", "财务追踪专家", "财务追踪专家 专注于企业财务数据分析、财报解读、盈利能力评估", vec!["财报", "财务分析"], ); let coder = make_manifest( "senior-developer", "高级开发者", "代码开发、架构设计、代码审查", vec!["代码", "开发"], ); registry.register( Arc::new(crate::runner::PromptOnlySkill::new(finance.clone(), String::new())), finance, ).await; registry.register( Arc::new(crate::runner::PromptOnlySkill::new(coder.clone(), String::new())), coder, ).await; let router = SemanticSkillRouter::new_tf_idf_only(registry); // Route a finance query let result = router.route("分析腾讯财报数据").await; assert!(result.is_some()); let r = result.unwrap(); assert_eq!(r.skill_id, "finance-tracker"); // Route a code query let result2 = router.route("帮我写一段 Rust 代码").await; assert!(result2.is_some()); let r2 = result2.unwrap(); assert_eq!(r2.skill_id, "senior-developer"); } #[tokio::test] async fn test_retrieve_candidates() { let registry = Arc::new(SkillRegistry::new()); let skills = vec![ make_manifest("s1", "Python 开发", "Python 代码开发", vec!["python"]), make_manifest("s2", "Rust 开发", "Rust 系统编程", vec!["rust"]), make_manifest("s3", "财务分析", "财务数据分析", vec!["财务"]), ]; for skill in skills { let m = skill.clone(); registry.register( Arc::new(crate::runner::PromptOnlySkill::new(m.clone(), String::new())), m, ).await; } let router = SemanticSkillRouter::new_tf_idf_only(registry); let candidates = router.retrieve_candidates("Rust 编程", 2).await; assert_eq!(candidates.len(), 2); assert_eq!(candidates[0].manifest.id.as_str(), "s2"); } #[test] fn test_cosine_similarity() { let a = vec![1.0, 0.0, 0.0]; let b = vec![1.0, 0.0, 0.0]; assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001); let c = vec![0.0, 1.0, 0.0]; assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001); } /// Mock LLM fallback that always picks the candidate matching target_skill_id struct MockLlmFallback { target_skill_id: String, } #[async_trait] impl RuntimeLlmIntent for MockLlmFallback { async fn resolve_skill( &self, _query: &str, candidates: &[ScoredCandidate], ) -> Option { let candidate = candidates.iter().find(|c| c.manifest.id.as_str() == self.target_skill_id)?; Some(RoutingResult { skill_id: candidate.manifest.id.to_string(), confidence: 0.75, parameters: serde_json::json!({}), reasoning: "LLM selected this skill".to_string(), }) } } #[tokio::test] async fn test_llm_fallback_invoked_when_below_threshold() { let registry = Arc::new(SkillRegistry::new()); // Register skills with very similar descriptions to force low confidence let s1 = make_manifest("skill-a", "数据分析师", "数据分析和可视化报告", vec!["数据"]); let s2 = make_manifest("skill-b", "数据工程师", "数据管道和 ETL 处理", vec!["数据"]); registry.register( Arc::new(crate::runner::PromptOnlySkill::new(s1.clone(), String::new())), s1, ).await; registry.register( Arc::new(crate::runner::PromptOnlySkill::new(s2.clone(), String::new())), s2, ).await; // Router with impossibly high threshold to force LLM fallback let router = SemanticSkillRouter::new_tf_idf_only(registry) .with_confidence_threshold(2.0) // No TF-IDF score can reach this .with_llm_fallback(Arc::new(MockLlmFallback { target_skill_id: "skill-b".to_string(), })); let result = router.route("数据处理").await; assert!(result.is_some()); let r = result.unwrap(); // LLM fallback picks skill-b regardless of TF-IDF ranking assert_eq!(r.skill_id, "skill-b"); assert_eq!(r.reasoning, "LLM selected this skill"); } #[tokio::test] async fn test_no_llm_fallback_when_high_confidence() { let registry = Arc::new(SkillRegistry::new()); let finance = make_manifest( "finance-tracker", "财务追踪专家", "财务追踪专家 专注于企业财务数据分析、财报解读、盈利能力评估", vec!["财报", "财务分析"], ); registry.register( Arc::new(crate::runner::PromptOnlySkill::new(finance.clone(), String::new())), finance, ).await; // Router with LLM fallback that would pick wrong answer — but high TF-IDF should skip LLM let router = SemanticSkillRouter::new_tf_idf_only(registry) .with_confidence_threshold(0.3) // Low threshold → TF-IDF should exceed it .with_llm_fallback(Arc::new(MockLlmFallback { target_skill_id: "nonexistent".to_string(), })); let result = router.route("分析腾讯财报数据").await; assert!(result.is_some()); let r = result.unwrap(); assert_eq!(r.skill_id, "finance-tracker"); // Should NOT be LLM reasoning assert!(r.reasoning.contains("High semantic match")); } #[tokio::test] async fn test_no_llm_fallback_returns_best_match() { let registry = Arc::new(SkillRegistry::new()); let s1 = make_manifest("skill-x", "数据分析师", "数据分析和可视化报告", vec!["数据"]); registry.register( Arc::new(crate::runner::PromptOnlySkill::new(s1.clone(), String::new())), s1, ).await; // No LLM fallback configured let router = SemanticSkillRouter::new_tf_idf_only(registry) .with_confidence_threshold(0.99); let result = router.route("数据分析").await; assert!(result.is_some()); // Should still return best TF-IDF match even below threshold assert_eq!(result.unwrap().skill_id, "skill-x"); } }