Files
hms/crates/erp-ai/src/agent/orchestrator.rs
iven 5ba28ea349 feat(ai): Phase 1B 角色沙箱 — 三级权限隔离 + Tool 过滤 + 输出控制
- 新增 agent/sandbox.rs: UserRole/SandboxConfig/OutputFilter 三级模型
- resolve_role() 从 JWT roles 解析为 Patient/MedicalStaff/Admin
- ToolRegistry.tool_definitions_filtered() 按角色白名单过滤
- orchestrator.run() 新增 allowed_tools 参数,Tool 执行时二次校验
- chat_handler 集成沙箱:角色 Prompt 后缀 + 患者免责声明追加
2026-05-18 23:28:30 +08:00

150 lines
4.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::sync::Arc;
use super::registry::ToolRegistry;
use super::tool::ToolContext;
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::error::AiResult;
use crate::provider::AiProvider;
/// Agent 运行时参数
pub struct AgentRunParams {
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
pub max_iterations: usize,
}
impl Default for AgentRunParams {
fn default() -> Self {
Self {
model: "claude-sonnet-4-6".to_string(),
temperature: 0.7,
max_tokens: 2048,
max_iterations: 5,
}
}
}
/// Agent Orchestrator — 执行 ReAct 循环
pub struct AgentOrchestrator {
provider: Arc<dyn AiProvider>,
tool_registry: Arc<ToolRegistry>,
}
/// Agent 运行结果
pub struct AgentRunResult {
pub reply: String,
pub total_input_tokens: u32,
pub total_output_tokens: u32,
pub iterations: usize,
}
impl AgentOrchestrator {
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
Self {
provider,
tool_registry,
}
}
/// 执行 Agent ReAct 循环
pub async fn run(
&self,
system_prompt: &str,
messages: &mut Vec<ChatMessage>,
ctx: &ToolContext,
params: &AgentRunParams,
allowed_tools: Option<&std::collections::HashSet<String>>,
) -> AiResult<AgentRunResult> {
let tools = match allowed_tools {
Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed),
None => self.tool_registry.tool_definitions(),
};
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
loop {
iterations += 1;
let response = self
.provider
.generate_with_tools(
messages.clone(),
tools.clone(),
system_prompt,
&params.model,
params.temperature,
params.max_tokens,
)
.await?;
if let Some(ref usage) = response.usage {
total_input_tokens += usage.input;
total_output_tokens += usage.output;
}
// 如果没有 tool_callsAgent 给出最终回复
let tool_calls = match response.tool_calls {
Some(tc) if !tc.is_empty() => tc,
_ => {
return Ok(AgentRunResult {
reply: response.content.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
iterations,
});
}
};
// 达到上限:强制结束
if iterations >= params.max_iterations {
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
.to_string(),
tool_calls: None,
tool_call_id: None,
});
continue;
}
// 将 assistant 的 tool_calls 加入消息历史
messages.push(ChatMessage {
role: ChatMessageRole::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
});
// 执行每个 Tool Call受沙箱 allowed_tools 约束)
for tc in &tool_calls {
let tool_result = match self.tool_registry.get(&tc.name) {
Some(tool) => {
// 沙箱过滤:如果 allowed_tools 存在且不包含此 Tool拒绝执行
if let Some(allowed) = allowed_tools {
if !allowed.contains(tc.name.as_str()) {
format!("Tool '{}' 在当前角色下不可用", tc.name)
} else {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
} else {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
}
None => format!("未知 Tool: {}", tc.name),
};
messages.push(ChatMessage {
role: ChatMessageRole::Tool,
content: tool_result,
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
}
}
}