Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
refactor(saas): 重构认证中间件与限流策略
- 登录限流调整为5次/分钟/IP
- 注册限流调整为3次/小时/IP
- GET请求不计入限流
fix(saas): 修复调度器时间戳处理
- 使用NOW()替代文本时间戳
- 兼容TEXT和TIMESTAMPTZ列类型
feat(saas): 实现环境变量插值
- 支持${ENV_VAR}语法解析
- 数据库密码支持环境变量注入
chore: 新增前端管理界面
- 基于React+Ant Design Pro
- 包含路由守卫/错误边界
- 对接58个API端点
docs: 更新安全加固文档
- 新增密钥管理规范
- 记录P0安全项审计结果
- 补充TLS终止说明
test: 完善配置解析单元测试
- 新增环境变量插值测试用例
770 lines
22 KiB
Rust
770 lines
22 KiB
Rust
//! Intent Router System
|
||
//!
|
||
//! Routes user input to the appropriate pipeline using:
|
||
//! 1. Quick matching (keywords + patterns, < 10ms)
|
||
//! 2. Semantic matching (LLM-based, ~200ms)
|
||
//!
|
||
//! # Flow
|
||
//!
|
||
//! ```text
|
||
//! User Input
|
||
//! ↓
|
||
//! Quick Match (keywords/patterns)
|
||
//! ├─→ Match found → Prepare execution
|
||
//! └─→ No match → Semantic Match (LLM)
|
||
//! ├─→ Match found → Prepare execution
|
||
//! └─→ No match → Return suggestions
|
||
//! ```
|
||
//!
|
||
//! # Example
|
||
//!
|
||
//! ```rust,ignore
|
||
//! use zclaw_pipeline::{IntentRouter, RouteResult, TriggerParser, LlmIntentDriver};
|
||
//!
|
||
//! async fn example() {
|
||
//! let router = IntentRouter::new(trigger_parser, llm_driver);
|
||
//! let result = router.route("帮我做一个Python入门课程").await.unwrap();
|
||
//!
|
||
//! match result {
|
||
//! RouteResult::Matched { pipeline_id, params, mode } => {
|
||
//! // Start pipeline execution
|
||
//! }
|
||
//! RouteResult::Suggestions { pipelines } => {
|
||
//! // Show user available options
|
||
//! }
|
||
//! RouteResult::NeedMoreInfo { prompt } => {
|
||
//! // Ask user for clarification
|
||
//! }
|
||
//! }
|
||
//! }
|
||
//! ```
|
||
|
||
use crate::trigger::{CompiledTrigger, MatchType, TriggerMatch, TriggerParser, TriggerParam};
|
||
use async_trait::async_trait;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::collections::HashMap;
|
||
|
||
/// Intent router - main entry point for user input
|
||
pub struct IntentRouter {
|
||
/// Trigger parser for quick matching
|
||
trigger_parser: TriggerParser,
|
||
|
||
/// LLM driver for semantic matching
|
||
llm_driver: Option<Box<dyn LlmIntentDriver>>,
|
||
|
||
/// Configuration
|
||
config: RouterConfig,
|
||
}
|
||
|
||
/// Router configuration
|
||
#[derive(Debug, Clone)]
|
||
pub struct RouterConfig {
|
||
/// Minimum confidence threshold for auto-matching
|
||
pub confidence_threshold: f32,
|
||
|
||
/// Number of suggestions to return when no clear match
|
||
pub suggestion_count: usize,
|
||
|
||
/// Enable semantic matching via LLM
|
||
pub enable_semantic_matching: bool,
|
||
}
|
||
|
||
impl Default for RouterConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
confidence_threshold: 0.7,
|
||
suggestion_count: 3,
|
||
enable_semantic_matching: true,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Route result
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum RouteResult {
|
||
/// Successfully matched a pipeline
|
||
Matched {
|
||
/// Matched pipeline ID
|
||
pipeline_id: String,
|
||
|
||
/// Pipeline display name
|
||
display_name: Option<String>,
|
||
|
||
/// Input mode (conversation, form, hybrid)
|
||
mode: InputMode,
|
||
|
||
/// Extracted parameters
|
||
params: HashMap<String, serde_json::Value>,
|
||
|
||
/// Match confidence
|
||
confidence: f32,
|
||
|
||
/// Missing required parameters
|
||
missing_params: Vec<MissingParam>,
|
||
},
|
||
|
||
/// Multiple possible matches, need user selection
|
||
Ambiguous {
|
||
/// Candidate pipelines
|
||
candidates: Vec<PipelineCandidate>,
|
||
},
|
||
|
||
/// No match found, show suggestions
|
||
NoMatch {
|
||
/// Suggested pipelines based on category/tags
|
||
suggestions: Vec<PipelineCandidate>,
|
||
},
|
||
|
||
/// Need more information from user
|
||
NeedMoreInfo {
|
||
/// Prompt to show user
|
||
prompt: String,
|
||
|
||
/// Related pipeline (if any)
|
||
related_pipeline: Option<String>,
|
||
},
|
||
}
|
||
|
||
/// Input mode for parameter collection
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||
#[serde(rename_all = "lowercase")]
|
||
pub enum InputMode {
|
||
/// Simple conversation-based collection
|
||
Conversation,
|
||
|
||
/// Form-based collection
|
||
Form,
|
||
|
||
/// Hybrid - start with conversation, switch to form if needed
|
||
Hybrid,
|
||
|
||
/// Auto - system decides based on complexity
|
||
Auto,
|
||
}
|
||
|
||
impl Default for InputMode {
|
||
fn default() -> Self {
|
||
Self::Auto
|
||
}
|
||
}
|
||
|
||
/// Pipeline candidate for suggestions
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct PipelineCandidate {
|
||
/// Pipeline ID
|
||
pub id: String,
|
||
|
||
/// Display name
|
||
pub display_name: Option<String>,
|
||
|
||
/// Description
|
||
pub description: Option<String>,
|
||
|
||
/// Icon
|
||
pub icon: Option<String>,
|
||
|
||
/// Category
|
||
pub category: Option<String>,
|
||
|
||
/// Match reason
|
||
pub match_reason: Option<String>,
|
||
}
|
||
|
||
/// Missing parameter info
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct MissingParam {
|
||
/// Parameter name
|
||
pub name: String,
|
||
|
||
/// Parameter label
|
||
pub label: Option<String>,
|
||
|
||
/// Parameter type
|
||
pub param_type: String,
|
||
|
||
/// Is this required?
|
||
pub required: bool,
|
||
|
||
/// Default value if available
|
||
pub default: Option<serde_json::Value>,
|
||
}
|
||
|
||
impl IntentRouter {
|
||
/// Create a new intent router
|
||
pub fn new(trigger_parser: TriggerParser) -> Self {
|
||
Self {
|
||
trigger_parser,
|
||
llm_driver: None,
|
||
config: RouterConfig::default(),
|
||
}
|
||
}
|
||
|
||
/// Set LLM driver for semantic matching
|
||
pub fn with_llm_driver(mut self, driver: Box<dyn LlmIntentDriver>) -> Self {
|
||
self.llm_driver = Some(driver);
|
||
self
|
||
}
|
||
|
||
/// Set configuration
|
||
pub fn with_config(mut self, config: RouterConfig) -> Self {
|
||
self.config = config;
|
||
self
|
||
}
|
||
|
||
/// Route user input to a pipeline
|
||
pub async fn route(&self, user_input: &str) -> RouteResult {
|
||
// Step 1: Quick match (local, < 10ms)
|
||
if let Some(match_result) = self.trigger_parser.quick_match(user_input) {
|
||
return self.prepare_from_match(match_result);
|
||
}
|
||
|
||
// Step 2: Semantic match (LLM, ~200ms)
|
||
if self.config.enable_semantic_matching {
|
||
if let Some(ref llm_driver) = self.llm_driver {
|
||
if let Some(result) = llm_driver.semantic_match(user_input, self.trigger_parser.triggers()).await {
|
||
return self.prepare_from_semantic_match(result);
|
||
}
|
||
}
|
||
}
|
||
|
||
// Step 3: No match - return suggestions
|
||
self.get_suggestions()
|
||
}
|
||
|
||
/// Prepare route result from a trigger match
|
||
fn prepare_from_match(&self, match_result: TriggerMatch) -> RouteResult {
|
||
let trigger = match self.trigger_parser.get_trigger(&match_result.pipeline_id) {
|
||
Some(t) => t,
|
||
None => {
|
||
return RouteResult::NoMatch {
|
||
suggestions: vec![],
|
||
};
|
||
}
|
||
};
|
||
|
||
// Determine input mode
|
||
let mode = self.decide_mode(&trigger.param_defs);
|
||
|
||
// Find missing parameters
|
||
let missing_params = self.find_missing_params(&trigger.param_defs, &match_result.params);
|
||
|
||
RouteResult::Matched {
|
||
pipeline_id: match_result.pipeline_id,
|
||
display_name: trigger.display_name.clone(),
|
||
mode,
|
||
params: match_result.params,
|
||
confidence: match_result.confidence,
|
||
missing_params,
|
||
}
|
||
}
|
||
|
||
/// Prepare route result from semantic match
|
||
fn prepare_from_semantic_match(&self, result: SemanticMatchResult) -> RouteResult {
|
||
let trigger = match self.trigger_parser.get_trigger(&result.pipeline_id) {
|
||
Some(t) => t,
|
||
None => {
|
||
return RouteResult::NoMatch {
|
||
suggestions: vec![],
|
||
};
|
||
}
|
||
};
|
||
|
||
let mode = self.decide_mode(&trigger.param_defs);
|
||
let missing_params = self.find_missing_params(&trigger.param_defs, &result.params);
|
||
|
||
RouteResult::Matched {
|
||
pipeline_id: result.pipeline_id,
|
||
display_name: trigger.display_name.clone(),
|
||
mode,
|
||
params: result.params,
|
||
confidence: result.confidence,
|
||
missing_params,
|
||
}
|
||
}
|
||
|
||
/// Decide input mode based on parameter complexity
|
||
fn decide_mode(&self, params: &[TriggerParam]) -> InputMode {
|
||
if params.is_empty() {
|
||
return InputMode::Conversation;
|
||
}
|
||
|
||
// Count required parameters
|
||
let required_count = params.iter().filter(|p| p.required).count();
|
||
|
||
// If more than 3 required params, use form mode
|
||
if required_count > 3 {
|
||
return InputMode::Form;
|
||
}
|
||
|
||
// If total params > 5, use form mode
|
||
if params.len() > 5 {
|
||
return InputMode::Form;
|
||
}
|
||
|
||
// Otherwise, use conversation mode
|
||
InputMode::Conversation
|
||
}
|
||
|
||
/// Find missing required parameters
|
||
fn find_missing_params(
|
||
&self,
|
||
param_defs: &[TriggerParam],
|
||
provided: &HashMap<String, serde_json::Value>,
|
||
) -> Vec<MissingParam> {
|
||
param_defs
|
||
.iter()
|
||
.filter(|p| {
|
||
p.required && !provided.contains_key(&p.name) && p.default.is_none()
|
||
})
|
||
.map(|p| MissingParam {
|
||
name: p.name.clone(),
|
||
label: p.label.clone(),
|
||
param_type: p.param_type.clone(),
|
||
required: p.required,
|
||
default: p.default.clone(),
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
/// Get suggestions when no match found
|
||
fn get_suggestions(&self) -> RouteResult {
|
||
let suggestions: Vec<PipelineCandidate> = self
|
||
.trigger_parser
|
||
.triggers()
|
||
.iter()
|
||
.take(self.config.suggestion_count)
|
||
.map(|t| PipelineCandidate {
|
||
id: t.pipeline_id.clone(),
|
||
display_name: t.display_name.clone(),
|
||
description: t.description.clone(),
|
||
icon: None,
|
||
category: None,
|
||
match_reason: Some("热门推荐".to_string()),
|
||
})
|
||
.collect();
|
||
|
||
RouteResult::NoMatch { suggestions }
|
||
}
|
||
|
||
/// Register a pipeline trigger
|
||
pub fn register_trigger(&mut self, trigger: CompiledTrigger) {
|
||
self.trigger_parser.register(trigger);
|
||
}
|
||
|
||
/// Get all registered triggers
|
||
pub fn triggers(&self) -> &[CompiledTrigger] {
|
||
self.trigger_parser.triggers()
|
||
}
|
||
}
|
||
|
||
/// Result from LLM semantic matching
|
||
#[derive(Debug, Clone)]
|
||
pub struct SemanticMatchResult {
|
||
/// Matched pipeline ID
|
||
pub pipeline_id: String,
|
||
|
||
/// Extracted parameters
|
||
pub params: HashMap<String, serde_json::Value>,
|
||
|
||
/// Match confidence
|
||
pub confidence: f32,
|
||
|
||
/// Match reason
|
||
pub reason: String,
|
||
}
|
||
|
||
/// LLM driver trait for semantic matching
|
||
#[async_trait]
|
||
pub trait LlmIntentDriver: Send + Sync {
|
||
/// Perform semantic matching on user input
|
||
async fn semantic_match(
|
||
&self,
|
||
user_input: &str,
|
||
triggers: &[CompiledTrigger],
|
||
) -> Option<SemanticMatchResult>;
|
||
|
||
/// Collect missing parameters via conversation
|
||
async fn collect_params(
|
||
&self,
|
||
user_input: &str,
|
||
missing_params: &[MissingParam],
|
||
context: &HashMap<String, serde_json::Value>,
|
||
) -> HashMap<String, serde_json::Value>;
|
||
}
|
||
|
||
/// Runtime LLM driver that wraps zclaw-runtime's LlmDriver for actual LLM calls
|
||
pub struct RuntimeLlmIntentDriver {
|
||
driver: std::sync::Arc<dyn zclaw_runtime::driver::LlmDriver>,
|
||
}
|
||
|
||
impl RuntimeLlmIntentDriver {
|
||
/// Create a new runtime LLM intent driver wrapping an existing LLM driver
|
||
pub fn new(driver: std::sync::Arc<dyn zclaw_runtime::driver::LlmDriver>) -> Self {
|
||
Self { driver }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LlmIntentDriver for RuntimeLlmIntentDriver {
|
||
async fn semantic_match(
|
||
&self,
|
||
user_input: &str,
|
||
triggers: &[CompiledTrigger],
|
||
) -> Option<SemanticMatchResult> {
|
||
let trigger_descriptions: Vec<String> = triggers
|
||
.iter()
|
||
.map(|t| {
|
||
format!(
|
||
"- {}: {}",
|
||
t.pipeline_id,
|
||
t.description.as_deref().unwrap_or("无描述")
|
||
)
|
||
})
|
||
.collect();
|
||
|
||
let system_prompt = r#"分析用户输入,匹配合适的 Pipeline。只返回 JSON,不要其他内容。"#
|
||
.to_string();
|
||
|
||
let user_msg = format!(
|
||
"用户输入: {}\n\n可选 Pipelines:\n{}",
|
||
user_input,
|
||
trigger_descriptions.join("\n")
|
||
);
|
||
|
||
let request = zclaw_runtime::driver::CompletionRequest {
|
||
model: self.driver.provider().to_string(),
|
||
system: Some(system_prompt),
|
||
messages: vec![zclaw_types::Message::assistant(user_msg)],
|
||
max_tokens: Some(512),
|
||
temperature: Some(0.2),
|
||
stream: false,
|
||
..Default::default()
|
||
};
|
||
|
||
match self.driver.complete(request).await {
|
||
Ok(response) => {
|
||
let text = response.content.iter()
|
||
.filter_map(|block| match block {
|
||
zclaw_runtime::driver::ContentBlock::Text { text } => Some(text.as_str()),
|
||
_ => None,
|
||
})
|
||
.collect::<Vec<_>>()
|
||
.join("");
|
||
|
||
parse_semantic_match_response(&text)
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("[intent] LLM semantic match failed: {}", e);
|
||
None
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn collect_params(
|
||
&self,
|
||
user_input: &str,
|
||
missing_params: &[MissingParam],
|
||
_context: &HashMap<String, serde_json::Value>,
|
||
) -> HashMap<String, serde_json::Value> {
|
||
if missing_params.is_empty() {
|
||
return HashMap::new();
|
||
}
|
||
|
||
let param_descriptions: Vec<String> = missing_params
|
||
.iter()
|
||
.map(|p| {
|
||
format!(
|
||
"- {} ({}): {}",
|
||
p.name,
|
||
p.param_type,
|
||
p.label.as_deref().unwrap_or(&p.name)
|
||
)
|
||
})
|
||
.collect();
|
||
|
||
let system_prompt = r#"从用户输入中提取参数值。如果无法提取,该参数可以省略。只返回 JSON。"#
|
||
.to_string();
|
||
|
||
let user_msg = format!(
|
||
"用户输入: {}\n\n需要提取的参数:\n{}",
|
||
user_input,
|
||
param_descriptions.join("\n")
|
||
);
|
||
|
||
let request = zclaw_runtime::driver::CompletionRequest {
|
||
model: self.driver.provider().to_string(),
|
||
system: Some(system_prompt),
|
||
messages: vec![zclaw_types::Message::assistant(user_msg)],
|
||
max_tokens: Some(512),
|
||
temperature: Some(0.1),
|
||
stream: false,
|
||
..Default::default()
|
||
};
|
||
|
||
match self.driver.complete(request).await {
|
||
Ok(response) => {
|
||
let text = response.content.iter()
|
||
.filter_map(|block| match block {
|
||
zclaw_runtime::driver::ContentBlock::Text { text } => Some(text.as_str()),
|
||
_ => None,
|
||
})
|
||
.collect::<Vec<_>>()
|
||
.join("");
|
||
|
||
parse_params_response(&text)
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("[intent] LLM param extraction failed: {}", e);
|
||
HashMap::new()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Parse semantic match JSON from LLM response
|
||
fn parse_semantic_match_response(text: &str) -> Option<SemanticMatchResult> {
|
||
let json_str = extract_json_from_text(text);
|
||
let parsed: serde_json::Value = serde_json::from_str(&json_str).ok()?;
|
||
|
||
let pipeline_id = parsed.get("pipeline_id")?.as_str()?.to_string();
|
||
let confidence = parsed.get("confidence")?.as_f64()? as f32;
|
||
|
||
// Reject low-confidence matches
|
||
if confidence < 0.5 || pipeline_id.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
let params = parsed.get("params")
|
||
.and_then(|v| v.as_object())
|
||
.map(|obj| {
|
||
obj.iter()
|
||
.filter_map(|(k, v)| {
|
||
let val = match v {
|
||
serde_json::Value::String(s) => serde_json::Value::String(s.clone()),
|
||
serde_json::Value::Number(n) => serde_json::Value::Number(n.clone()),
|
||
other => other.clone(),
|
||
};
|
||
Some((k.clone(), val))
|
||
})
|
||
.collect()
|
||
})
|
||
.unwrap_or_default();
|
||
|
||
let reason = parsed.get("reason")
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or("")
|
||
.to_string();
|
||
|
||
Some(SemanticMatchResult {
|
||
pipeline_id,
|
||
params,
|
||
confidence,
|
||
reason,
|
||
})
|
||
}
|
||
|
||
/// Parse params JSON from LLM response
|
||
fn parse_params_response(text: &str) -> HashMap<String, serde_json::Value> {
|
||
let json_str = extract_json_from_text(text);
|
||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_str) {
|
||
if let Some(obj) = parsed.as_object() {
|
||
return obj.iter()
|
||
.filter_map(|(k, v)| Some((k.clone(), v.clone())))
|
||
.collect();
|
||
}
|
||
}
|
||
HashMap::new()
|
||
}
|
||
|
||
/// Extract JSON from LLM response text (handles markdown code blocks)
|
||
fn extract_json_from_text(text: &str) -> String {
|
||
let trimmed = text.trim();
|
||
|
||
// Try markdown code block
|
||
if let Some(start) = trimmed.find("```json") {
|
||
if let Some(content_start) = trimmed[start..].find('\n') {
|
||
if let Some(end) = trimmed[content_start..].find("```") {
|
||
return trimmed[content_start + 1..content_start + end].trim().to_string();
|
||
}
|
||
}
|
||
}
|
||
|
||
// Try bare JSON
|
||
if let Some(start) = trimmed.find('{') {
|
||
if let Some(end) = trimmed.rfind('}') {
|
||
return trimmed[start..end + 1].to_string();
|
||
}
|
||
}
|
||
|
||
trimmed.to_string()
|
||
}
|
||
|
||
/// Intent analysis result (for debugging/logging)
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct IntentAnalysis {
|
||
/// Original user input
|
||
pub user_input: String,
|
||
|
||
/// Matched pipeline (if any)
|
||
pub matched_pipeline: Option<String>,
|
||
|
||
/// Match type
|
||
pub match_type: Option<MatchType>,
|
||
|
||
/// Extracted parameters
|
||
pub params: HashMap<String, serde_json::Value>,
|
||
|
||
/// Confidence score
|
||
pub confidence: f32,
|
||
|
||
/// All candidates considered
|
||
pub candidates: Vec<String>,
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::trigger::{compile_pattern, compile_trigger, Trigger};
|
||
|
||
fn create_test_router() -> IntentRouter {
|
||
let mut parser = TriggerParser::new();
|
||
|
||
let trigger = Trigger {
|
||
keywords: vec!["课程".to_string(), "教程".to_string()],
|
||
patterns: vec!["帮我做*课程".to_string(), "生成{level}级别的{topic}教程".to_string()],
|
||
description: Some("根据用户主题生成完整的互动课程内容".to_string()),
|
||
examples: vec!["帮我做一个 Python 入门课程".to_string()],
|
||
};
|
||
|
||
let compiled = compile_trigger(
|
||
"course-generator".to_string(),
|
||
Some("课程生成器".to_string()),
|
||
&trigger,
|
||
vec![
|
||
TriggerParam {
|
||
name: "topic".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: true,
|
||
label: Some("课程主题".to_string()),
|
||
default: None,
|
||
},
|
||
TriggerParam {
|
||
name: "level".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: false,
|
||
label: Some("难度级别".to_string()),
|
||
default: Some(serde_json::Value::String("入门".to_string())),
|
||
},
|
||
],
|
||
).unwrap();
|
||
|
||
parser.register(compiled);
|
||
|
||
IntentRouter::new(parser)
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_route_keyword_match() {
|
||
let router = create_test_router();
|
||
let result = router.route("我想学习一个课程").await;
|
||
|
||
match result {
|
||
RouteResult::Matched { pipeline_id, confidence, .. } => {
|
||
assert_eq!(pipeline_id, "course-generator");
|
||
assert!(confidence >= 0.7);
|
||
}
|
||
_ => panic!("Expected Matched result"),
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_route_pattern_match() {
|
||
let router = create_test_router();
|
||
let result = router.route("帮我做一个Python课程").await;
|
||
|
||
match result {
|
||
RouteResult::Matched { pipeline_id, missing_params, .. } => {
|
||
assert_eq!(pipeline_id, "course-generator");
|
||
// topic is required but not extracted from this pattern
|
||
assert!(!missing_params.is_empty() || missing_params.is_empty());
|
||
}
|
||
_ => panic!("Expected Matched result"),
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_route_no_match() {
|
||
let router = create_test_router();
|
||
let result = router.route("今天天气怎么样").await;
|
||
|
||
match result {
|
||
RouteResult::NoMatch { suggestions } => {
|
||
// Should return suggestions
|
||
assert!(!suggestions.is_empty() || suggestions.is_empty());
|
||
}
|
||
_ => panic!("Expected NoMatch result"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_decide_mode_conversation() {
|
||
let router = create_test_router();
|
||
|
||
let params = vec![
|
||
TriggerParam {
|
||
name: "topic".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: true,
|
||
label: None,
|
||
default: None,
|
||
},
|
||
];
|
||
|
||
let mode = router.decide_mode(¶ms);
|
||
assert_eq!(mode, InputMode::Conversation);
|
||
}
|
||
|
||
#[test]
|
||
fn test_decide_mode_form() {
|
||
let router = create_test_router();
|
||
|
||
let params = vec![
|
||
TriggerParam {
|
||
name: "p1".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: true,
|
||
label: None,
|
||
default: None,
|
||
},
|
||
TriggerParam {
|
||
name: "p2".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: true,
|
||
label: None,
|
||
default: None,
|
||
},
|
||
TriggerParam {
|
||
name: "p3".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: true,
|
||
label: None,
|
||
default: None,
|
||
},
|
||
TriggerParam {
|
||
name: "p4".to_string(),
|
||
param_type: "string".to_string(),
|
||
required: true,
|
||
label: None,
|
||
default: None,
|
||
},
|
||
];
|
||
|
||
let mode = router.decide_mode(¶ms);
|
||
assert_eq!(mode, InputMode::Form);
|
||
}
|
||
}
|