Files
hms/crates/erp-ai/src/agent/orchestrator.rs
iven bcff978ea0 feat(ai): Day 5 — ChatResponse display_hints + Web RichMessage 渲染
- ChatResponse 增加 display_hints 字段,Orchestrator 收集 Tool 产生的 DisplayHint
- DisplayHint 实现 utoipa::ToSchema
- Web ChatResponse 类型同步,DisplayHint 8 种联合类型
- RichMessage 组件:InsightCard/RiskAlert/LabReportCard/TrendChart/PatientProfile/VitalCard
- AiSidebar 消息中渲染 display_hints 富消息
- 小程序 AiChatResponse 类型同步
2026-05-19 11:10:07 +08:00

203 lines
6.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::sync::Arc;
use super::registry::ToolRegistry;
use super::tool::ToolContext;
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::error::AiResult;
use crate::provider::AiProvider;
/// 单次 Tool 调用日志
#[derive(Debug, Clone)]
pub struct ToolCallLog {
pub tool_name: String,
pub duration_ms: u64,
pub success: bool,
}
/// Agent 运行时参数
pub struct AgentRunParams {
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
pub max_iterations: usize,
/// 可选:累计 Token 预算input + output超出后强制结束
pub token_budget: Option<u32>,
}
impl Default for AgentRunParams {
fn default() -> Self {
Self {
model: "claude-sonnet-4-6".to_string(),
temperature: 0.7,
max_tokens: 2048,
max_iterations: 5,
token_budget: None,
}
}
}
/// Agent Orchestrator — 执行 ReAct 循环
pub struct AgentOrchestrator {
provider: Arc<dyn AiProvider>,
tool_registry: Arc<ToolRegistry>,
}
/// Agent 运行结果
pub struct AgentRunResult {
pub reply: String,
pub total_input_tokens: u32,
pub total_output_tokens: u32,
pub iterations: usize,
pub tool_calls: Vec<ToolCallLog>,
pub display_hints: Vec<super::tool::DisplayHint>,
}
impl AgentOrchestrator {
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
Self {
provider,
tool_registry,
}
}
/// 执行 Agent ReAct 循环
pub async fn run(
&self,
system_prompt: &str,
messages: &mut Vec<ChatMessage>,
ctx: &ToolContext,
params: &AgentRunParams,
allowed_tools: Option<&std::collections::HashSet<String>>,
) -> AiResult<AgentRunResult> {
let tools = match allowed_tools {
Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed),
None => self.tool_registry.tool_definitions(),
};
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
let mut tool_call_logs: Vec<ToolCallLog> = Vec::new();
let mut display_hints: Vec<super::tool::DisplayHint> = Vec::new();
loop {
iterations += 1;
let response = self
.provider
.generate_with_tools(
messages.clone(),
tools.clone(),
system_prompt,
&params.model,
params.temperature,
params.max_tokens,
)
.await?;
if let Some(ref usage) = response.usage {
total_input_tokens += usage.input;
total_output_tokens += usage.output;
}
// 如果没有 tool_callsAgent 给出最终回复
let tool_calls = match response.tool_calls {
Some(tc) if !tc.is_empty() => tc,
_ => {
return Ok(AgentRunResult {
reply: response.content.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
iterations,
tool_calls: tool_call_logs,
display_hints,
});
}
};
// 达到上限:强制结束
if iterations >= params.max_iterations {
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
.to_string(),
tool_calls: None,
tool_call_id: None,
});
continue;
}
// Token 预算检查:超出后强制结束
if let Some(budget) = params.token_budget {
let total = total_input_tokens + total_output_tokens;
if total >= budget {
tracing::warn!(
total_tokens = total,
budget = budget,
iterations = iterations,
"Token budget exhausted, forcing final reply"
);
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: "系统提示Token 预算已用尽,请立即基于已有信息总结回复用户,不要再调用工具)"
.to_string(),
tool_calls: None,
tool_call_id: None,
});
continue;
}
}
// 将 assistant 的 tool_calls 加入消息历史
messages.push(ChatMessage {
role: ChatMessageRole::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
});
// 执行每个 Tool Call受沙箱 allowed_tools 约束)
for tc in &tool_calls {
let start = std::time::Instant::now();
let (tool_result, success, hint) = match self.tool_registry.get(&tc.name) {
Some(tool) => {
if let Some(allowed) = allowed_tools {
if !allowed.contains(tc.name.as_str()) {
(
format!("Tool '{}' 在当前角色下不可用", tc.name),
false,
None,
)
} else {
let result = tool.execute(ctx, tc.arguments.clone()).await;
(result.output, true, result.display_hint)
}
} else {
let result = tool.execute(ctx, tc.arguments.clone()).await;
(result.output, true, result.display_hint)
}
}
None => (format!("未知 Tool: {}", tc.name), false, None),
};
let duration = start.elapsed();
tool_call_logs.push(ToolCallLog {
tool_name: tc.name.clone(),
duration_ms: duration.as_millis() as u64,
success,
});
if let Some(h) = hint {
display_hints.push(h);
}
messages.push(ChatMessage {
role: ChatMessageRole::Tool,
content: tool_result,
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
}
}
}