feat(ai): Phase 1B 角色沙箱 — 三级权限隔离 + Tool 过滤 + 输出控制

- 新增 agent/sandbox.rs: UserRole/SandboxConfig/OutputFilter 三级模型
- resolve_role() 从 JWT roles 解析为 Patient/MedicalStaff/Admin
- ToolRegistry.tool_definitions_filtered() 按角色白名单过滤
- orchestrator.run() 新增 allowed_tools 参数,Tool 执行时二次校验
- chat_handler 集成沙箱:角色 Prompt 后缀 + 患者免责声明追加
This commit is contained in:
iven
2026-05-18 23:28:30 +08:00
parent 7e3d27ecf3
commit 5ba28ea349
5 changed files with 163 additions and 8 deletions

View File

@@ -1,8 +1,10 @@
pub mod orchestrator;
pub mod registry;
pub mod sandbox;
pub mod tool;
pub mod tools;
pub use orchestrator::AgentOrchestrator;
pub use registry::ToolRegistry;
pub use sandbox::{OutputFilter, SandboxConfig, UserRole, get_sandbox_config, resolve_role};
pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult};

View File

@@ -54,8 +54,12 @@ impl AgentOrchestrator {
messages: &mut Vec<ChatMessage>,
ctx: &ToolContext,
params: &AgentRunParams,
allowed_tools: Option<&std::collections::HashSet<String>>,
) -> AiResult<AgentRunResult> {
let tools = self.tool_registry.tool_definitions();
let tools = match allowed_tools {
Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed),
None => self.tool_registry.tool_definitions(),
};
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
@@ -113,12 +117,22 @@ impl AgentOrchestrator {
tool_call_id: None,
});
// 执行每个 Tool Call
// 执行每个 Tool Call(受沙箱 allowed_tools 约束)
for tc in &tool_calls {
let tool_result = match self.tool_registry.get(&tc.name) {
Some(tool) => {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
// 沙箱过滤:如果 allowed_tools 存在且不包含此 Tool拒绝执行
if let Some(allowed) = allowed_tools {
if !allowed.contains(tc.name.as_str()) {
format!("Tool '{}' 在当前角色下不可用", tc.name)
} else {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
} else {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
}
None => format!("未知 Tool: {}", tc.name),
};

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use super::tool::AgentTool;
@@ -33,6 +34,22 @@ impl ToolRegistry {
self.tools.values().collect()
}
/// 根据角色允许的 Tool 列表过滤,返回过滤后的 ToolDefinition
pub fn tool_definitions_filtered(
&self,
allowed_tools: &HashSet<String>,
) -> Vec<crate::dto::ToolDefinition> {
self.tools
.values()
.filter(|t| allowed_tools.contains(t.name()))
.map(|t| crate::dto::ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters_schema(),
})
.collect()
}
/// 生成传给 LLM 的 ToolDefinition 列表
pub fn tool_definitions(&self) -> Vec<crate::dto::ToolDefinition> {
self.tools

View File

@@ -0,0 +1,97 @@
use std::collections::HashSet;
/// 用户角色(从 JWT claims 提取)
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum UserRole {
Patient,
MedicalStaff,
Admin,
}
impl UserRole {
pub fn from_role_name(role: &str) -> Self {
match role {
"patient" => Self::Patient,
"admin" => Self::Admin,
_ => Self::MedicalStaff, // doctor, nurse, etc.
}
}
}
/// 沙箱配置 — 定义某个角色可用的工具集和输出策略
#[derive(Debug, Clone)]
pub struct SandboxConfig {
pub role: UserRole,
pub allowed_tools: HashSet<String>,
pub system_prompt_suffix: &'static str,
pub output_filter: OutputFilter,
}
/// 输出过滤策略
#[derive(Debug, Clone, Default)]
pub struct OutputFilter {
pub append_disclaimer: bool,
pub redact_diagnosis_terms: bool,
pub disclaimer_text: &'static str,
}
/// 获取指定角色的沙箱配置
pub fn get_sandbox_config(role: &UserRole) -> SandboxConfig {
match role {
UserRole::Patient => SandboxConfig {
role: role.clone(),
allowed_tools: HashSet::from(["query_patient_vitals".into()]),
system_prompt_suffix: PATIENT_PROMPT_SUFFIX,
output_filter: OutputFilter {
append_disclaimer: true,
redact_diagnosis_terms: true,
disclaimer_text: DISCLAIMER_PATIENT,
},
},
UserRole::MedicalStaff => SandboxConfig {
role: role.clone(),
allowed_tools: HashSet::from(["query_patient_vitals".into()]),
system_prompt_suffix: MEDICAL_STAFF_PROMPT_SUFFIX,
output_filter: OutputFilter {
append_disclaimer: false,
redact_diagnosis_terms: false,
disclaimer_text: "",
},
},
UserRole::Admin => SandboxConfig {
role: role.clone(),
allowed_tools: HashSet::from(["query_patient_vitals".into()]),
system_prompt_suffix: ADMIN_PROMPT_SUFFIX,
output_filter: OutputFilter::default(),
},
}
}
/// 根据 JWT 角色列表解析为沙箱角色(取最高权限角色)
pub fn resolve_role(roles: &[String]) -> UserRole {
for r in roles {
if r == "admin" {
return UserRole::Admin;
}
}
for r in roles {
if r == "doctor" || r == "nurse" {
return UserRole::MedicalStaff;
}
}
for r in roles {
if r == "patient" {
return UserRole::Patient;
}
}
// 默认按患者(最小权限)
UserRole::Patient
}
static PATIENT_PROMPT_SUFFIX: &str = "\n\n注意:你是面向患者的 AI 健康助手。请使用通俗易懂的语言,避免使用专业医学术语。回答应温和且包含情绪安抚。不得做出具体的诊断或用药建议。如果用户描述了严重症状,建议立即就医。";
static MEDICAL_STAFF_PROMPT_SUFFIX: &str = "\n\n注意:你是面向医护人员的 AI 助手。回答应专业简洁,引用数据来源。可提供辅助诊断参考和风险评分,但最终诊断和治疗方案由医生决定。";
static ADMIN_PROMPT_SUFFIX: &str = "\n\n注意:你是面向管理人员的 AI 助手。主要提供用量统计、成本分析和运营洞察。不提供个体患者数据。";
static DISCLAIMER_PATIENT: &str = "\n\n---\n*以上内容由 AI 生成,仅供参考,不构成医疗诊断或治疗建议。如有健康问题请咨询专业医生。*";

View File

@@ -5,6 +5,7 @@ use erp_core::types::{ApiResponse, TenantContext};
use serde::{Deserialize, Serialize};
use crate::agent::orchestrator::AgentRunParams;
use crate::agent::sandbox::{get_sandbox_config, resolve_role};
use crate::agent::tool::ToolContext;
use crate::agent::tools::QueryPatientVitalsTool;
use crate::agent::{AgentOrchestrator, ToolRegistry};
@@ -113,10 +114,22 @@ where
)
})?;
// 构建 ToolRegistry — Phase 0 只有 query_patient_vitals
// 构建全局 ToolRegistry(所有已注册 Tool
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(QueryPatientVitalsTool));
// 根据用户角色获取沙箱配置
let user_role = resolve_role(&ctx.roles);
let sandbox = get_sandbox_config(&user_role);
tracing::info!(
tenant_id = %ctx.tenant_id,
user_id = %ctx.user_id,
role = ?user_role,
allowed_tools = ?sandbox.allowed_tools,
"Sandbox resolved"
);
let tool_ctx = ToolContext {
tenant_id: ctx.tenant_id,
user_id: ctx.user_id,
@@ -125,6 +138,12 @@ where
health_provider: ai_state.health_provider.clone(),
};
// system_prompt 追加角色沙箱的 Prompt 后缀
let system_prompt = format!(
"{}{}",
config.agent.system_prompt, sandbox.system_prompt_suffix
);
let run_params = AgentRunParams {
model: config.agent.model,
temperature: config.agent.temperature,
@@ -146,14 +165,15 @@ where
let provider_name = provider_arc.name().to_string();
// 执行 Agent ReAct 循环
// 执行 Agent ReAct 循环(使用角色沙箱过滤后的 Tool 和 Prompt
let orchestrator = AgentOrchestrator::new(provider_arc, std::sync::Arc::new(registry));
let result = orchestrator
let mut result = orchestrator
.run(
&config.agent.system_prompt,
&system_prompt,
&mut messages,
&tool_ctx,
&run_params,
Some(&sandbox.allowed_tools),
)
.await
.map_err(|e| {
@@ -161,6 +181,11 @@ where
erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into())
})?;
// 输出过滤:患者角色追加免责声明
if sandbox.output_filter.append_disclaimer && !result.reply.is_empty() {
result.reply.push_str(sandbox.output_filter.disclaimer_text);
}
let message_id = uuid::Uuid::now_v7().to_string();
tracing::info!(