|
|
|
|
@@ -0,0 +1,519 @@
|
|
|
|
|
//! Semantic skill router
|
|
|
|
|
//!
|
|
|
|
|
//! Routes user queries to the most relevant skill using a hybrid approach:
|
|
|
|
|
//! 1. TF-IDF based text similarity (always available, no external deps)
|
|
|
|
|
//! 2. Optional embedding similarity (when an Embedder is configured)
|
|
|
|
|
//! 3. Optional LLM fallback for ambiguous cases
|
|
|
|
|
|
|
|
|
|
use std::collections::HashMap;
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
|
|
|
|
|
use crate::SkillManifest;
|
|
|
|
|
use crate::registry::SkillRegistry;
|
|
|
|
|
|
|
|
|
|
/// Embedder trait — abstracts embedding computation.
|
|
|
|
|
///
|
|
|
|
|
/// Default implementation uses TF-IDF. Real embedding APIs (OpenAI, local models)
|
|
|
|
|
/// are adapted at the kernel layer where zclaw-growth is available.
|
|
|
|
|
#[async_trait]
|
|
|
|
|
pub trait Embedder: Send + Sync {
|
|
|
|
|
/// Compute embedding vector for text.
|
|
|
|
|
/// Returns `None` if embedding is unavailable (falls back to TF-IDF).
|
|
|
|
|
async fn embed(&self, text: &str) -> Option<Vec<f32>>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// No-op embedder that always returns None (forces TF-IDF fallback).
|
|
|
|
|
pub struct NoOpEmbedder;
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl Embedder for NoOpEmbedder {
|
|
|
|
|
async fn embed(&self, _text: &str) -> Option<Vec<f32>> {
|
|
|
|
|
None
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Skill routing result
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
|
|
|
pub struct RoutingResult {
|
|
|
|
|
/// Selected skill ID
|
|
|
|
|
pub skill_id: String,
|
|
|
|
|
/// Confidence score (0.0 - 1.0)
|
|
|
|
|
pub confidence: f32,
|
|
|
|
|
/// Extracted or inferred parameters
|
|
|
|
|
pub parameters: serde_json::Value,
|
|
|
|
|
/// Human-readable reasoning
|
|
|
|
|
pub reasoning: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Candidate skill with similarity score
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
|
|
|
pub struct ScoredCandidate {
|
|
|
|
|
pub manifest: SkillManifest,
|
|
|
|
|
pub score: f32,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Semantic skill router
|
|
|
|
|
///
|
|
|
|
|
/// Uses a two-phase approach:
|
|
|
|
|
/// - Phase 1: TF-IDF + optional embedding similarity to find top-K candidates
|
|
|
|
|
/// - Phase 2: Optional LLM selection for ambiguous queries (threshold-based)
|
|
|
|
|
pub struct SemanticSkillRouter {
|
|
|
|
|
/// Skill registry for manifest lookups
|
|
|
|
|
registry: Arc<SkillRegistry>,
|
|
|
|
|
/// Embedder (may be NoOp)
|
|
|
|
|
embedder: Arc<dyn Embedder>,
|
|
|
|
|
/// Pre-built TF-IDF index over skill descriptions
|
|
|
|
|
tfidf_index: SkillTfidfIndex,
|
|
|
|
|
/// Pre-computed embedding vectors (skill_id → embedding)
|
|
|
|
|
skill_embeddings: HashMap<String, Vec<f32>>,
|
|
|
|
|
/// Confidence threshold for direct selection (skip LLM)
|
|
|
|
|
confidence_threshold: f32,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SemanticSkillRouter {
|
|
|
|
|
/// Create a new router with the given registry and embedder
|
|
|
|
|
pub fn new(registry: Arc<SkillRegistry>, embedder: Arc<dyn Embedder>) -> Self {
|
|
|
|
|
let mut router = Self {
|
|
|
|
|
registry,
|
|
|
|
|
embedder,
|
|
|
|
|
tfidf_index: SkillTfidfIndex::new(),
|
|
|
|
|
skill_embeddings: HashMap::new(),
|
|
|
|
|
confidence_threshold: 0.85,
|
|
|
|
|
};
|
|
|
|
|
router.rebuild_index_sync();
|
|
|
|
|
router
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Create with default TF-IDF only (no embedding)
|
|
|
|
|
pub fn new_tf_idf_only(registry: Arc<SkillRegistry>) -> Self {
|
|
|
|
|
Self::new(registry, Arc::new(NoOpEmbedder))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Set confidence threshold for direct selection
|
|
|
|
|
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
|
|
|
|
|
self.confidence_threshold = threshold.clamp(0.0, 1.0);
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Rebuild the TF-IDF index from current registry manifests
|
|
|
|
|
fn rebuild_index_sync(&mut self) {
|
|
|
|
|
let manifests = self.registry.manifests_snapshot();
|
|
|
|
|
self.tfidf_index.clear();
|
|
|
|
|
for (_, manifest) in &manifests {
|
|
|
|
|
let text = Self::skill_text(manifest);
|
|
|
|
|
self.tfidf_index.add_document(manifest.id.to_string(), &text);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Rebuild index and pre-compute embeddings (async)
|
|
|
|
|
pub async fn rebuild_index(&mut self) {
|
|
|
|
|
let manifests = self.registry.manifests_snapshot();
|
|
|
|
|
self.tfidf_index.clear();
|
|
|
|
|
self.skill_embeddings.clear();
|
|
|
|
|
|
|
|
|
|
// Phase 1: Build TF-IDF index
|
|
|
|
|
for (_, manifest) in &manifests {
|
|
|
|
|
let text = Self::skill_text(manifest);
|
|
|
|
|
self.tfidf_index.add_document(manifest.id.to_string(), &text);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Phase 2: Pre-compute embeddings
|
|
|
|
|
for (_, manifest) in &manifests {
|
|
|
|
|
let text = Self::skill_text(manifest);
|
|
|
|
|
if let Some(vec) = self.embedder.embed(&text).await {
|
|
|
|
|
self.skill_embeddings.insert(manifest.id.to_string(), vec);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tracing::info!(
|
|
|
|
|
"[SemanticSkillRouter] Index rebuilt: {} skills, {} embeddings",
|
|
|
|
|
manifests.len(),
|
|
|
|
|
self.skill_embeddings.len()
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Build searchable text from a skill manifest
|
|
|
|
|
fn skill_text(manifest: &SkillManifest) -> String {
|
|
|
|
|
let mut parts = vec![
|
|
|
|
|
manifest.name.clone(),
|
|
|
|
|
manifest.description.clone(),
|
|
|
|
|
];
|
|
|
|
|
parts.extend(manifest.triggers.iter().cloned());
|
|
|
|
|
parts.extend(manifest.capabilities.iter().cloned());
|
|
|
|
|
parts.extend(manifest.tags.iter().cloned());
|
|
|
|
|
if let Some(ref cat) = manifest.category {
|
|
|
|
|
parts.push(cat.clone());
|
|
|
|
|
}
|
|
|
|
|
parts.join(" ")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Retrieve top-K candidate skills for a query
|
|
|
|
|
pub async fn retrieve_candidates(&self, query: &str, top_k: usize) -> Vec<ScoredCandidate> {
|
|
|
|
|
let manifests = self.registry.manifests_snapshot();
|
|
|
|
|
if manifests.is_empty() {
|
|
|
|
|
return Vec::new();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut scored: Vec<ScoredCandidate> = Vec::new();
|
|
|
|
|
|
|
|
|
|
// Try embedding-based scoring first
|
|
|
|
|
let query_embedding = self.embedder.embed(query).await;
|
|
|
|
|
|
|
|
|
|
for (skill_id, manifest) in &manifests {
|
|
|
|
|
let tfidf_score = self.tfidf_index.score(query, &skill_id.to_string());
|
|
|
|
|
|
|
|
|
|
let final_score = if let Some(ref q_emb) = query_embedding {
|
|
|
|
|
// Hybrid: embedding (70%) + TF-IDF (30%)
|
|
|
|
|
if let Some(s_emb) = self.skill_embeddings.get(&skill_id.to_string()) {
|
|
|
|
|
let emb_sim = cosine_similarity(q_emb, s_emb);
|
|
|
|
|
emb_sim * 0.7 + tfidf_score * 0.3
|
|
|
|
|
} else {
|
|
|
|
|
tfidf_score
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
tfidf_score
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
scored.push(ScoredCandidate {
|
|
|
|
|
manifest: manifest.clone(),
|
|
|
|
|
score: final_score,
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sort descending by score
|
|
|
|
|
scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
|
|
|
|
scored.truncate(top_k);
|
|
|
|
|
|
|
|
|
|
scored
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Route a query to the best matching skill.
|
|
|
|
|
///
|
|
|
|
|
/// Returns `None` if no skill matches well enough.
|
|
|
|
|
/// If top candidate exceeds `confidence_threshold`, returns directly.
|
|
|
|
|
/// Otherwise returns top candidate with lower confidence (caller can invoke LLM fallback).
|
|
|
|
|
pub async fn route(&self, query: &str) -> Option<RoutingResult> {
|
|
|
|
|
let candidates = self.retrieve_candidates(query, 3).await;
|
|
|
|
|
|
|
|
|
|
if candidates.is_empty() {
|
|
|
|
|
return None;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let best = &candidates[0];
|
|
|
|
|
|
|
|
|
|
// If score is very low, don't route
|
|
|
|
|
if best.score < 0.1 {
|
|
|
|
|
return None;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let confidence = best.score;
|
|
|
|
|
let reasoning = if confidence >= self.confidence_threshold {
|
|
|
|
|
format!("High semantic match ({:.0}%)", confidence * 100.0)
|
|
|
|
|
} else {
|
|
|
|
|
format!("Best match ({:.0}%) — may need LLM refinement", confidence * 100.0)
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Some(RoutingResult {
|
|
|
|
|
skill_id: best.manifest.id.to_string(),
|
|
|
|
|
confidence,
|
|
|
|
|
parameters: serde_json::json!({}),
|
|
|
|
|
reasoning,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Get index stats
|
|
|
|
|
pub fn stats(&self) -> RouterStats {
|
|
|
|
|
RouterStats {
|
|
|
|
|
indexed_skills: self.tfidf_index.document_count(),
|
|
|
|
|
embedding_count: self.skill_embeddings.len(),
|
|
|
|
|
confidence_threshold: self.confidence_threshold,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Router statistics
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
|
|
|
pub struct RouterStats {
|
|
|
|
|
pub indexed_skills: usize,
|
|
|
|
|
pub embedding_count: usize,
|
|
|
|
|
pub confidence_threshold: f32,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Compute cosine similarity between two vectors
|
|
|
|
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
|
|
|
if a.is_empty() || b.is_empty() || a.len() != b.len() {
|
|
|
|
|
return 0.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
|
|
|
|
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
|
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
|
|
|
|
|
|
|
|
let denom = norm_a * norm_b;
|
|
|
|
|
if denom < 1e-10 {
|
|
|
|
|
0.0
|
|
|
|
|
} else {
|
|
|
|
|
(dot / denom).clamp(0.0, 1.0)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
// TF-IDF Index (lightweight, no external deps)
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
/// Lightweight TF-IDF index for skill descriptions
|
|
|
|
|
struct SkillTfidfIndex {
|
|
|
|
|
/// Per-document term frequencies: doc_id → (term → tf)
|
|
|
|
|
doc_tfs: HashMap<String, HashMap<String, f32>>,
|
|
|
|
|
/// Document frequency: term → number of docs containing it
|
|
|
|
|
doc_freq: HashMap<String, usize>,
|
|
|
|
|
/// Total documents
|
|
|
|
|
total_docs: usize,
|
|
|
|
|
/// Stop words
|
|
|
|
|
stop_words: std::collections::HashSet<String>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SkillTfidfIndex {
|
|
|
|
|
fn new() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
doc_tfs: HashMap::new(),
|
|
|
|
|
doc_freq: HashMap::new(),
|
|
|
|
|
total_docs: 0,
|
|
|
|
|
stop_words: Self::default_stop_words(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn clear(&mut self) {
|
|
|
|
|
self.doc_tfs.clear();
|
|
|
|
|
self.doc_freq.clear();
|
|
|
|
|
self.total_docs = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn document_count(&self) -> usize {
|
|
|
|
|
self.total_docs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn add_document(&mut self, doc_id: String, text: &str) {
|
|
|
|
|
let tokens = self.tokenize(text);
|
|
|
|
|
if tokens.is_empty() {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Compute TF
|
|
|
|
|
let mut tf = HashMap::new();
|
|
|
|
|
let total = tokens.len() as f32;
|
|
|
|
|
for token in &tokens {
|
|
|
|
|
*tf.entry(token.clone()).or_insert(0.0) += 1.0;
|
|
|
|
|
}
|
|
|
|
|
for count in tf.values_mut() {
|
|
|
|
|
*count /= total;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Update document frequency
|
|
|
|
|
let unique: std::collections::HashSet<_> = tokens.into_iter().collect();
|
|
|
|
|
for term in &unique {
|
|
|
|
|
*self.doc_freq.entry(term.clone()).or_insert(0) += 1;
|
|
|
|
|
}
|
|
|
|
|
self.total_docs += 1;
|
|
|
|
|
|
|
|
|
|
self.doc_tfs.insert(doc_id, tf);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Score a query against a specific document
|
|
|
|
|
fn score(&self, query: &str, doc_id: &str) -> f32 {
|
|
|
|
|
let query_tokens = self.tokenize(query);
|
|
|
|
|
if query_tokens.is_empty() {
|
|
|
|
|
return 0.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let doc_tf = match self.doc_tfs.get(doc_id) {
|
|
|
|
|
Some(tf) => tf,
|
|
|
|
|
None => return 0.0,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Compute query TF-IDF vector
|
|
|
|
|
let mut query_vec = HashMap::new();
|
|
|
|
|
let q_total = query_tokens.len() as f32;
|
|
|
|
|
let mut q_tf = HashMap::new();
|
|
|
|
|
for token in &query_tokens {
|
|
|
|
|
*q_tf.entry(token.clone()).or_insert(0.0) += 1.0;
|
|
|
|
|
}
|
|
|
|
|
for (term, tf_val) in &q_tf {
|
|
|
|
|
let idf = self.idf(term);
|
|
|
|
|
query_vec.insert(term.clone(), (tf_val / q_total) * idf);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Compute doc TF-IDF vector (on the fly)
|
|
|
|
|
let mut doc_vec = HashMap::new();
|
|
|
|
|
for (term, tf_val) in doc_tf {
|
|
|
|
|
let idf = self.idf(term);
|
|
|
|
|
doc_vec.insert(term.clone(), tf_val * idf);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Cosine similarity
|
|
|
|
|
Self::cosine_sim_maps(&query_vec, &doc_vec)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn idf(&self, term: &str) -> f32 {
|
|
|
|
|
let df = self.doc_freq.get(term).copied().unwrap_or(0);
|
|
|
|
|
if df == 0 || self.total_docs == 0 {
|
|
|
|
|
return 0.0;
|
|
|
|
|
}
|
|
|
|
|
((self.total_docs as f32 + 1.0) / (df as f32 + 1.0)).ln() + 1.0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn tokenize(&self, text: &str) -> Vec<String> {
|
|
|
|
|
text.to_lowercase()
|
|
|
|
|
.split(|c: char| !c.is_alphanumeric())
|
|
|
|
|
.filter(|s| !s.is_empty() && s.len() > 1 && !self.stop_words.contains(*s))
|
|
|
|
|
.map(|s| s.to_string())
|
|
|
|
|
.collect()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn cosine_sim_maps(v1: &HashMap<String, f32>, v2: &HashMap<String, f32>) -> f32 {
|
|
|
|
|
if v1.is_empty() || v2.is_empty() {
|
|
|
|
|
return 0.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut dot = 0.0;
|
|
|
|
|
let mut norm1 = 0.0;
|
|
|
|
|
let mut norm2 = 0.0;
|
|
|
|
|
|
|
|
|
|
for (k, v) in v1 {
|
|
|
|
|
norm1 += v * v;
|
|
|
|
|
if let Some(v2_val) = v2.get(k) {
|
|
|
|
|
dot += v * v2_val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for v in v2.values() {
|
|
|
|
|
norm2 += v * v;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let denom = (norm1 * norm2).sqrt();
|
|
|
|
|
if denom < 1e-10 {
|
|
|
|
|
0.0
|
|
|
|
|
} else {
|
|
|
|
|
(dot / denom).clamp(0.0, 1.0)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn default_stop_words() -> std::collections::HashSet<String> {
|
|
|
|
|
[
|
|
|
|
|
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
|
|
|
|
|
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
|
|
|
|
"should", "may", "might", "must", "shall", "can", "need", "to", "of",
|
|
|
|
|
"in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
|
|
|
|
|
"and", "but", "if", "or", "not", "this", "that", "it", "its", "i",
|
|
|
|
|
"you", "he", "she", "we", "they", "my", "your", "his", "her", "our",
|
|
|
|
|
]
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|s| s.to_string())
|
|
|
|
|
.collect()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
use super::*;
|
|
|
|
|
use crate::{SkillManifest, SkillMode};
|
|
|
|
|
use zclaw_types::SkillId;
|
|
|
|
|
|
|
|
|
|
fn make_manifest(id: &str, name: &str, desc: &str, triggers: Vec<&str>) -> SkillManifest {
|
|
|
|
|
SkillManifest {
|
|
|
|
|
id: SkillId::new(id),
|
|
|
|
|
name: name.to_string(),
|
|
|
|
|
description: desc.to_string(),
|
|
|
|
|
version: "1.0.0".to_string(),
|
|
|
|
|
author: None,
|
|
|
|
|
mode: SkillMode::PromptOnly,
|
|
|
|
|
capabilities: vec![],
|
|
|
|
|
input_schema: None,
|
|
|
|
|
output_schema: None,
|
|
|
|
|
tags: vec![],
|
|
|
|
|
category: None,
|
|
|
|
|
triggers: triggers.into_iter().map(|s| s.to_string()).collect(),
|
|
|
|
|
enabled: true,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
async fn test_basic_routing() {
|
|
|
|
|
let registry = Arc::new(SkillRegistry::new());
|
|
|
|
|
|
|
|
|
|
// Register test skills
|
|
|
|
|
let finance = make_manifest(
|
|
|
|
|
"finance-tracker",
|
|
|
|
|
"财务追踪专家",
|
|
|
|
|
"财务追踪专家 专注于企业财务数据分析、财报解读、盈利能力评估",
|
|
|
|
|
vec!["财报", "财务分析"],
|
|
|
|
|
);
|
|
|
|
|
let coder = make_manifest(
|
|
|
|
|
"senior-developer",
|
|
|
|
|
"高级开发者",
|
|
|
|
|
"代码开发、架构设计、代码审查",
|
|
|
|
|
vec!["代码", "开发"],
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
registry.register(
|
|
|
|
|
Arc::new(crate::runner::PromptOnlySkill::new(finance.clone(), String::new())),
|
|
|
|
|
finance,
|
|
|
|
|
).await;
|
|
|
|
|
registry.register(
|
|
|
|
|
Arc::new(crate::runner::PromptOnlySkill::new(coder.clone(), String::new())),
|
|
|
|
|
coder,
|
|
|
|
|
).await;
|
|
|
|
|
|
|
|
|
|
let router = SemanticSkillRouter::new_tf_idf_only(registry);
|
|
|
|
|
|
|
|
|
|
// Route a finance query
|
|
|
|
|
let result = router.route("分析腾讯财报数据").await;
|
|
|
|
|
assert!(result.is_some());
|
|
|
|
|
let r = result.unwrap();
|
|
|
|
|
assert_eq!(r.skill_id, "finance-tracker");
|
|
|
|
|
|
|
|
|
|
// Route a code query
|
|
|
|
|
let result2 = router.route("帮我写一段 Rust 代码").await;
|
|
|
|
|
assert!(result2.is_some());
|
|
|
|
|
let r2 = result2.unwrap();
|
|
|
|
|
assert_eq!(r2.skill_id, "senior-developer");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
|
async fn test_retrieve_candidates() {
|
|
|
|
|
let registry = Arc::new(SkillRegistry::new());
|
|
|
|
|
|
|
|
|
|
let skills = vec![
|
|
|
|
|
make_manifest("s1", "Python 开发", "Python 代码开发", vec!["python"]),
|
|
|
|
|
make_manifest("s2", "Rust 开发", "Rust 系统编程", vec!["rust"]),
|
|
|
|
|
make_manifest("s3", "财务分析", "财务数据分析", vec!["财务"]),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
for skill in skills {
|
|
|
|
|
let m = skill.clone();
|
|
|
|
|
registry.register(
|
|
|
|
|
Arc::new(crate::runner::PromptOnlySkill::new(m.clone(), String::new())),
|
|
|
|
|
m,
|
|
|
|
|
).await;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let router = SemanticSkillRouter::new_tf_idf_only(registry);
|
|
|
|
|
let candidates = router.retrieve_candidates("Rust 编程", 2).await;
|
|
|
|
|
|
|
|
|
|
assert_eq!(candidates.len(), 2);
|
|
|
|
|
assert_eq!(candidates[0].manifest.id.as_str(), "s2");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn test_cosine_similarity() {
|
|
|
|
|
let a = vec![1.0, 0.0, 0.0];
|
|
|
|
|
let b = vec![1.0, 0.0, 0.0];
|
|
|
|
|
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
|
|
|
|
|
|
|
|
|
|
let c = vec![0.0, 1.0, 0.0];
|
|
|
|
|
assert!((cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
|
|
|
|
|
}
|
|
|
|
|
}
|