feat(ai): AgentTool trait + ToolRegistry + AgentOrchestrator — ReAct 循环(最多 5 轮 Tool Call)
This commit is contained in:
8
crates/erp-ai/src/agent/mod.rs
Normal file
8
crates/erp-ai/src/agent/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod orchestrator;
|
||||
pub mod registry;
|
||||
pub mod tool;
|
||||
pub mod tools;
|
||||
|
||||
pub use orchestrator::AgentOrchestrator;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
|
||||
117
crates/erp-ai/src/agent/orchestrator.rs
Normal file
117
crates/erp-ai/src/agent/orchestrator.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::registry::ToolRegistry;
|
||||
use super::tool::ToolContext;
|
||||
use crate::dto::{ChatMessage, ChatMessageRole};
|
||||
use crate::error::AiResult;
|
||||
use crate::provider::AiProvider;
|
||||
|
||||
/// Agent Orchestrator — 执行 ReAct 循环
|
||||
pub struct AgentOrchestrator {
|
||||
provider: Arc<dyn AiProvider>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
max_iterations: usize,
|
||||
}
|
||||
|
||||
/// Agent 运行结果
|
||||
pub struct AgentRunResult {
|
||||
pub reply: String,
|
||||
pub total_input_tokens: u32,
|
||||
pub total_output_tokens: u32,
|
||||
pub iterations: usize,
|
||||
}
|
||||
|
||||
impl AgentOrchestrator {
|
||||
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tool_registry,
|
||||
max_iterations: 5,
|
||||
}
|
||||
}
|
||||
|
||||
/// 执行 Agent ReAct 循环
|
||||
pub async fn run(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
messages: &mut Vec<ChatMessage>,
|
||||
ctx: &ToolContext,
|
||||
) -> AiResult<AgentRunResult> {
|
||||
let tools = self.tool_registry.tool_definitions();
|
||||
let mut iterations = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
let mut total_output_tokens = 0u32;
|
||||
|
||||
loop {
|
||||
iterations += 1;
|
||||
|
||||
let response = self
|
||||
.provider
|
||||
.generate_with_tools(
|
||||
messages.clone(),
|
||||
tools.clone(),
|
||||
system_prompt,
|
||||
"auto",
|
||||
0.7,
|
||||
2048,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(ref usage) = response.usage {
|
||||
total_input_tokens += usage.input;
|
||||
total_output_tokens += usage.output;
|
||||
}
|
||||
|
||||
// 如果没有 tool_calls,Agent 给出最终回复
|
||||
let tool_calls = match response.tool_calls {
|
||||
Some(tc) if !tc.is_empty() => tc,
|
||||
_ => {
|
||||
return Ok(AgentRunResult {
|
||||
reply: response.content.unwrap_or_default(),
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 达到上限:强制结束
|
||||
if iterations >= self.max_iterations {
|
||||
messages.push(ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
|
||||
.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// 将 assistant 的 tool_calls 加入消息历史
|
||||
messages.push(ChatMessage {
|
||||
role: ChatMessageRole::Assistant,
|
||||
content: response.content.unwrap_or_default(),
|
||||
tool_calls: Some(tool_calls.clone()),
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
// 执行每个 Tool Call
|
||||
for tc in &tool_calls {
|
||||
let tool_result = match self.tool_registry.get(&tc.name) {
|
||||
Some(tool) => {
|
||||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||||
result.output
|
||||
}
|
||||
None => format!("未知 Tool: {}", tc.name),
|
||||
};
|
||||
|
||||
messages.push(ChatMessage {
|
||||
role: ChatMessageRole::Tool,
|
||||
content: tool_result,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tc.id.clone()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
47
crates/erp-ai/src/agent/registry.rs
Normal file
47
crates/erp-ai/src/agent/registry.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::tool::AgentTool;
|
||||
|
||||
/// Tool 注册表 — 管理所有可用的 Agent Tool
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Arc<dyn AgentTool>>,
|
||||
}
|
||||
|
||||
impl Default for ToolRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&mut self, tool: Arc<dyn AgentTool>) {
|
||||
self.tools.insert(tool.name().to_string(), tool);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<&Arc<dyn AgentTool>> {
|
||||
self.tools.get(name)
|
||||
}
|
||||
|
||||
pub fn all_tools(&self) -> Vec<&Arc<dyn AgentTool>> {
|
||||
self.tools.values().collect()
|
||||
}
|
||||
|
||||
/// 生成传给 LLM 的 ToolDefinition 列表
|
||||
pub fn tool_definitions(&self) -> Vec<crate::dto::ToolDefinition> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|t| crate::dto::ToolDefinition {
|
||||
name: t.name().to_string(),
|
||||
description: t.description().to_string(),
|
||||
parameters: t.parameters_schema(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
55
crates/erp-ai/src/agent/tool.rs
Normal file
55
crates/erp-ai/src/agent/tool.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use async_trait::async_trait;
|
||||
use erp_core::health_provider::HealthDataProvider;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Agent Tool trait — 所有 Agent 可调用的工具都实现此 trait
|
||||
#[async_trait]
|
||||
pub trait AgentTool: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn description(&self) -> &str;
|
||||
fn parameters_schema(&self) -> serde_json::Value;
|
||||
async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult;
|
||||
}
|
||||
|
||||
/// Tool 执行上下文 — 包含安全过滤后的租户/用户信息
|
||||
pub struct ToolContext {
|
||||
pub tenant_id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub patient_id: Option<Uuid>,
|
||||
pub db: DatabaseConnection,
|
||||
pub health_provider: Arc<dyn HealthDataProvider>,
|
||||
}
|
||||
|
||||
/// Tool 执行结果
|
||||
pub struct ToolResult {
|
||||
pub output: String,
|
||||
pub display_hint: Option<DisplayHint>,
|
||||
}
|
||||
|
||||
/// 前端渲染提示 — 告诉前端如何富化展示 Tool 返回的数据
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum DisplayHint {
|
||||
VitalCard {
|
||||
indicator_type: String,
|
||||
values: Vec<(String, f64)>,
|
||||
unit: String,
|
||||
},
|
||||
LabReportCard {
|
||||
report_date: String,
|
||||
abnormal_count: usize,
|
||||
},
|
||||
ActionConfirm {
|
||||
action_type: String,
|
||||
summary: String,
|
||||
confirm_payload: serde_json::Value,
|
||||
},
|
||||
RiskAlert {
|
||||
level: String,
|
||||
message: String,
|
||||
},
|
||||
Text,
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod agent;
|
||||
pub mod config;
|
||||
pub mod copilot;
|
||||
pub mod dto;
|
||||
|
||||
Reference in New Issue
Block a user