From 5ba28ea349cad220d6e8ca6deb35de8871c64118 Mon Sep 17 00:00:00 2001 From: iven Date: Mon, 18 May 2026 23:28:30 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20Phase=201B=20=E8=A7=92=E8=89=B2?= =?UTF-8?q?=E6=B2=99=E7=AE=B1=20=E2=80=94=20=E4=B8=89=E7=BA=A7=E6=9D=83?= =?UTF-8?q?=E9=99=90=E9=9A=94=E7=A6=BB=20+=20Tool=20=E8=BF=87=E6=BB=A4=20+?= =?UTF-8?q?=20=E8=BE=93=E5=87=BA=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 agent/sandbox.rs: UserRole/SandboxConfig/OutputFilter 三级模型 - resolve_role() 从 JWT roles 解析为 Patient/MedicalStaff/Admin - ToolRegistry.tool_definitions_filtered() 按角色白名单过滤 - orchestrator.run() 新增 allowed_tools 参数,Tool 执行时二次校验 - chat_handler 集成沙箱:角色 Prompt 后缀 + 患者免责声明追加 --- crates/erp-ai/src/agent/mod.rs | 2 + crates/erp-ai/src/agent/orchestrator.rs | 22 ++++- crates/erp-ai/src/agent/registry.rs | 17 ++++ crates/erp-ai/src/agent/sandbox.rs | 97 +++++++++++++++++++++++ crates/erp-ai/src/handler/chat_handler.rs | 33 +++++++- 5 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 crates/erp-ai/src/agent/sandbox.rs diff --git a/crates/erp-ai/src/agent/mod.rs b/crates/erp-ai/src/agent/mod.rs index a734572..c4d7bcb 100644 --- a/crates/erp-ai/src/agent/mod.rs +++ b/crates/erp-ai/src/agent/mod.rs @@ -1,8 +1,10 @@ pub mod orchestrator; pub mod registry; +pub mod sandbox; pub mod tool; pub mod tools; pub use orchestrator::AgentOrchestrator; pub use registry::ToolRegistry; +pub use sandbox::{OutputFilter, SandboxConfig, UserRole, get_sandbox_config, resolve_role}; pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult}; diff --git a/crates/erp-ai/src/agent/orchestrator.rs b/crates/erp-ai/src/agent/orchestrator.rs index de01129..4ff88ba 100644 --- a/crates/erp-ai/src/agent/orchestrator.rs +++ b/crates/erp-ai/src/agent/orchestrator.rs @@ -54,8 +54,12 @@ impl AgentOrchestrator { messages: &mut Vec, ctx: &ToolContext, params: &AgentRunParams, + allowed_tools: Option<&std::collections::HashSet>, ) -> AiResult { - let tools = self.tool_registry.tool_definitions(); + let tools = match allowed_tools { + Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed), + None => self.tool_registry.tool_definitions(), + }; let mut iterations = 0; let mut total_input_tokens = 0u32; let mut total_output_tokens = 0u32; @@ -113,12 +117,22 @@ impl AgentOrchestrator { tool_call_id: None, }); - // 执行每个 Tool Call + // 执行每个 Tool Call(受沙箱 allowed_tools 约束) for tc in &tool_calls { let tool_result = match self.tool_registry.get(&tc.name) { Some(tool) => { - let result = tool.execute(ctx, tc.arguments.clone()).await; - result.output + // 沙箱过滤:如果 allowed_tools 存在且不包含此 Tool,拒绝执行 + if let Some(allowed) = allowed_tools { + if !allowed.contains(tc.name.as_str()) { + format!("Tool '{}' 在当前角色下不可用", tc.name) + } else { + let result = tool.execute(ctx, tc.arguments.clone()).await; + result.output + } + } else { + let result = tool.execute(ctx, tc.arguments.clone()).await; + result.output + } } None => format!("未知 Tool: {}", tc.name), }; diff --git a/crates/erp-ai/src/agent/registry.rs b/crates/erp-ai/src/agent/registry.rs index 7754772..df16864 100644 --- a/crates/erp-ai/src/agent/registry.rs +++ b/crates/erp-ai/src/agent/registry.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; use super::tool::AgentTool; @@ -33,6 +34,22 @@ impl ToolRegistry { self.tools.values().collect() } + /// 根据角色允许的 Tool 列表过滤,返回过滤后的 ToolDefinition + pub fn tool_definitions_filtered( + &self, + allowed_tools: &HashSet, + ) -> Vec { + self.tools + .values() + .filter(|t| allowed_tools.contains(t.name())) + .map(|t| crate::dto::ToolDefinition { + name: t.name().to_string(), + description: t.description().to_string(), + parameters: t.parameters_schema(), + }) + .collect() + } + /// 生成传给 LLM 的 ToolDefinition 列表 pub fn tool_definitions(&self) -> Vec { self.tools diff --git a/crates/erp-ai/src/agent/sandbox.rs b/crates/erp-ai/src/agent/sandbox.rs new file mode 100644 index 0000000..2db169a --- /dev/null +++ b/crates/erp-ai/src/agent/sandbox.rs @@ -0,0 +1,97 @@ +use std::collections::HashSet; + +/// 用户角色(从 JWT claims 提取) +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum UserRole { + Patient, + MedicalStaff, + Admin, +} + +impl UserRole { + pub fn from_role_name(role: &str) -> Self { + match role { + "patient" => Self::Patient, + "admin" => Self::Admin, + _ => Self::MedicalStaff, // doctor, nurse, etc. + } + } +} + +/// 沙箱配置 — 定义某个角色可用的工具集和输出策略 +#[derive(Debug, Clone)] +pub struct SandboxConfig { + pub role: UserRole, + pub allowed_tools: HashSet, + pub system_prompt_suffix: &'static str, + pub output_filter: OutputFilter, +} + +/// 输出过滤策略 +#[derive(Debug, Clone, Default)] +pub struct OutputFilter { + pub append_disclaimer: bool, + pub redact_diagnosis_terms: bool, + pub disclaimer_text: &'static str, +} + +/// 获取指定角色的沙箱配置 +pub fn get_sandbox_config(role: &UserRole) -> SandboxConfig { + match role { + UserRole::Patient => SandboxConfig { + role: role.clone(), + allowed_tools: HashSet::from(["query_patient_vitals".into()]), + system_prompt_suffix: PATIENT_PROMPT_SUFFIX, + output_filter: OutputFilter { + append_disclaimer: true, + redact_diagnosis_terms: true, + disclaimer_text: DISCLAIMER_PATIENT, + }, + }, + UserRole::MedicalStaff => SandboxConfig { + role: role.clone(), + allowed_tools: HashSet::from(["query_patient_vitals".into()]), + system_prompt_suffix: MEDICAL_STAFF_PROMPT_SUFFIX, + output_filter: OutputFilter { + append_disclaimer: false, + redact_diagnosis_terms: false, + disclaimer_text: "", + }, + }, + UserRole::Admin => SandboxConfig { + role: role.clone(), + allowed_tools: HashSet::from(["query_patient_vitals".into()]), + system_prompt_suffix: ADMIN_PROMPT_SUFFIX, + output_filter: OutputFilter::default(), + }, + } +} + +/// 根据 JWT 角色列表解析为沙箱角色(取最高权限角色) +pub fn resolve_role(roles: &[String]) -> UserRole { + for r in roles { + if r == "admin" { + return UserRole::Admin; + } + } + for r in roles { + if r == "doctor" || r == "nurse" { + return UserRole::MedicalStaff; + } + } + for r in roles { + if r == "patient" { + return UserRole::Patient; + } + } + // 默认按患者(最小权限) + UserRole::Patient +} + +static PATIENT_PROMPT_SUFFIX: &str = "\n\n注意:你是面向患者的 AI 健康助手。请使用通俗易懂的语言,避免使用专业医学术语。回答应温和且包含情绪安抚。不得做出具体的诊断或用药建议。如果用户描述了严重症状,建议立即就医。"; + +static MEDICAL_STAFF_PROMPT_SUFFIX: &str = "\n\n注意:你是面向医护人员的 AI 助手。回答应专业简洁,引用数据来源。可提供辅助诊断参考和风险评分,但最终诊断和治疗方案由医生决定。"; + +static ADMIN_PROMPT_SUFFIX: &str = "\n\n注意:你是面向管理人员的 AI 助手。主要提供用量统计、成本分析和运营洞察。不提供个体患者数据。"; + +static DISCLAIMER_PATIENT: &str = "\n\n---\n*以上内容由 AI 生成,仅供参考,不构成医疗诊断或治疗建议。如有健康问题请咨询专业医生。*"; diff --git a/crates/erp-ai/src/handler/chat_handler.rs b/crates/erp-ai/src/handler/chat_handler.rs index 8c3b64c..54fd7dc 100644 --- a/crates/erp-ai/src/handler/chat_handler.rs +++ b/crates/erp-ai/src/handler/chat_handler.rs @@ -5,6 +5,7 @@ use erp_core::types::{ApiResponse, TenantContext}; use serde::{Deserialize, Serialize}; use crate::agent::orchestrator::AgentRunParams; +use crate::agent::sandbox::{get_sandbox_config, resolve_role}; use crate::agent::tool::ToolContext; use crate::agent::tools::QueryPatientVitalsTool; use crate::agent::{AgentOrchestrator, ToolRegistry}; @@ -113,10 +114,22 @@ where ) })?; - // 构建 ToolRegistry — Phase 0 只有 query_patient_vitals + // 构建全局 ToolRegistry(所有已注册 Tool) let mut registry = ToolRegistry::new(); registry.register(std::sync::Arc::new(QueryPatientVitalsTool)); + // 根据用户角色获取沙箱配置 + let user_role = resolve_role(&ctx.roles); + let sandbox = get_sandbox_config(&user_role); + + tracing::info!( + tenant_id = %ctx.tenant_id, + user_id = %ctx.user_id, + role = ?user_role, + allowed_tools = ?sandbox.allowed_tools, + "Sandbox resolved" + ); + let tool_ctx = ToolContext { tenant_id: ctx.tenant_id, user_id: ctx.user_id, @@ -125,6 +138,12 @@ where health_provider: ai_state.health_provider.clone(), }; + // system_prompt 追加角色沙箱的 Prompt 后缀 + let system_prompt = format!( + "{}{}", + config.agent.system_prompt, sandbox.system_prompt_suffix + ); + let run_params = AgentRunParams { model: config.agent.model, temperature: config.agent.temperature, @@ -146,14 +165,15 @@ where let provider_name = provider_arc.name().to_string(); - // 执行 Agent ReAct 循环 + // 执行 Agent ReAct 循环(使用角色沙箱过滤后的 Tool 和 Prompt) let orchestrator = AgentOrchestrator::new(provider_arc, std::sync::Arc::new(registry)); - let result = orchestrator + let mut result = orchestrator .run( - &config.agent.system_prompt, + &system_prompt, &mut messages, &tool_ctx, &run_params, + Some(&sandbox.allowed_tools), ) .await .map_err(|e| { @@ -161,6 +181,11 @@ where erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into()) })?; + // 输出过滤:患者角色追加免责声明 + if sandbox.output_filter.append_disclaimer && !result.reply.is_empty() { + result.reply.push_str(sandbox.output_filter.disclaimer_text); + } + let message_id = uuid::Uuid::now_v7().to_string(); tracing::info!(