- 新增 agent/sandbox.rs: UserRole/SandboxConfig/OutputFilter 三级模型 - resolve_role() 从 JWT roles 解析为 Patient/MedicalStaff/Admin - ToolRegistry.tool_definitions_filtered() 按角色白名单过滤 - orchestrator.run() 新增 allowed_tools 参数,Tool 执行时二次校验 - chat_handler 集成沙箱:角色 Prompt 后缀 + 患者免责声明追加
150 lines
4.8 KiB
Rust
150 lines
4.8 KiB
Rust
use std::sync::Arc;
|
||
|
||
use super::registry::ToolRegistry;
|
||
use super::tool::ToolContext;
|
||
use crate::dto::{ChatMessage, ChatMessageRole};
|
||
use crate::error::AiResult;
|
||
use crate::provider::AiProvider;
|
||
|
||
/// Agent 运行时参数
|
||
pub struct AgentRunParams {
|
||
pub model: String,
|
||
pub temperature: f32,
|
||
pub max_tokens: u32,
|
||
pub max_iterations: usize,
|
||
}
|
||
|
||
impl Default for AgentRunParams {
|
||
fn default() -> Self {
|
||
Self {
|
||
model: "claude-sonnet-4-6".to_string(),
|
||
temperature: 0.7,
|
||
max_tokens: 2048,
|
||
max_iterations: 5,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Agent Orchestrator — 执行 ReAct 循环
|
||
pub struct AgentOrchestrator {
|
||
provider: Arc<dyn AiProvider>,
|
||
tool_registry: Arc<ToolRegistry>,
|
||
}
|
||
|
||
/// Agent 运行结果
|
||
pub struct AgentRunResult {
|
||
pub reply: String,
|
||
pub total_input_tokens: u32,
|
||
pub total_output_tokens: u32,
|
||
pub iterations: usize,
|
||
}
|
||
|
||
impl AgentOrchestrator {
|
||
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
|
||
Self {
|
||
provider,
|
||
tool_registry,
|
||
}
|
||
}
|
||
|
||
/// 执行 Agent ReAct 循环
|
||
pub async fn run(
|
||
&self,
|
||
system_prompt: &str,
|
||
messages: &mut Vec<ChatMessage>,
|
||
ctx: &ToolContext,
|
||
params: &AgentRunParams,
|
||
allowed_tools: Option<&std::collections::HashSet<String>>,
|
||
) -> AiResult<AgentRunResult> {
|
||
let tools = match allowed_tools {
|
||
Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed),
|
||
None => self.tool_registry.tool_definitions(),
|
||
};
|
||
let mut iterations = 0;
|
||
let mut total_input_tokens = 0u32;
|
||
let mut total_output_tokens = 0u32;
|
||
|
||
loop {
|
||
iterations += 1;
|
||
|
||
let response = self
|
||
.provider
|
||
.generate_with_tools(
|
||
messages.clone(),
|
||
tools.clone(),
|
||
system_prompt,
|
||
¶ms.model,
|
||
params.temperature,
|
||
params.max_tokens,
|
||
)
|
||
.await?;
|
||
|
||
if let Some(ref usage) = response.usage {
|
||
total_input_tokens += usage.input;
|
||
total_output_tokens += usage.output;
|
||
}
|
||
|
||
// 如果没有 tool_calls,Agent 给出最终回复
|
||
let tool_calls = match response.tool_calls {
|
||
Some(tc) if !tc.is_empty() => tc,
|
||
_ => {
|
||
return Ok(AgentRunResult {
|
||
reply: response.content.unwrap_or_default(),
|
||
total_input_tokens,
|
||
total_output_tokens,
|
||
iterations,
|
||
});
|
||
}
|
||
};
|
||
|
||
// 达到上限:强制结束
|
||
if iterations >= params.max_iterations {
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::User,
|
||
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
|
||
.to_string(),
|
||
tool_calls: None,
|
||
tool_call_id: None,
|
||
});
|
||
continue;
|
||
}
|
||
|
||
// 将 assistant 的 tool_calls 加入消息历史
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::Assistant,
|
||
content: response.content.unwrap_or_default(),
|
||
tool_calls: Some(tool_calls.clone()),
|
||
tool_call_id: None,
|
||
});
|
||
|
||
// 执行每个 Tool Call(受沙箱 allowed_tools 约束)
|
||
for tc in &tool_calls {
|
||
let tool_result = match self.tool_registry.get(&tc.name) {
|
||
Some(tool) => {
|
||
// 沙箱过滤:如果 allowed_tools 存在且不包含此 Tool,拒绝执行
|
||
if let Some(allowed) = allowed_tools {
|
||
if !allowed.contains(tc.name.as_str()) {
|
||
format!("Tool '{}' 在当前角色下不可用", tc.name)
|
||
} else {
|
||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||
result.output
|
||
}
|
||
} else {
|
||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||
result.output
|
||
}
|
||
}
|
||
None => format!("未知 Tool: {}", tc.name),
|
||
};
|
||
|
||
messages.push(ChatMessage {
|
||
role: ChatMessageRole::Tool,
|
||
content: tool_result,
|
||
tool_calls: None,
|
||
tool_call_id: Some(tc.id.clone()),
|
||
});
|
||
}
|
||
}
|
||
}
|
||
}
|