- ChatResponse 增加 display_hints 字段,Orchestrator 收集 Tool 产生的 DisplayHint - DisplayHint 实现 utoipa::ToSchema - Web ChatResponse 类型同步,DisplayHint 8 种联合类型 - RichMessage 组件:InsightCard/RiskAlert/LabReportCard/TrendChart/PatientProfile/VitalCard - AiSidebar 消息中渲染 display_hints 富消息 - 小程序 AiChatResponse 类型同步
203 lines
6.9 KiB
Rust
203 lines
6.9 KiB
Rust
use std::sync::Arc;
|
||
|
||
use super::registry::ToolRegistry;
|
||
use super::tool::ToolContext;
|
||
use crate::dto::{ChatMessage, ChatMessageRole};
|
||
use crate::error::AiResult;
|
||
use crate::provider::AiProvider;
|
||
|
||
/// 单次 Tool 调用日志
|
||
#[derive(Debug, Clone)]
|
||
pub struct ToolCallLog {
|
||
pub tool_name: String,
|
||
pub duration_ms: u64,
|
||
pub success: bool,
|
||
}
|
||
|
||
/// Agent 运行时参数
|
||
pub struct AgentRunParams {
|
||
pub model: String,
|
||
pub temperature: f32,
|
||
pub max_tokens: u32,
|
||
pub max_iterations: usize,
|
||
/// 可选:累计 Token 预算(input + output),超出后强制结束
|
||
pub token_budget: Option<u32>,
|
||
}
|
||
|
||
impl Default for AgentRunParams {
|
||
fn default() -> Self {
|
||
Self {
|
||
model: "claude-sonnet-4-6".to_string(),
|
||
temperature: 0.7,
|
||
max_tokens: 2048,
|
||
max_iterations: 5,
|
||
token_budget: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Agent Orchestrator — 执行 ReAct 循环
|
||
pub struct AgentOrchestrator {
|
||
provider: Arc<dyn AiProvider>,
|
||
tool_registry: Arc<ToolRegistry>,
|
||
}
|
||
|
||
/// Agent 运行结果
|
||
pub struct AgentRunResult {
|
||
pub reply: String,
|
||
pub total_input_tokens: u32,
|
||
pub total_output_tokens: u32,
|
||
pub iterations: usize,
|
||
pub tool_calls: Vec<ToolCallLog>,
|
||
pub display_hints: Vec<super::tool::DisplayHint>,
|
||
}
|
||
|
||
impl AgentOrchestrator {
|
||
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
|
||
Self {
|
||
provider,
|
||
tool_registry,
|
||
}
|
||
}
|
||
|
||
/// 执行 Agent ReAct 循环
|
||
pub async fn run(
|
||
&self,
|
||
system_prompt: &str,
|
||
messages: &mut Vec<ChatMessage>,
|
||
ctx: &ToolContext,
|
||
params: &AgentRunParams,
|
||
allowed_tools: Option<&std::collections::HashSet<String>>,
|
||
) -> AiResult<AgentRunResult> {
|
||
let tools = match allowed_tools {
|
||
Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed),
|
||
None => self.tool_registry.tool_definitions(),
|
||
};
|
||
let mut iterations = 0;
|
||
let mut total_input_tokens = 0u32;
|
||
let mut total_output_tokens = 0u32;
|
||
let mut tool_call_logs: Vec<ToolCallLog> = Vec::new();
|
||
let mut display_hints: Vec<super::tool::DisplayHint> = Vec::new();
|
||
|
||
loop {
|
||
iterations += 1;
|
||
|
||
let response = self
|
||
.provider
|
||
.generate_with_tools(
|
||
messages.clone(),
|
||
tools.clone(),
|
||
system_prompt,
|
||
¶ms.model,
|
||
params.temperature,
|
||
params.max_tokens,
|
||
)
|
||
.await?;
|
||
|
||
if let Some(ref usage) = response.usage {
|
||
total_input_tokens += usage.input;
|
||
total_output_tokens += usage.output;
|
||
}
|
||
|
||
// 如果没有 tool_calls,Agent 给出最终回复
|
||
let tool_calls = match response.tool_calls {
|
||
Some(tc) if !tc.is_empty() => tc,
|
||
_ => {
|
||
return Ok(AgentRunResult {
|
||
reply: response.content.unwrap_or_default(),
|
||
total_input_tokens,
|
||
total_output_tokens,
|
||
iterations,
|
||
tool_calls: tool_call_logs,
|
||
display_hints,
|
||
});
|
||
}
|
||
};
|
||
|
||
// 达到上限:强制结束
|
||
if iterations >= params.max_iterations {
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::User,
|
||
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
|
||
.to_string(),
|
||
tool_calls: None,
|
||
tool_call_id: None,
|
||
});
|
||
continue;
|
||
}
|
||
|
||
// Token 预算检查:超出后强制结束
|
||
if let Some(budget) = params.token_budget {
|
||
let total = total_input_tokens + total_output_tokens;
|
||
if total >= budget {
|
||
tracing::warn!(
|
||
total_tokens = total,
|
||
budget = budget,
|
||
iterations = iterations,
|
||
"Token budget exhausted, forcing final reply"
|
||
);
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::User,
|
||
content: "(系统提示:Token 预算已用尽,请立即基于已有信息总结回复用户,不要再调用工具)"
|
||
.to_string(),
|
||
tool_calls: None,
|
||
tool_call_id: None,
|
||
});
|
||
continue;
|
||
}
|
||
}
|
||
|
||
// 将 assistant 的 tool_calls 加入消息历史
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::Assistant,
|
||
content: response.content.unwrap_or_default(),
|
||
tool_calls: Some(tool_calls.clone()),
|
||
tool_call_id: None,
|
||
});
|
||
|
||
// 执行每个 Tool Call(受沙箱 allowed_tools 约束)
|
||
for tc in &tool_calls {
|
||
let start = std::time::Instant::now();
|
||
let (tool_result, success, hint) = match self.tool_registry.get(&tc.name) {
|
||
Some(tool) => {
|
||
if let Some(allowed) = allowed_tools {
|
||
if !allowed.contains(tc.name.as_str()) {
|
||
(
|
||
format!("Tool '{}' 在当前角色下不可用", tc.name),
|
||
false,
|
||
None,
|
||
)
|
||
} else {
|
||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||
(result.output, true, result.display_hint)
|
||
}
|
||
} else {
|
||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||
(result.output, true, result.display_hint)
|
||
}
|
||
}
|
||
None => (format!("未知 Tool: {}", tc.name), false, None),
|
||
};
|
||
let duration = start.elapsed();
|
||
|
||
tool_call_logs.push(ToolCallLog {
|
||
tool_name: tc.name.clone(),
|
||
duration_ms: duration.as_millis() as u64,
|
||
success,
|
||
});
|
||
|
||
if let Some(h) = hint {
|
||
display_hints.push(h);
|
||
}
|
||
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::Tool,
|
||
content: tool_result,
|
||
tool_calls: None,
|
||
tool_call_id: Some(tc.id.clone()),
|
||
});
|
||
}
|
||
}
|
||
}
|
||
}
|