feat(ai): Phase 1B 角色沙箱 — 三级权限隔离 + Tool 过滤 + 输出控制
- 新增 agent/sandbox.rs: UserRole/SandboxConfig/OutputFilter 三级模型 - resolve_role() 从 JWT roles 解析为 Patient/MedicalStaff/Admin - ToolRegistry.tool_definitions_filtered() 按角色白名单过滤 - orchestrator.run() 新增 allowed_tools 参数,Tool 执行时二次校验 - chat_handler 集成沙箱:角色 Prompt 后缀 + 患者免责声明追加
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
pub mod orchestrator;
|
||||
pub mod registry;
|
||||
pub mod sandbox;
|
||||
pub mod tool;
|
||||
pub mod tools;
|
||||
|
||||
pub use orchestrator::AgentOrchestrator;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use sandbox::{OutputFilter, SandboxConfig, UserRole, get_sandbox_config, resolve_role};
|
||||
pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
|
||||
|
||||
@@ -54,8 +54,12 @@ impl AgentOrchestrator {
|
||||
messages: &mut Vec<ChatMessage>,
|
||||
ctx: &ToolContext,
|
||||
params: &AgentRunParams,
|
||||
allowed_tools: Option<&std::collections::HashSet<String>>,
|
||||
) -> AiResult<AgentRunResult> {
|
||||
let tools = self.tool_registry.tool_definitions();
|
||||
let tools = match allowed_tools {
|
||||
Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed),
|
||||
None => self.tool_registry.tool_definitions(),
|
||||
};
|
||||
let mut iterations = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
let mut total_output_tokens = 0u32;
|
||||
@@ -113,12 +117,22 @@ impl AgentOrchestrator {
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
// 执行每个 Tool Call
|
||||
// 执行每个 Tool Call(受沙箱 allowed_tools 约束)
|
||||
for tc in &tool_calls {
|
||||
let tool_result = match self.tool_registry.get(&tc.name) {
|
||||
Some(tool) => {
|
||||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||||
result.output
|
||||
// 沙箱过滤:如果 allowed_tools 存在且不包含此 Tool,拒绝执行
|
||||
if let Some(allowed) = allowed_tools {
|
||||
if !allowed.contains(tc.name.as_str()) {
|
||||
format!("Tool '{}' 在当前角色下不可用", tc.name)
|
||||
} else {
|
||||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||||
result.output
|
||||
}
|
||||
} else {
|
||||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||||
result.output
|
||||
}
|
||||
}
|
||||
None => format!("未知 Tool: {}", tc.name),
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::tool::AgentTool;
|
||||
@@ -33,6 +34,22 @@ impl ToolRegistry {
|
||||
self.tools.values().collect()
|
||||
}
|
||||
|
||||
/// 根据角色允许的 Tool 列表过滤,返回过滤后的 ToolDefinition
|
||||
pub fn tool_definitions_filtered(
|
||||
&self,
|
||||
allowed_tools: &HashSet<String>,
|
||||
) -> Vec<crate::dto::ToolDefinition> {
|
||||
self.tools
|
||||
.values()
|
||||
.filter(|t| allowed_tools.contains(t.name()))
|
||||
.map(|t| crate::dto::ToolDefinition {
|
||||
name: t.name().to_string(),
|
||||
description: t.description().to_string(),
|
||||
parameters: t.parameters_schema(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 生成传给 LLM 的 ToolDefinition 列表
|
||||
pub fn tool_definitions(&self) -> Vec<crate::dto::ToolDefinition> {
|
||||
self.tools
|
||||
|
||||
97
crates/erp-ai/src/agent/sandbox.rs
Normal file
97
crates/erp-ai/src/agent/sandbox.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// 用户角色(从 JWT claims 提取)
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum UserRole {
|
||||
Patient,
|
||||
MedicalStaff,
|
||||
Admin,
|
||||
}
|
||||
|
||||
impl UserRole {
|
||||
pub fn from_role_name(role: &str) -> Self {
|
||||
match role {
|
||||
"patient" => Self::Patient,
|
||||
"admin" => Self::Admin,
|
||||
_ => Self::MedicalStaff, // doctor, nurse, etc.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 沙箱配置 — 定义某个角色可用的工具集和输出策略
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SandboxConfig {
|
||||
pub role: UserRole,
|
||||
pub allowed_tools: HashSet<String>,
|
||||
pub system_prompt_suffix: &'static str,
|
||||
pub output_filter: OutputFilter,
|
||||
}
|
||||
|
||||
/// 输出过滤策略
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OutputFilter {
|
||||
pub append_disclaimer: bool,
|
||||
pub redact_diagnosis_terms: bool,
|
||||
pub disclaimer_text: &'static str,
|
||||
}
|
||||
|
||||
/// 获取指定角色的沙箱配置
|
||||
pub fn get_sandbox_config(role: &UserRole) -> SandboxConfig {
|
||||
match role {
|
||||
UserRole::Patient => SandboxConfig {
|
||||
role: role.clone(),
|
||||
allowed_tools: HashSet::from(["query_patient_vitals".into()]),
|
||||
system_prompt_suffix: PATIENT_PROMPT_SUFFIX,
|
||||
output_filter: OutputFilter {
|
||||
append_disclaimer: true,
|
||||
redact_diagnosis_terms: true,
|
||||
disclaimer_text: DISCLAIMER_PATIENT,
|
||||
},
|
||||
},
|
||||
UserRole::MedicalStaff => SandboxConfig {
|
||||
role: role.clone(),
|
||||
allowed_tools: HashSet::from(["query_patient_vitals".into()]),
|
||||
system_prompt_suffix: MEDICAL_STAFF_PROMPT_SUFFIX,
|
||||
output_filter: OutputFilter {
|
||||
append_disclaimer: false,
|
||||
redact_diagnosis_terms: false,
|
||||
disclaimer_text: "",
|
||||
},
|
||||
},
|
||||
UserRole::Admin => SandboxConfig {
|
||||
role: role.clone(),
|
||||
allowed_tools: HashSet::from(["query_patient_vitals".into()]),
|
||||
system_prompt_suffix: ADMIN_PROMPT_SUFFIX,
|
||||
output_filter: OutputFilter::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// 根据 JWT 角色列表解析为沙箱角色(取最高权限角色)
|
||||
pub fn resolve_role(roles: &[String]) -> UserRole {
|
||||
for r in roles {
|
||||
if r == "admin" {
|
||||
return UserRole::Admin;
|
||||
}
|
||||
}
|
||||
for r in roles {
|
||||
if r == "doctor" || r == "nurse" {
|
||||
return UserRole::MedicalStaff;
|
||||
}
|
||||
}
|
||||
for r in roles {
|
||||
if r == "patient" {
|
||||
return UserRole::Patient;
|
||||
}
|
||||
}
|
||||
// 默认按患者(最小权限)
|
||||
UserRole::Patient
|
||||
}
|
||||
|
||||
static PATIENT_PROMPT_SUFFIX: &str = "\n\n注意:你是面向患者的 AI 健康助手。请使用通俗易懂的语言,避免使用专业医学术语。回答应温和且包含情绪安抚。不得做出具体的诊断或用药建议。如果用户描述了严重症状,建议立即就医。";
|
||||
|
||||
static MEDICAL_STAFF_PROMPT_SUFFIX: &str = "\n\n注意:你是面向医护人员的 AI 助手。回答应专业简洁,引用数据来源。可提供辅助诊断参考和风险评分,但最终诊断和治疗方案由医生决定。";
|
||||
|
||||
static ADMIN_PROMPT_SUFFIX: &str = "\n\n注意:你是面向管理人员的 AI 助手。主要提供用量统计、成本分析和运营洞察。不提供个体患者数据。";
|
||||
|
||||
static DISCLAIMER_PATIENT: &str = "\n\n---\n*以上内容由 AI 生成,仅供参考,不构成医疗诊断或治疗建议。如有健康问题请咨询专业医生。*";
|
||||
@@ -5,6 +5,7 @@ use erp_core::types::{ApiResponse, TenantContext};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::agent::orchestrator::AgentRunParams;
|
||||
use crate::agent::sandbox::{get_sandbox_config, resolve_role};
|
||||
use crate::agent::tool::ToolContext;
|
||||
use crate::agent::tools::QueryPatientVitalsTool;
|
||||
use crate::agent::{AgentOrchestrator, ToolRegistry};
|
||||
@@ -113,10 +114,22 @@ where
|
||||
)
|
||||
})?;
|
||||
|
||||
// 构建 ToolRegistry — Phase 0 只有 query_patient_vitals
|
||||
// 构建全局 ToolRegistry(所有已注册 Tool)
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(std::sync::Arc::new(QueryPatientVitalsTool));
|
||||
|
||||
// 根据用户角色获取沙箱配置
|
||||
let user_role = resolve_role(&ctx.roles);
|
||||
let sandbox = get_sandbox_config(&user_role);
|
||||
|
||||
tracing::info!(
|
||||
tenant_id = %ctx.tenant_id,
|
||||
user_id = %ctx.user_id,
|
||||
role = ?user_role,
|
||||
allowed_tools = ?sandbox.allowed_tools,
|
||||
"Sandbox resolved"
|
||||
);
|
||||
|
||||
let tool_ctx = ToolContext {
|
||||
tenant_id: ctx.tenant_id,
|
||||
user_id: ctx.user_id,
|
||||
@@ -125,6 +138,12 @@ where
|
||||
health_provider: ai_state.health_provider.clone(),
|
||||
};
|
||||
|
||||
// system_prompt 追加角色沙箱的 Prompt 后缀
|
||||
let system_prompt = format!(
|
||||
"{}{}",
|
||||
config.agent.system_prompt, sandbox.system_prompt_suffix
|
||||
);
|
||||
|
||||
let run_params = AgentRunParams {
|
||||
model: config.agent.model,
|
||||
temperature: config.agent.temperature,
|
||||
@@ -146,14 +165,15 @@ where
|
||||
|
||||
let provider_name = provider_arc.name().to_string();
|
||||
|
||||
// 执行 Agent ReAct 循环
|
||||
// 执行 Agent ReAct 循环(使用角色沙箱过滤后的 Tool 和 Prompt)
|
||||
let orchestrator = AgentOrchestrator::new(provider_arc, std::sync::Arc::new(registry));
|
||||
let result = orchestrator
|
||||
let mut result = orchestrator
|
||||
.run(
|
||||
&config.agent.system_prompt,
|
||||
&system_prompt,
|
||||
&mut messages,
|
||||
&tool_ctx,
|
||||
&run_params,
|
||||
Some(&sandbox.allowed_tools),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -161,6 +181,11 @@ where
|
||||
erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into())
|
||||
})?;
|
||||
|
||||
// 输出过滤:患者角色追加免责声明
|
||||
if sandbox.output_filter.append_disclaimer && !result.reply.is_empty() {
|
||||
result.reply.push_str(sandbox.output_filter.disclaimer_text);
|
||||
}
|
||||
|
||||
let message_id = uuid::Uuid::now_v7().to_string();
|
||||
|
||||
tracing::info!(
|
||||
|
||||
Reference in New Issue
Block a user