use std::sync::Arc; use super::registry::ToolRegistry; use super::tool::ToolContext; use crate::dto::{ChatMessage, ChatMessageRole}; use crate::error::AiResult; use crate::provider::AiProvider; /// 单次 Tool 调用日志 #[derive(Debug, Clone)] pub struct ToolCallLog { pub tool_name: String, pub duration_ms: u64, pub success: bool, } /// Agent 运行时参数 pub struct AgentRunParams { pub model: String, pub temperature: f32, pub max_tokens: u32, pub max_iterations: usize, /// 可选:累计 Token 预算(input + output),超出后强制结束 pub token_budget: Option, } impl Default for AgentRunParams { fn default() -> Self { Self { model: "claude-sonnet-4-6".to_string(), temperature: 0.7, max_tokens: 2048, max_iterations: 5, token_budget: None, } } } /// Agent Orchestrator — 执行 ReAct 循环 pub struct AgentOrchestrator { provider: Arc, tool_registry: Arc, } /// Agent 运行结果 pub struct AgentRunResult { pub reply: String, pub total_input_tokens: u32, pub total_output_tokens: u32, pub iterations: usize, pub tool_calls: Vec, pub display_hints: Vec, } impl AgentOrchestrator { pub fn new(provider: Arc, tool_registry: Arc) -> Self { Self { provider, tool_registry, } } /// 执行 Agent ReAct 循环 pub async fn run( &self, system_prompt: &str, messages: &mut Vec, ctx: &ToolContext, params: &AgentRunParams, allowed_tools: Option<&std::collections::HashSet>, ) -> AiResult { let tools = match allowed_tools { Some(allowed) => self.tool_registry.tool_definitions_filtered(allowed), None => self.tool_registry.tool_definitions(), }; let mut iterations = 0; let mut total_input_tokens = 0u32; let mut total_output_tokens = 0u32; let mut tool_call_logs: Vec = Vec::new(); let mut display_hints: Vec = Vec::new(); loop { iterations += 1; let response = self .provider .generate_with_tools( messages.clone(), tools.clone(), system_prompt, ¶ms.model, params.temperature, params.max_tokens, ) .await?; if let Some(ref usage) = response.usage { total_input_tokens += usage.input; total_output_tokens += usage.output; } // 如果没有 tool_calls,Agent 给出最终回复 let tool_calls = match response.tool_calls { Some(tc) if !tc.is_empty() => tc, _ => { return Ok(AgentRunResult { reply: response.content.unwrap_or_default(), total_input_tokens, total_output_tokens, iterations, tool_calls: tool_call_logs, display_hints, }); } }; // 达到上限:强制结束 if iterations >= params.max_iterations { messages.push(ChatMessage { role: ChatMessageRole::User, content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)" .to_string(), tool_calls: None, tool_call_id: None, }); continue; } // Token 预算检查:超出后强制结束 if let Some(budget) = params.token_budget { let total = total_input_tokens + total_output_tokens; if total >= budget { tracing::warn!( total_tokens = total, budget = budget, iterations = iterations, "Token budget exhausted, forcing final reply" ); messages.push(ChatMessage { role: ChatMessageRole::User, content: "(系统提示:Token 预算已用尽,请立即基于已有信息总结回复用户,不要再调用工具)" .to_string(), tool_calls: None, tool_call_id: None, }); continue; } } // 将 assistant 的 tool_calls 加入消息历史 messages.push(ChatMessage { role: ChatMessageRole::Assistant, content: response.content.unwrap_or_default(), tool_calls: Some(tool_calls.clone()), tool_call_id: None, }); // 执行每个 Tool Call(受沙箱 allowed_tools 约束) for tc in &tool_calls { let start = std::time::Instant::now(); let (tool_result, success, hint) = match self.tool_registry.get(&tc.name) { Some(tool) => { if let Some(allowed) = allowed_tools { if !allowed.contains(tc.name.as_str()) { ( format!("Tool '{}' 在当前角色下不可用", tc.name), false, None, ) } else { let result = tool.execute(ctx, tc.arguments.clone()).await; (result.output, true, result.display_hint) } } else { let result = tool.execute(ctx, tc.arguments.clone()).await; (result.output, true, result.display_hint) } } None => (format!("未知 Tool: {}", tc.name), false, None), }; let duration = start.elapsed(); tool_call_logs.push(ToolCallLog { tool_name: tc.name.clone(), duration_ms: duration.as_millis() as u64, success, }); if let Some(h) = hint { display_hints.push(h); } messages.push(ChatMessage { role: ChatMessageRole::Tool, content: tool_result, tool_calls: None, tool_call_id: Some(tc.id.clone()), }); } } } }