use std::sync::Arc; use super::registry::ToolRegistry; use super::tool::ToolContext; use crate::dto::{ChatMessage, ChatMessageRole}; use crate::error::AiResult; use crate::provider::AiProvider; /// Agent 运行时参数 pub struct AgentRunParams { pub model: String, pub temperature: f32, pub max_tokens: u32, pub max_iterations: usize, } impl Default for AgentRunParams { fn default() -> Self { Self { model: "claude-sonnet-4-6".to_string(), temperature: 0.7, max_tokens: 2048, max_iterations: 5, } } } /// Agent Orchestrator — 执行 ReAct 循环 pub struct AgentOrchestrator { provider: Arc, tool_registry: Arc, } /// Agent 运行结果 pub struct AgentRunResult { pub reply: String, pub total_input_tokens: u32, pub total_output_tokens: u32, pub iterations: usize, } impl AgentOrchestrator { pub fn new(provider: Arc, tool_registry: Arc) -> Self { Self { provider, tool_registry, } } /// 执行 Agent ReAct 循环 pub async fn run( &self, system_prompt: &str, messages: &mut Vec, ctx: &ToolContext, params: &AgentRunParams, allowed_tools: Option<&std::collections::HashSet>, ) -> AiResult { let tools = match allowed_tools { Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed), None => self.tool_registry.tool_definitions(), }; let mut iterations = 0; let mut total_input_tokens = 0u32; let mut total_output_tokens = 0u32; loop { iterations += 1; let response = self .provider .generate_with_tools( messages.clone(), tools.clone(), system_prompt, ¶ms.model, params.temperature, params.max_tokens, ) .await?; if let Some(ref usage) = response.usage { total_input_tokens += usage.input; total_output_tokens += usage.output; } // 如果没有 tool_calls,Agent 给出最终回复 let tool_calls = match response.tool_calls { Some(tc) if !tc.is_empty() => tc, _ => { return Ok(AgentRunResult { reply: response.content.unwrap_or_default(), total_input_tokens, total_output_tokens, iterations, }); } }; // 达到上限:强制结束 if iterations >= params.max_iterations { messages.push(ChatMessage { role: ChatMessageRole::User, content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)" .to_string(), tool_calls: None, tool_call_id: None, }); continue; } // 将 assistant 的 tool_calls 加入消息历史 messages.push(ChatMessage { role: ChatMessageRole::Assistant, content: response.content.unwrap_or_default(), tool_calls: Some(tool_calls.clone()), tool_call_id: None, }); // 执行每个 Tool Call(受沙箱 allowed_tools 约束) for tc in &tool_calls { let tool_result = match self.tool_registry.get(&tc.name) { Some(tool) => { // 沙箱过滤:如果 allowed_tools 存在且不包含此 Tool,拒绝执行 if let Some(allowed) = allowed_tools { if !allowed.contains(tc.name.as_str()) { format!("Tool '{}' 在当前角色下不可用", tc.name) } else { let result = tool.execute(ctx, tc.arguments.clone()).await; result.output } } else { let result = tool.execute(ctx, tc.arguments.clone()).await; result.output } } None => format!("未知 Tool: {}", tc.name), }; messages.push(ChatMessage { role: ChatMessageRole::Tool, content: tool_result, tool_calls: None, tool_call_id: Some(tc.id.clone()), }); } } } }