feat(ai): AgentTool trait + ToolRegistry + AgentOrchestrator — ReAct 循环(最多 5 轮 Tool Call)
This commit is contained in:
8
crates/erp-ai/src/agent/mod.rs
Normal file
8
crates/erp-ai/src/agent/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
pub mod orchestrator;
|
||||||
|
pub mod registry;
|
||||||
|
pub mod tool;
|
||||||
|
pub mod tools;
|
||||||
|
|
||||||
|
pub use orchestrator::AgentOrchestrator;
|
||||||
|
pub use registry::ToolRegistry;
|
||||||
|
pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
|
||||||
117
crates/erp-ai/src/agent/orchestrator.rs
Normal file
117
crates/erp-ai/src/agent/orchestrator.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::registry::ToolRegistry;
|
||||||
|
use super::tool::ToolContext;
|
||||||
|
use crate::dto::{ChatMessage, ChatMessageRole};
|
||||||
|
use crate::error::AiResult;
|
||||||
|
use crate::provider::AiProvider;
|
||||||
|
|
||||||
|
/// Agent Orchestrator — 执行 ReAct 循环
|
||||||
|
pub struct AgentOrchestrator {
|
||||||
|
provider: Arc<dyn AiProvider>,
|
||||||
|
tool_registry: Arc<ToolRegistry>,
|
||||||
|
max_iterations: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent 运行结果
|
||||||
|
pub struct AgentRunResult {
|
||||||
|
pub reply: String,
|
||||||
|
pub total_input_tokens: u32,
|
||||||
|
pub total_output_tokens: u32,
|
||||||
|
pub iterations: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentOrchestrator {
|
||||||
|
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
tool_registry,
|
||||||
|
max_iterations: 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行 Agent ReAct 循环
|
||||||
|
pub async fn run(
|
||||||
|
&self,
|
||||||
|
system_prompt: &str,
|
||||||
|
messages: &mut Vec<ChatMessage>,
|
||||||
|
ctx: &ToolContext,
|
||||||
|
) -> AiResult<AgentRunResult> {
|
||||||
|
let tools = self.tool_registry.tool_definitions();
|
||||||
|
let mut iterations = 0;
|
||||||
|
let mut total_input_tokens = 0u32;
|
||||||
|
let mut total_output_tokens = 0u32;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
iterations += 1;
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.provider
|
||||||
|
.generate_with_tools(
|
||||||
|
messages.clone(),
|
||||||
|
tools.clone(),
|
||||||
|
system_prompt,
|
||||||
|
"auto",
|
||||||
|
0.7,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(ref usage) = response.usage {
|
||||||
|
total_input_tokens += usage.input;
|
||||||
|
total_output_tokens += usage.output;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 tool_calls,Agent 给出最终回复
|
||||||
|
let tool_calls = match response.tool_calls {
|
||||||
|
Some(tc) if !tc.is_empty() => tc,
|
||||||
|
_ => {
|
||||||
|
return Ok(AgentRunResult {
|
||||||
|
reply: response.content.unwrap_or_default(),
|
||||||
|
total_input_tokens,
|
||||||
|
total_output_tokens,
|
||||||
|
iterations,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 达到上限:强制结束
|
||||||
|
if iterations >= self.max_iterations {
|
||||||
|
messages.push(ChatMessage {
|
||||||
|
role: ChatMessageRole::User,
|
||||||
|
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
|
||||||
|
.to_string(),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将 assistant 的 tool_calls 加入消息历史
|
||||||
|
messages.push(ChatMessage {
|
||||||
|
role: ChatMessageRole::Assistant,
|
||||||
|
content: response.content.unwrap_or_default(),
|
||||||
|
tool_calls: Some(tool_calls.clone()),
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// 执行每个 Tool Call
|
||||||
|
for tc in &tool_calls {
|
||||||
|
let tool_result = match self.tool_registry.get(&tc.name) {
|
||||||
|
Some(tool) => {
|
||||||
|
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||||||
|
result.output
|
||||||
|
}
|
||||||
|
None => format!("未知 Tool: {}", tc.name),
|
||||||
|
};
|
||||||
|
|
||||||
|
messages.push(ChatMessage {
|
||||||
|
role: ChatMessageRole::Tool,
|
||||||
|
content: tool_result,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: Some(tc.id.clone()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
47
crates/erp-ai/src/agent/registry.rs
Normal file
47
crates/erp-ai/src/agent/registry.rs
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::tool::AgentTool;
|
||||||
|
|
||||||
|
/// Tool 注册表 — 管理所有可用的 Agent Tool
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
tools: HashMap<String, Arc<dyn AgentTool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ToolRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tools: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register(&mut self, tool: Arc<dyn AgentTool>) {
|
||||||
|
self.tools.insert(tool.name().to_string(), tool);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Option<&Arc<dyn AgentTool>> {
|
||||||
|
self.tools.get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn all_tools(&self) -> Vec<&Arc<dyn AgentTool>> {
|
||||||
|
self.tools.values().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 生成传给 LLM 的 ToolDefinition 列表
|
||||||
|
pub fn tool_definitions(&self) -> Vec<crate::dto::ToolDefinition> {
|
||||||
|
self.tools
|
||||||
|
.values()
|
||||||
|
.map(|t| crate::dto::ToolDefinition {
|
||||||
|
name: t.name().to_string(),
|
||||||
|
description: t.description().to_string(),
|
||||||
|
parameters: t.parameters_schema(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
55
crates/erp-ai/src/agent/tool.rs
Normal file
55
crates/erp-ai/src/agent/tool.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use erp_core::health_provider::HealthDataProvider;
|
||||||
|
use sea_orm::DatabaseConnection;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
/// Agent Tool trait — 所有 Agent 可调用的工具都实现此 trait
|
||||||
|
#[async_trait]
|
||||||
|
pub trait AgentTool: Send + Sync {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
fn description(&self) -> &str;
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value;
|
||||||
|
async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool 执行上下文 — 包含安全过滤后的租户/用户信息
|
||||||
|
pub struct ToolContext {
|
||||||
|
pub tenant_id: Uuid,
|
||||||
|
pub user_id: Uuid,
|
||||||
|
pub patient_id: Option<Uuid>,
|
||||||
|
pub db: DatabaseConnection,
|
||||||
|
pub health_provider: Arc<dyn HealthDataProvider>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool 执行结果
|
||||||
|
pub struct ToolResult {
|
||||||
|
pub output: String,
|
||||||
|
pub display_hint: Option<DisplayHint>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 前端渲染提示 — 告诉前端如何富化展示 Tool 返回的数据
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum DisplayHint {
|
||||||
|
VitalCard {
|
||||||
|
indicator_type: String,
|
||||||
|
values: Vec<(String, f64)>,
|
||||||
|
unit: String,
|
||||||
|
},
|
||||||
|
LabReportCard {
|
||||||
|
report_date: String,
|
||||||
|
abnormal_count: usize,
|
||||||
|
},
|
||||||
|
ActionConfirm {
|
||||||
|
action_type: String,
|
||||||
|
summary: String,
|
||||||
|
confirm_payload: serde_json::Value,
|
||||||
|
},
|
||||||
|
RiskAlert {
|
||||||
|
level: String,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
Text,
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod agent;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod copilot;
|
pub mod copilot;
|
||||||
pub mod dto;
|
pub mod dto;
|
||||||
|
|||||||
Reference in New Issue
Block a user