feat(runtime): DeerFlow 模式中间件链 Phase 1-4 全部完成
借鉴 DeerFlow 架构,实现完整中间件链系统: Phase 1 - Agent 中间件链基础设施 - MiddlewareChain Clone 支持 - LoopRunner 双路径集成 (middleware/legacy) - Kernel create_middleware_chain() 工厂方法 Phase 2 - 技能按需注入 - SkillIndexMiddleware (priority 200) - SkillLoadTool 工具 - SkillDetail/SkillIndexEntry 结构体 - KernelSkillExecutor trait 扩展 Phase 3 - Guardrail 安全护栏 - GuardrailMiddleware (priority 400, fail_open) - ShellExecRule / FileWriteRule / WebFetchRule Phase 4 - 记忆闭环统一 - MemoryMiddleware (priority 150, 30s 防抖) - after_completion 双路径调用 中间件注册顺序: 100 Compaction | 150 Memory | 200 SkillIndex 400 Guardrail | 500 LoopGuard | 700 TokenCalibration 向后兼容:Option<MiddlewareChain> 默认 None 走旧路径
This commit is contained in:
61
crates/zclaw-runtime/src/middleware/compaction.rs
Normal file
61
crates/zclaw-runtime/src/middleware/compaction.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
//! Compaction middleware — wraps the existing compaction module.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::driver::LlmDriver;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Middleware that compresses conversation history when it exceeds a token threshold.
|
||||
pub struct CompactionMiddleware {
|
||||
threshold: usize,
|
||||
config: CompactionConfig,
|
||||
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
|
||||
driver: Option<Arc<dyn LlmDriver>>,
|
||||
/// Optional growth integration for memory flushing during compaction.
|
||||
growth: Option<GrowthIntegration>,
|
||||
}
|
||||
|
||||
impl CompactionMiddleware {
|
||||
pub fn new(
|
||||
threshold: usize,
|
||||
config: CompactionConfig,
|
||||
driver: Option<Arc<dyn LlmDriver>>,
|
||||
growth: Option<GrowthIntegration>,
|
||||
) -> Self {
|
||||
Self { threshold, config, driver, growth }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for CompactionMiddleware {
|
||||
fn name(&self) -> &str { "compaction" }
|
||||
fn priority(&self) -> i32 { 100 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
if self.threshold == 0 {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
let outcome = compaction::maybe_compact_with_config(
|
||||
ctx.messages.clone(),
|
||||
self.threshold,
|
||||
&self.config,
|
||||
&ctx.agent_id,
|
||||
&ctx.session_id,
|
||||
self.driver.as_ref(),
|
||||
self.growth.as_ref(),
|
||||
)
|
||||
.await;
|
||||
ctx.messages = outcome.messages;
|
||||
} else {
|
||||
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
|
||||
}
|
||||
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
223
crates/zclaw-runtime/src/middleware/guardrail.rs
Normal file
223
crates/zclaw-runtime/src/middleware/guardrail.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Guardrail middleware — configurable safety rules for tool call evaluation.
|
||||
//!
|
||||
//! This middleware inspects tool calls before execution and can block or
|
||||
//! modify them based on configurable rules. Inspired by DeerFlow's safety
|
||||
//! evaluation hooks.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
|
||||
/// A single guardrail rule that can inspect and decide on tool calls.
|
||||
pub trait GuardrailRule: Send + Sync {
|
||||
/// Human-readable name for logging.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Evaluate a tool call.
|
||||
fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict;
|
||||
}
|
||||
|
||||
/// Decision returned by a guardrail rule.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum GuardrailVerdict {
|
||||
/// Allow the tool call to proceed.
|
||||
Allow,
|
||||
/// Block the call and return *message* as an error to the LLM.
|
||||
Block(String),
|
||||
}
|
||||
|
||||
/// Middleware that evaluates tool calls against a set of configurable safety rules.
|
||||
///
|
||||
/// Rules are grouped by tool name. When a tool call is made, all rules for
|
||||
/// that tool are evaluated in order. If any rule returns `Block`, the call
|
||||
/// is blocked. This is a "deny-by-exception" model — calls are allowed unless
|
||||
/// a rule explicitly blocks them.
|
||||
pub struct GuardrailMiddleware {
|
||||
/// Rules keyed by tool name.
|
||||
rules: HashMap<String, Vec<Box<dyn GuardrailRule>>>,
|
||||
/// Default policy for tools with no specific rules: true = allow, false = block.
|
||||
fail_open: bool,
|
||||
}
|
||||
|
||||
impl GuardrailMiddleware {
|
||||
pub fn new(fail_open: bool) -> Self {
|
||||
Self {
|
||||
rules: HashMap::new(),
|
||||
fail_open,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a guardrail rule for a specific tool.
|
||||
pub fn add_rule(&mut self, tool_name: impl Into<String>, rule: Box<dyn GuardrailRule>) {
|
||||
self.rules.entry(tool_name.into()).or_default().push(rule);
|
||||
}
|
||||
|
||||
/// Register built-in safety rules (shell_exec, file_write, web_fetch).
|
||||
pub fn with_builtin_rules(mut self) -> Self {
|
||||
self.add_rule("shell_exec", Box::new(ShellExecRule));
|
||||
self.add_rule("file_write", Box::new(FileWriteRule));
|
||||
self.add_rule("web_fetch", Box::new(WebFetchRule));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for GuardrailMiddleware {
|
||||
fn name(&self) -> &str { "guardrail" }
|
||||
fn priority(&self) -> i32 { 400 }
|
||||
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
_ctx: &MiddlewareContext,
|
||||
tool_name: &str,
|
||||
tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
if let Some(rules) = self.rules.get(tool_name) {
|
||||
for rule in rules {
|
||||
match rule.evaluate(tool_name, tool_input) {
|
||||
GuardrailVerdict::Allow => {}
|
||||
GuardrailVerdict::Block(msg) => {
|
||||
tracing::warn!(
|
||||
"[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}",
|
||||
rule.name(),
|
||||
tool_name,
|
||||
msg
|
||||
);
|
||||
return Ok(ToolCallDecision::Block(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if !self.fail_open {
|
||||
// fail-closed: unknown tools are blocked
|
||||
tracing::warn!(
|
||||
"[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it",
|
||||
tool_name
|
||||
);
|
||||
return Ok(ToolCallDecision::Block(format!(
|
||||
"工具 '{}' 未注册安全规则,fail-closed 策略阻止执行",
|
||||
tool_name
|
||||
)));
|
||||
}
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Built-in rules
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Rule that blocks dangerous shell commands.
|
||||
pub struct ShellExecRule;
|
||||
|
||||
impl GuardrailRule for ShellExecRule {
|
||||
fn name(&self) -> &str { "shell_exec_dangerous_commands" }
|
||||
|
||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||
let cmd = tool_input["command"].as_str().unwrap_or("");
|
||||
let dangerous = [
|
||||
"rm -rf /",
|
||||
"rm -rf ~",
|
||||
"del /s /q C:\\",
|
||||
"format ",
|
||||
"mkfs.",
|
||||
"dd if=",
|
||||
":(){ :|:& };:", // fork bomb
|
||||
"> /dev/sda",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
];
|
||||
let cmd_lower = cmd.to_lowercase();
|
||||
for pattern in &dangerous {
|
||||
if cmd_lower.contains(pattern) {
|
||||
return GuardrailVerdict::Block(format!(
|
||||
"危险命令被安全护栏拦截: 包含 '{}'",
|
||||
pattern
|
||||
));
|
||||
}
|
||||
}
|
||||
GuardrailVerdict::Allow
|
||||
}
|
||||
}
|
||||
|
||||
/// Rule that blocks writes to critical system directories.
|
||||
pub struct FileWriteRule;
|
||||
|
||||
impl GuardrailRule for FileWriteRule {
|
||||
fn name(&self) -> &str { "file_write_critical_dirs" }
|
||||
|
||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||
let path = tool_input["path"].as_str().unwrap_or("");
|
||||
let critical_prefixes = [
|
||||
"/etc/",
|
||||
"/usr/",
|
||||
"/bin/",
|
||||
"/sbin/",
|
||||
"/boot/",
|
||||
"/System/",
|
||||
"/Library/",
|
||||
"C:\\Windows\\",
|
||||
"C:\\Program Files\\",
|
||||
"C:\\ProgramData\\",
|
||||
];
|
||||
let path_lower = path.to_lowercase();
|
||||
for prefix in &critical_prefixes {
|
||||
if path_lower.starts_with(&prefix.to_lowercase()) {
|
||||
return GuardrailVerdict::Block(format!(
|
||||
"写入系统关键目录被拦截: {}",
|
||||
path
|
||||
));
|
||||
}
|
||||
}
|
||||
GuardrailVerdict::Allow
|
||||
}
|
||||
}
|
||||
|
||||
/// Rule that blocks web requests to internal/private network addresses.
|
||||
pub struct WebFetchRule;
|
||||
|
||||
impl GuardrailRule for WebFetchRule {
|
||||
fn name(&self) -> &str { "web_fetch_private_network" }
|
||||
|
||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
||||
let url = tool_input["url"].as_str().unwrap_or("");
|
||||
let blocked = [
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
"10.",
|
||||
"172.16.",
|
||||
"172.17.",
|
||||
"172.18.",
|
||||
"172.19.",
|
||||
"172.20.",
|
||||
"172.21.",
|
||||
"172.22.",
|
||||
"172.23.",
|
||||
"172.24.",
|
||||
"172.25.",
|
||||
"172.26.",
|
||||
"172.27.",
|
||||
"172.28.",
|
||||
"172.29.",
|
||||
"172.30.",
|
||||
"172.31.",
|
||||
"192.168.",
|
||||
"::1",
|
||||
"169.254.",
|
||||
"metadata.google",
|
||||
"metadata.azure",
|
||||
];
|
||||
let url_lower = url.to_lowercase();
|
||||
for prefix in &blocked {
|
||||
if url_lower.contains(prefix) {
|
||||
return GuardrailVerdict::Block(format!(
|
||||
"请求内网/私有地址被拦截: {}",
|
||||
url
|
||||
));
|
||||
}
|
||||
}
|
||||
GuardrailVerdict::Allow
|
||||
}
|
||||
}
|
||||
57
crates/zclaw-runtime/src/middleware/loop_guard.rs
Normal file
57
crates/zclaw-runtime/src/middleware/loop_guard.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
//! Loop guard middleware — extracts loop detection into a middleware hook.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Middleware that detects and blocks repetitive tool-call loops.
|
||||
pub struct LoopGuardMiddleware {
|
||||
guard: Mutex<LoopGuard>,
|
||||
}
|
||||
|
||||
impl LoopGuardMiddleware {
|
||||
pub fn new(config: LoopGuardConfig) -> Self {
|
||||
Self {
|
||||
guard: Mutex::new(LoopGuard::new(config)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
Self {
|
||||
guard: Mutex::new(LoopGuard::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for LoopGuardMiddleware {
|
||||
fn name(&self) -> &str { "loop_guard" }
|
||||
fn priority(&self) -> i32 { 500 }
|
||||
|
||||
async fn before_tool_call(
|
||||
&self,
|
||||
_ctx: &MiddlewareContext,
|
||||
tool_name: &str,
|
||||
tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
let result = self.guard.lock().unwrap().check(tool_name, tool_input);
|
||||
match result {
|
||||
LoopGuardResult::CircuitBreaker => {
|
||||
tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name);
|
||||
Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string()))
|
||||
}
|
||||
LoopGuardResult::Blocked => {
|
||||
tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name);
|
||||
Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string()))
|
||||
}
|
||||
LoopGuardResult::Warn => {
|
||||
tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name);
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow),
|
||||
}
|
||||
}
|
||||
}
|
||||
115
crates/zclaw-runtime/src/middleware/memory.rs
Normal file
115
crates/zclaw-runtime/src/middleware/memory.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
//! Memory middleware — unified pre/post hooks for memory retrieval and extraction.
|
||||
//!
|
||||
//! This middleware unifies the memory lifecycle:
|
||||
//! - `before_completion`: retrieves relevant memories and injects them into the system prompt
|
||||
//! - `after_completion`: extracts learnings from the conversation and stores them
|
||||
//!
|
||||
//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the
|
||||
//! `intelligence_hooks` calls in the Tauri desktop layer.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
|
||||
/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion).
|
||||
///
|
||||
/// Wraps `GrowthIntegration` and delegates:
|
||||
/// - `before_completion` → `enhance_prompt()` for memory injection
|
||||
/// - `after_completion` → `process_conversation()` for memory extraction
|
||||
pub struct MemoryMiddleware {
|
||||
growth: GrowthIntegration,
|
||||
/// Minimum seconds between extractions for the same agent (debounce).
|
||||
debounce_secs: u64,
|
||||
/// Timestamp of last extraction per agent (for debouncing).
|
||||
last_extraction: std::sync::Mutex<std::collections::HashMap<String, std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl MemoryMiddleware {
|
||||
pub fn new(growth: GrowthIntegration) -> Self {
|
||||
Self {
|
||||
growth,
|
||||
debounce_secs: 30,
|
||||
last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the debounce interval in seconds.
|
||||
pub fn with_debounce_secs(mut self, secs: u64) -> Self {
|
||||
self.debounce_secs = secs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if enough time has passed since the last extraction for this agent.
|
||||
fn should_extract(&self, agent_id: &str) -> bool {
|
||||
let now = std::time::Instant::now();
|
||||
let mut map = self.last_extraction.lock().unwrap();
|
||||
if let Some(last) = map.get(agent_id) {
|
||||
if now.duration_since(*last).as_secs() < self.debounce_secs {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
map.insert(agent_id.to_string(), now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for MemoryMiddleware {
|
||||
fn name(&self) -> &str { "memory" }
|
||||
fn priority(&self) -> i32 { 150 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
match self.growth.enhance_prompt(
|
||||
&ctx.agent_id,
|
||||
&ctx.system_prompt,
|
||||
&ctx.user_input,
|
||||
).await {
|
||||
Ok(enhanced) => {
|
||||
ctx.system_prompt = enhanced;
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal: memory retrieval failure should not block the loop
|
||||
tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e);
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||
// Debounce: skip extraction if called too recently for this agent
|
||||
let agent_key = ctx.agent_id.to_string();
|
||||
if !self.should_extract(&agent_key) {
|
||||
tracing::debug!(
|
||||
"[MemoryMiddleware] Skipping extraction for agent {} (debounced)",
|
||||
agent_key
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if ctx.messages.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.growth.process_conversation(
|
||||
&ctx.agent_id,
|
||||
&ctx.messages,
|
||||
ctx.session_id.clone(),
|
||||
).await {
|
||||
Ok(count) => {
|
||||
tracing::info!(
|
||||
"[MemoryMiddleware] Extracted {} memories for agent {}",
|
||||
count,
|
||||
agent_key
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal: extraction failure should not affect the response
|
||||
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
62
crates/zclaw-runtime/src/middleware/skill_index.rs
Normal file
62
crates/zclaw-runtime/src/middleware/skill_index.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! Skill index middleware — injects a lightweight skill index into the system prompt.
|
||||
//!
|
||||
//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills),
|
||||
//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then
|
||||
//! call the `skill_load` tool on demand to retrieve full skill details when needed.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::tool::{SkillIndexEntry, SkillExecutor};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Middleware that injects a lightweight skill index into the system prompt.
|
||||
///
|
||||
/// The index format is compact:
|
||||
/// ```text
|
||||
/// ## Skills (index — use skill_load for details)
|
||||
/// - finance-tracker: 财务分析、财报解读 [数据分析]
|
||||
/// - senior-developer: 代码开发、架构设计 [开发工程]
|
||||
/// ```
|
||||
pub struct SkillIndexMiddleware {
|
||||
/// Pre-built skill index entries, constructed at chain creation time.
|
||||
entries: Vec<SkillIndexEntry>,
|
||||
}
|
||||
|
||||
impl SkillIndexMiddleware {
|
||||
pub fn new(entries: Vec<SkillIndexEntry>) -> Self {
|
||||
Self { entries }
|
||||
}
|
||||
|
||||
/// Build index entries from a skill executor that supports listing.
|
||||
pub fn from_executor(executor: &Arc<dyn SkillExecutor>) -> Self {
|
||||
Self {
|
||||
entries: executor.list_skill_index(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for SkillIndexMiddleware {
|
||||
fn name(&self) -> &str { "skill_index" }
|
||||
fn priority(&self) -> i32 { 200 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
if self.entries.is_empty() {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n");
|
||||
for entry in &self.entries {
|
||||
let triggers = if entry.triggers.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" — {}", entry.triggers.join(", "))
|
||||
};
|
||||
index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers));
|
||||
}
|
||||
|
||||
ctx.system_prompt.push_str(&index);
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
52
crates/zclaw-runtime/src/middleware/token_calibration.rs
Normal file
52
crates/zclaw-runtime/src/middleware/token_calibration.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
//! Token calibration middleware — calibrates token estimation after first LLM response.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::compaction;
|
||||
|
||||
/// Middleware that calibrates the global token estimation factor based on
|
||||
/// actual API-returned token counts from the first LLM response.
|
||||
pub struct TokenCalibrationMiddleware {
|
||||
/// Whether calibration has already been applied in this session.
|
||||
calibrated: std::sync::atomic::AtomicBool,
|
||||
}
|
||||
|
||||
impl TokenCalibrationMiddleware {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
calibrated: std::sync::atomic::AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TokenCalibrationMiddleware {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for TokenCalibrationMiddleware {
|
||||
fn name(&self) -> &str { "token_calibration" }
|
||||
fn priority(&self) -> i32 { 700 }
|
||||
|
||||
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
// Calibration happens in after_completion when we have actual token counts.
|
||||
// Before-completion is a no-op.
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
|
||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
||||
if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
let estimated = compaction::estimate_messages_tokens(&ctx.messages);
|
||||
compaction::update_calibration(estimated, ctx.input_tokens);
|
||||
self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
tracing::debug!(
|
||||
"[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}",
|
||||
estimated, ctx.input_tokens
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user