//! Loop guard to prevent infinite tool loops use sha2::{Sha256, Digest}; use std::collections::HashMap; /// Configuration for loop guard #[derive(Debug, Clone)] pub struct LoopGuardConfig { /// Warn after this many repetitions pub warn_threshold: u32, /// Block tool call after this many repetitions pub block_threshold: u32, /// Terminate loop after this many total repetitions pub circuit_breaker: u32, } impl Default for LoopGuardConfig { fn default() -> Self { Self { warn_threshold: 3, block_threshold: 5, circuit_breaker: 30, } } } /// Loop guard state #[derive(Debug, Clone)] pub struct LoopGuard { config: LoopGuardConfig, /// Hash of (tool_name, params) -> count call_counts: HashMap, /// Total calls in this session total_calls: u32, } impl LoopGuard { pub fn new(config: LoopGuardConfig) -> Self { Self { config, call_counts: HashMap::new(), total_calls: 0, } } /// Check if a tool call should be allowed pub fn check(&mut self, tool_name: &str, params: &serde_json::Value) -> LoopGuardResult { let hash = self.hash_call(tool_name, params); let count = self.call_counts.entry(hash).or_insert(0); self.total_calls += 1; *count += 1; // Check circuit breaker first if self.total_calls > self.config.circuit_breaker { return LoopGuardResult::CircuitBreaker; } // Check block threshold if *count > self.config.block_threshold { return LoopGuardResult::Blocked; } // Check warn threshold if *count > self.config.warn_threshold { return LoopGuardResult::Warn; } LoopGuardResult::Allowed } /// Reset the guard state pub fn reset(&mut self) { self.call_counts.clear(); self.total_calls = 0; } fn hash_call(&self, tool_name: &str, params: &serde_json::Value) -> String { let mut hasher = Sha256::new(); hasher.update(tool_name.as_bytes()); hasher.update(params.to_string().as_bytes()); format!("{:x}", hasher.finalize()) } } impl Default for LoopGuard { fn default() -> Self { Self::new(LoopGuardConfig::default()) } } /// Result of loop guard check #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum LoopGuardResult { /// Call is allowed Allowed, /// Call is allowed but should warn Warn, /// Call should be blocked Blocked, /// Loop should be terminated CircuitBreaker, }