feat(ai): AgentTool trait + ToolRegistry + AgentOrchestrator — ReAct 循环(最多 5 轮 Tool Call)
This commit is contained in:
117
crates/erp-ai/src/agent/orchestrator.rs
Normal file
117
crates/erp-ai/src/agent/orchestrator.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::registry::ToolRegistry;
|
||||
use super::tool::ToolContext;
|
||||
use crate::dto::{ChatMessage, ChatMessageRole};
|
||||
use crate::error::AiResult;
|
||||
use crate::provider::AiProvider;
|
||||
|
||||
/// Agent Orchestrator — 执行 ReAct 循环
|
||||
pub struct AgentOrchestrator {
|
||||
provider: Arc<dyn AiProvider>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
max_iterations: usize,
|
||||
}
|
||||
|
||||
/// Agent 运行结果
|
||||
pub struct AgentRunResult {
|
||||
pub reply: String,
|
||||
pub total_input_tokens: u32,
|
||||
pub total_output_tokens: u32,
|
||||
pub iterations: usize,
|
||||
}
|
||||
|
||||
impl AgentOrchestrator {
|
||||
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
tool_registry,
|
||||
max_iterations: 5,
|
||||
}
|
||||
}
|
||||
|
||||
/// 执行 Agent ReAct 循环
|
||||
pub async fn run(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
messages: &mut Vec<ChatMessage>,
|
||||
ctx: &ToolContext,
|
||||
) -> AiResult<AgentRunResult> {
|
||||
let tools = self.tool_registry.tool_definitions();
|
||||
let mut iterations = 0;
|
||||
let mut total_input_tokens = 0u32;
|
||||
let mut total_output_tokens = 0u32;
|
||||
|
||||
loop {
|
||||
iterations += 1;
|
||||
|
||||
let response = self
|
||||
.provider
|
||||
.generate_with_tools(
|
||||
messages.clone(),
|
||||
tools.clone(),
|
||||
system_prompt,
|
||||
"auto",
|
||||
0.7,
|
||||
2048,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(ref usage) = response.usage {
|
||||
total_input_tokens += usage.input;
|
||||
total_output_tokens += usage.output;
|
||||
}
|
||||
|
||||
// 如果没有 tool_calls,Agent 给出最终回复
|
||||
let tool_calls = match response.tool_calls {
|
||||
Some(tc) if !tc.is_empty() => tc,
|
||||
_ => {
|
||||
return Ok(AgentRunResult {
|
||||
reply: response.content.unwrap_or_default(),
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 达到上限:强制结束
|
||||
if iterations >= self.max_iterations {
|
||||
messages.push(ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
|
||||
.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// 将 assistant 的 tool_calls 加入消息历史
|
||||
messages.push(ChatMessage {
|
||||
role: ChatMessageRole::Assistant,
|
||||
content: response.content.unwrap_or_default(),
|
||||
tool_calls: Some(tool_calls.clone()),
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
// 执行每个 Tool Call
|
||||
for tc in &tool_calls {
|
||||
let tool_result = match self.tool_registry.get(&tc.name) {
|
||||
Some(tool) => {
|
||||
let result = tool.execute(ctx, tc.arguments.clone()).await;
|
||||
result.output
|
||||
}
|
||||
None => format!("未知 Tool: {}", tc.name),
|
||||
};
|
||||
|
||||
messages.push(ChatMessage {
|
||||
role: ChatMessageRole::Tool,
|
||||
content: tool_result,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tc.id.clone()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user