feat: add internal ZCLAW kernel crates to git tracking
This commit is contained in:
103
crates/zclaw-runtime/src/loop_guard.rs
Normal file
103
crates/zclaw-runtime/src/loop_guard.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
//! Loop guard to prevent infinite tool loops
|
||||
|
||||
use sha2::{Sha256, Digest};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for loop guard
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoopGuardConfig {
|
||||
/// Warn after this many repetitions
|
||||
pub warn_threshold: u32,
|
||||
/// Block tool call after this many repetitions
|
||||
pub block_threshold: u32,
|
||||
/// Terminate loop after this many total repetitions
|
||||
pub circuit_breaker: u32,
|
||||
}
|
||||
|
||||
impl Default for LoopGuardConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
warn_threshold: 3,
|
||||
block_threshold: 5,
|
||||
circuit_breaker: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loop guard state
|
||||
#[derive(Debug)]
|
||||
pub struct LoopGuard {
|
||||
config: LoopGuardConfig,
|
||||
/// Hash of (tool_name, params) -> count
|
||||
call_counts: HashMap<String, u32>,
|
||||
/// Total calls in this session
|
||||
total_calls: u32,
|
||||
}
|
||||
|
||||
impl LoopGuard {
|
||||
pub fn new(config: LoopGuardConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
call_counts: HashMap::new(),
|
||||
total_calls: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool call should be allowed
|
||||
pub fn check(&mut self, tool_name: &str, params: &serde_json::Value) -> LoopGuardResult {
|
||||
let hash = self.hash_call(tool_name, params);
|
||||
let count = self.call_counts.entry(hash).or_insert(0);
|
||||
|
||||
self.total_calls += 1;
|
||||
*count += 1;
|
||||
|
||||
// Check circuit breaker first
|
||||
if self.total_calls > self.config.circuit_breaker {
|
||||
return LoopGuardResult::CircuitBreaker;
|
||||
}
|
||||
|
||||
// Check block threshold
|
||||
if *count > self.config.block_threshold {
|
||||
return LoopGuardResult::Blocked;
|
||||
}
|
||||
|
||||
// Check warn threshold
|
||||
if *count > self.config.warn_threshold {
|
||||
return LoopGuardResult::Warn;
|
||||
}
|
||||
|
||||
LoopGuardResult::Allowed
|
||||
}
|
||||
|
||||
/// Reset the guard state
|
||||
pub fn reset(&mut self) {
|
||||
self.call_counts.clear();
|
||||
self.total_calls = 0;
|
||||
}
|
||||
|
||||
fn hash_call(&self, tool_name: &str, params: &serde_json::Value) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(tool_name.as_bytes());
|
||||
hasher.update(params.to_string().as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LoopGuard {
|
||||
fn default() -> Self {
|
||||
Self::new(LoopGuardConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of loop guard check
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LoopGuardResult {
|
||||
/// Call is allowed
|
||||
Allowed,
|
||||
/// Call is allowed but should warn
|
||||
Warn,
|
||||
/// Call should be blocked
|
||||
Blocked,
|
||||
/// Loop should be terminated
|
||||
CircuitBreaker,
|
||||
}
|
||||
Reference in New Issue
Block a user