feat(ai): AgentTool trait + ToolRegistry + AgentOrchestrator — ReAct 循环(最多 5 轮 Tool Call)

This commit is contained in:
iven
2026-05-18 02:56:26 +08:00
parent 877e9831f6
commit 2d62605812
5 changed files with 228 additions and 0 deletions

View File

@@ -0,0 +1,117 @@
use std::sync::Arc;
use super::registry::ToolRegistry;
use super::tool::ToolContext;
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::error::AiResult;
use crate::provider::AiProvider;
/// Agent Orchestrator — 执行 ReAct 循环
pub struct AgentOrchestrator {
provider: Arc<dyn AiProvider>,
tool_registry: Arc<ToolRegistry>,
max_iterations: usize,
}
/// Agent 运行结果
pub struct AgentRunResult {
pub reply: String,
pub total_input_tokens: u32,
pub total_output_tokens: u32,
pub iterations: usize,
}
impl AgentOrchestrator {
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
Self {
provider,
tool_registry,
max_iterations: 5,
}
}
/// 执行 Agent ReAct 循环
pub async fn run(
&self,
system_prompt: &str,
messages: &mut Vec<ChatMessage>,
ctx: &ToolContext,
) -> AiResult<AgentRunResult> {
let tools = self.tool_registry.tool_definitions();
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
loop {
iterations += 1;
let response = self
.provider
.generate_with_tools(
messages.clone(),
tools.clone(),
system_prompt,
"auto",
0.7,
2048,
)
.await?;
if let Some(ref usage) = response.usage {
total_input_tokens += usage.input;
total_output_tokens += usage.output;
}
// 如果没有 tool_callsAgent 给出最终回复
let tool_calls = match response.tool_calls {
Some(tc) if !tc.is_empty() => tc,
_ => {
return Ok(AgentRunResult {
reply: response.content.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
iterations,
});
}
};
// 达到上限:强制结束
if iterations >= self.max_iterations {
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
.to_string(),
tool_calls: None,
tool_call_id: None,
});
continue;
}
// 将 assistant 的 tool_calls 加入消息历史
messages.push(ChatMessage {
role: ChatMessageRole::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
});
// 执行每个 Tool Call
for tc in &tool_calls {
let tool_result = match self.tool_registry.get(&tc.name) {
Some(tool) => {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
None => format!("未知 Tool: {}", tc.name),
};
messages.push(ChatMessage {
role: ChatMessageRole::Tool,
content: tool_result,
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
}
}
}