feat(ai): AgentTool trait + ToolRegistry + AgentOrchestrator — ReAct 循环(最多 5 轮 Tool Call)

This commit is contained in:
iven
2026-05-18 02:56:26 +08:00
parent 877e9831f6
commit 2d62605812
5 changed files with 228 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
pub mod orchestrator;
pub mod registry;
pub mod tool;
pub mod tools;
pub use orchestrator::AgentOrchestrator;
pub use registry::ToolRegistry;
pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult};

View File

@@ -0,0 +1,117 @@
use std::sync::Arc;
use super::registry::ToolRegistry;
use super::tool::ToolContext;
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::error::AiResult;
use crate::provider::AiProvider;
/// Agent Orchestrator — 执行 ReAct 循环
pub struct AgentOrchestrator {
provider: Arc<dyn AiProvider>,
tool_registry: Arc<ToolRegistry>,
max_iterations: usize,
}
/// Agent 运行结果
pub struct AgentRunResult {
pub reply: String,
pub total_input_tokens: u32,
pub total_output_tokens: u32,
pub iterations: usize,
}
impl AgentOrchestrator {
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
Self {
provider,
tool_registry,
max_iterations: 5,
}
}
/// 执行 Agent ReAct 循环
pub async fn run(
&self,
system_prompt: &str,
messages: &mut Vec<ChatMessage>,
ctx: &ToolContext,
) -> AiResult<AgentRunResult> {
let tools = self.tool_registry.tool_definitions();
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
loop {
iterations += 1;
let response = self
.provider
.generate_with_tools(
messages.clone(),
tools.clone(),
system_prompt,
"auto",
0.7,
2048,
)
.await?;
if let Some(ref usage) = response.usage {
total_input_tokens += usage.input;
total_output_tokens += usage.output;
}
// 如果没有 tool_callsAgent 给出最终回复
let tool_calls = match response.tool_calls {
Some(tc) if !tc.is_empty() => tc,
_ => {
return Ok(AgentRunResult {
reply: response.content.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
iterations,
});
}
};
// 达到上限:强制结束
if iterations >= self.max_iterations {
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
.to_string(),
tool_calls: None,
tool_call_id: None,
});
continue;
}
// 将 assistant 的 tool_calls 加入消息历史
messages.push(ChatMessage {
role: ChatMessageRole::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
});
// 执行每个 Tool Call
for tc in &tool_calls {
let tool_result = match self.tool_registry.get(&tc.name) {
Some(tool) => {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
None => format!("未知 Tool: {}", tc.name),
};
messages.push(ChatMessage {
role: ChatMessageRole::Tool,
content: tool_result,
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
}
}
}

View File

@@ -0,0 +1,47 @@
use std::collections::HashMap;
use std::sync::Arc;
use super::tool::AgentTool;
/// Tool 注册表 — 管理所有可用的 Agent Tool
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn AgentTool>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Arc<dyn AgentTool>) {
self.tools.insert(tool.name().to_string(), tool);
}
pub fn get(&self, name: &str) -> Option<&Arc<dyn AgentTool>> {
self.tools.get(name)
}
pub fn all_tools(&self) -> Vec<&Arc<dyn AgentTool>> {
self.tools.values().collect()
}
/// 生成传给 LLM 的 ToolDefinition 列表
pub fn tool_definitions(&self) -> Vec<crate::dto::ToolDefinition> {
self.tools
.values()
.map(|t| crate::dto::ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters_schema(),
})
.collect()
}
}

View File

@@ -0,0 +1,55 @@
use async_trait::async_trait;
use erp_core::health_provider::HealthDataProvider;
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
/// Agent Tool trait — 所有 Agent 可调用的工具都实现此 trait
#[async_trait]
pub trait AgentTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult;
}
/// Tool 执行上下文 — 包含安全过滤后的租户/用户信息
pub struct ToolContext {
pub tenant_id: Uuid,
pub user_id: Uuid,
pub patient_id: Option<Uuid>,
pub db: DatabaseConnection,
pub health_provider: Arc<dyn HealthDataProvider>,
}
/// Tool 执行结果
pub struct ToolResult {
pub output: String,
pub display_hint: Option<DisplayHint>,
}
/// 前端渲染提示 — 告诉前端如何富化展示 Tool 返回的数据
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DisplayHint {
VitalCard {
indicator_type: String,
values: Vec<(String, f64)>,
unit: String,
},
LabReportCard {
report_date: String,
abnormal_count: usize,
},
ActionConfirm {
action_type: String,
summary: String,
confirm_payload: serde_json::Value,
},
RiskAlert {
level: String,
message: String,
},
Text,
}

View File

@@ -1,3 +1,4 @@
pub mod agent;
pub mod config;
pub mod copilot;
pub mod dto;