From 2d62605812cb1e6556d2ee0e85a0b19146cdf1fa Mon Sep 17 00:00:00 2001 From: iven Date: Mon, 18 May 2026 02:56:26 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20AgentTool=20trait=20+=20ToolRegistr?= =?UTF-8?q?y=20+=20AgentOrchestrator=20=E2=80=94=20ReAct=20=E5=BE=AA?= =?UTF-8?q?=E7=8E=AF=EF=BC=88=E6=9C=80=E5=A4=9A=205=20=E8=BD=AE=20Tool=20C?= =?UTF-8?q?all=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/erp-ai/src/agent/mod.rs | 8 ++ crates/erp-ai/src/agent/orchestrator.rs | 117 ++++++++++++++++++++++++ crates/erp-ai/src/agent/registry.rs | 47 ++++++++++ crates/erp-ai/src/agent/tool.rs | 55 +++++++++++ crates/erp-ai/src/lib.rs | 1 + 5 files changed, 228 insertions(+) create mode 100644 crates/erp-ai/src/agent/mod.rs create mode 100644 crates/erp-ai/src/agent/orchestrator.rs create mode 100644 crates/erp-ai/src/agent/registry.rs create mode 100644 crates/erp-ai/src/agent/tool.rs diff --git a/crates/erp-ai/src/agent/mod.rs b/crates/erp-ai/src/agent/mod.rs new file mode 100644 index 0000000..a734572 --- /dev/null +++ b/crates/erp-ai/src/agent/mod.rs @@ -0,0 +1,8 @@ +pub mod orchestrator; +pub mod registry; +pub mod tool; +pub mod tools; + +pub use orchestrator::AgentOrchestrator; +pub use registry::ToolRegistry; +pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult}; diff --git a/crates/erp-ai/src/agent/orchestrator.rs b/crates/erp-ai/src/agent/orchestrator.rs new file mode 100644 index 0000000..4c4a860 --- /dev/null +++ b/crates/erp-ai/src/agent/orchestrator.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use super::registry::ToolRegistry; +use super::tool::ToolContext; +use crate::dto::{ChatMessage, ChatMessageRole}; +use crate::error::AiResult; +use crate::provider::AiProvider; + +/// Agent Orchestrator — 执行 ReAct 循环 +pub struct AgentOrchestrator { + provider: Arc, + tool_registry: Arc, + max_iterations: usize, +} + +/// Agent 运行结果 +pub struct AgentRunResult { + pub reply: String, + pub total_input_tokens: u32, + pub total_output_tokens: u32, + pub iterations: usize, +} + +impl AgentOrchestrator { + pub fn new(provider: Arc, tool_registry: Arc) -> Self { + Self { + provider, + tool_registry, + max_iterations: 5, + } + } + + /// 执行 Agent ReAct 循环 + pub async fn run( + &self, + system_prompt: &str, + messages: &mut Vec, + ctx: &ToolContext, + ) -> AiResult { + let tools = self.tool_registry.tool_definitions(); + let mut iterations = 0; + let mut total_input_tokens = 0u32; + let mut total_output_tokens = 0u32; + + loop { + iterations += 1; + + let response = self + .provider + .generate_with_tools( + messages.clone(), + tools.clone(), + system_prompt, + "auto", + 0.7, + 2048, + ) + .await?; + + if let Some(ref usage) = response.usage { + total_input_tokens += usage.input; + total_output_tokens += usage.output; + } + + // 如果没有 tool_calls,Agent 给出最终回复 + let tool_calls = match response.tool_calls { + Some(tc) if !tc.is_empty() => tc, + _ => { + return Ok(AgentRunResult { + reply: response.content.unwrap_or_default(), + total_input_tokens, + total_output_tokens, + iterations, + }); + } + }; + + // 达到上限:强制结束 + if iterations >= self.max_iterations { + messages.push(ChatMessage { + role: ChatMessageRole::User, + content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)" + .to_string(), + tool_calls: None, + tool_call_id: None, + }); + continue; + } + + // 将 assistant 的 tool_calls 加入消息历史 + messages.push(ChatMessage { + role: ChatMessageRole::Assistant, + content: response.content.unwrap_or_default(), + tool_calls: Some(tool_calls.clone()), + tool_call_id: None, + }); + + // 执行每个 Tool Call + for tc in &tool_calls { + let tool_result = match self.tool_registry.get(&tc.name) { + Some(tool) => { + let result = tool.execute(ctx, tc.arguments.clone()).await; + result.output + } + None => format!("未知 Tool: {}", tc.name), + }; + + messages.push(ChatMessage { + role: ChatMessageRole::Tool, + content: tool_result, + tool_calls: None, + tool_call_id: Some(tc.id.clone()), + }); + } + } + } +} diff --git a/crates/erp-ai/src/agent/registry.rs b/crates/erp-ai/src/agent/registry.rs new file mode 100644 index 0000000..7754772 --- /dev/null +++ b/crates/erp-ai/src/agent/registry.rs @@ -0,0 +1,47 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use super::tool::AgentTool; + +/// Tool 注册表 — 管理所有可用的 Agent Tool +pub struct ToolRegistry { + tools: HashMap>, +} + +impl Default for ToolRegistry { + fn default() -> Self { + Self::new() + } +} + +impl ToolRegistry { + pub fn new() -> Self { + Self { + tools: HashMap::new(), + } + } + + pub fn register(&mut self, tool: Arc) { + self.tools.insert(tool.name().to_string(), tool); + } + + pub fn get(&self, name: &str) -> Option<&Arc> { + self.tools.get(name) + } + + pub fn all_tools(&self) -> Vec<&Arc> { + self.tools.values().collect() + } + + /// 生成传给 LLM 的 ToolDefinition 列表 + pub fn tool_definitions(&self) -> Vec { + self.tools + .values() + .map(|t| crate::dto::ToolDefinition { + name: t.name().to_string(), + description: t.description().to_string(), + parameters: t.parameters_schema(), + }) + .collect() + } +} diff --git a/crates/erp-ai/src/agent/tool.rs b/crates/erp-ai/src/agent/tool.rs new file mode 100644 index 0000000..5813d92 --- /dev/null +++ b/crates/erp-ai/src/agent/tool.rs @@ -0,0 +1,55 @@ +use async_trait::async_trait; +use erp_core::health_provider::HealthDataProvider; +use sea_orm::DatabaseConnection; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uuid::Uuid; + +/// Agent Tool trait — 所有 Agent 可调用的工具都实现此 trait +#[async_trait] +pub trait AgentTool: Send + Sync { + fn name(&self) -> &str; + fn description(&self) -> &str; + fn parameters_schema(&self) -> serde_json::Value; + async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult; +} + +/// Tool 执行上下文 — 包含安全过滤后的租户/用户信息 +pub struct ToolContext { + pub tenant_id: Uuid, + pub user_id: Uuid, + pub patient_id: Option, + pub db: DatabaseConnection, + pub health_provider: Arc, +} + +/// Tool 执行结果 +pub struct ToolResult { + pub output: String, + pub display_hint: Option, +} + +/// 前端渲染提示 — 告诉前端如何富化展示 Tool 返回的数据 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DisplayHint { + VitalCard { + indicator_type: String, + values: Vec<(String, f64)>, + unit: String, + }, + LabReportCard { + report_date: String, + abnormal_count: usize, + }, + ActionConfirm { + action_type: String, + summary: String, + confirm_payload: serde_json::Value, + }, + RiskAlert { + level: String, + message: String, + }, + Text, +} diff --git a/crates/erp-ai/src/lib.rs b/crates/erp-ai/src/lib.rs index 3dc3728..dd69a5c 100644 --- a/crates/erp-ai/src/lib.rs +++ b/crates/erp-ai/src/lib.rs @@ -1,3 +1,4 @@ +pub mod agent; pub mod config; pub mod copilot; pub mod dto;