diff --git a/crates/zclaw-runtime/src/middleware.rs b/crates/zclaw-runtime/src/middleware.rs index 196c688..2d950af 100644 --- a/crates/zclaw-runtime/src/middleware.rs +++ b/crates/zclaw-runtime/src/middleware.rs @@ -12,6 +12,13 @@ //! | 200-399 | Capability | SkillIndex, Guardrail | //! | 400-599 | Safety | LoopGuard, Guardrail | //! | 600-799 | Telemetry | TokenCalibration, Tracking | +//! +//! # Wave parallelization +//! +//! `before_completion` middlewares that only modify `system_prompt` (not `messages`) +//! can declare `parallel_safe() == true`. The chain runs consecutive parallel-safe +//! middlewares concurrently, merging their prompt contributions. This reduces +//! sequential latency for the context-injection phase. use std::sync::Arc; use async_trait::async_trait; @@ -50,6 +57,7 @@ pub enum ToolCallDecision { // --------------------------------------------------------------------------- /// Carries the mutable state that middleware may inspect or modify. +#[derive(Clone)] pub struct MiddlewareContext { /// The agent that owns this loop. pub agent_id: AgentId, @@ -101,6 +109,15 @@ pub trait AgentMiddleware: Send + Sync { 500 } + /// Whether `before_completion` is safe to run concurrently with other + /// parallel-safe middlewares. Only return `true` if the middleware: + /// - Only modifies `ctx.system_prompt` (never `ctx.messages`) + /// - Does not depend on prompt modifications from other middlewares + /// - Does not return `MiddlewareDecision::Stop` + fn parallel_safe(&self) -> bool { + false + } + /// Hook executed **before** the LLM completion request is sent. /// /// Use this to inject context (memory, skill index, etc.) or to @@ -163,15 +180,74 @@ impl MiddlewareChain { self.middlewares.insert(pos, mw); } - /// Run all `before_completion` hooks in order. + /// Run all `before_completion` hooks with wave-based parallelization. + /// + /// Consecutive `parallel_safe` middlewares run concurrently — each gets + /// its own cloned context and appends to `system_prompt` independently. + /// Their contributions are merged after all complete. Non-parallel-safe + /// middlewares (and non-consecutive ones) run sequentially as before. pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result { - for mw in &self.middlewares { - match mw.before_completion(ctx).await? { - MiddlewareDecision::Continue => {} - MiddlewareDecision::Stop(reason) => { - tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason); - return Ok(MiddlewareDecision::Stop(reason)); + let mut idx = 0; + while idx < self.middlewares.len() { + // Find the extent of consecutive parallel-safe middlewares + let wave_start = idx; + let mut wave_end = idx; + while wave_end < self.middlewares.len() + && self.middlewares[wave_end].parallel_safe() + { + wave_end += 1; + } + + if wave_end - wave_start >= 2 { + // Run parallel wave (2+ consecutive parallel-safe middlewares) + let base_prompt_len = ctx.system_prompt.len(); + let wave = &self.middlewares[wave_start..wave_end]; + + // Spawn concurrent tasks — each owns its cloned context + Arc ref to middleware + let mut join_handles = Vec::with_capacity(wave.len()); + for mw in wave.iter() { + let mut ctx_clone = ctx.clone(); + let mw_arc = Arc::clone(mw); + join_handles.push(tokio::spawn(async move { + let result = mw_arc.before_completion(&mut ctx_clone).await; + (result, ctx_clone.system_prompt) + })); } + + // Await all and merge prompt contributions + for (i, handle) in join_handles.into_iter().enumerate() { + let (result, modified_prompt): (Result, String) = handle.await + .map_err(|e| zclaw_types::ZclawError::Internal(format!("Parallel middleware panicked: {}", e)))?; + match result? { + MiddlewareDecision::Continue => {} + MiddlewareDecision::Stop(reason) => { + tracing::info!( + "[MiddlewareChain] '{}' requested stop: {}", + self.middlewares[wave_start + i].name(), + reason + ); + return Ok(MiddlewareDecision::Stop(reason)); + } + } + // Merge system_prompt contribution from this clone + if modified_prompt.len() > base_prompt_len { + let contribution = &modified_prompt[base_prompt_len..]; + ctx.system_prompt.push_str(contribution); + } + } + + idx = wave_end; + } else { + // Run single middleware sequentially + let mw = &self.middlewares[idx]; + match mw.before_completion(ctx).await? { + MiddlewareDecision::Continue => {} + MiddlewareDecision::Stop(reason) => { + tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason); + return Ok(MiddlewareDecision::Stop(reason)); + } + } + idx += 1; } } Ok(MiddlewareDecision::Continue) diff --git a/crates/zclaw-runtime/src/middleware/butler_router.rs b/crates/zclaw-runtime/src/middleware/butler_router.rs index 3d7b4a7..cfef3b9 100644 --- a/crates/zclaw-runtime/src/middleware/butler_router.rs +++ b/crates/zclaw-runtime/src/middleware/butler_router.rs @@ -290,6 +290,8 @@ impl AgentMiddleware for ButlerRouterMiddleware { 80 } + fn parallel_safe(&self) -> bool { true } + async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { // Only route on the first user message in a turn (not tool results) let user_input = &ctx.user_input; diff --git a/crates/zclaw-runtime/src/middleware/evolution.rs b/crates/zclaw-runtime/src/middleware/evolution.rs index 3b12f9c..3517fd2 100644 --- a/crates/zclaw-runtime/src/middleware/evolution.rs +++ b/crates/zclaw-runtime/src/middleware/evolution.rs @@ -88,6 +88,8 @@ impl AgentMiddleware for EvolutionMiddleware { 78 // 在 ButlerRouter(80) 之前 } + fn parallel_safe(&self) -> bool { true } + async fn before_completion( &self, ctx: &mut MiddlewareContext, diff --git a/crates/zclaw-runtime/src/middleware/memory.rs b/crates/zclaw-runtime/src/middleware/memory.rs index 5417a58..627a234 100644 --- a/crates/zclaw-runtime/src/middleware/memory.rs +++ b/crates/zclaw-runtime/src/middleware/memory.rs @@ -111,6 +111,7 @@ impl MemoryMiddleware { impl AgentMiddleware for MemoryMiddleware { fn name(&self) -> &str { "memory" } fn priority(&self) -> i32 { 150 } + fn parallel_safe(&self) -> bool { true } async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { tracing::debug!( diff --git a/crates/zclaw-runtime/src/middleware/skill_index.rs b/crates/zclaw-runtime/src/middleware/skill_index.rs index 749c7b7..a43f827 100644 --- a/crates/zclaw-runtime/src/middleware/skill_index.rs +++ b/crates/zclaw-runtime/src/middleware/skill_index.rs @@ -40,6 +40,7 @@ impl SkillIndexMiddleware { impl AgentMiddleware for SkillIndexMiddleware { fn name(&self) -> &str { "skill_index" } fn priority(&self) -> i32 { 200 } + fn parallel_safe(&self) -> bool { true } async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { if self.entries.is_empty() { diff --git a/crates/zclaw-runtime/src/middleware/title.rs b/crates/zclaw-runtime/src/middleware/title.rs index c7dc371..627ba38 100644 --- a/crates/zclaw-runtime/src/middleware/title.rs +++ b/crates/zclaw-runtime/src/middleware/title.rs @@ -41,6 +41,7 @@ impl Default for TitleMiddleware { impl AgentMiddleware for TitleMiddleware { fn name(&self) -> &str { "title" } fn priority(&self) -> i32 { 180 } + fn parallel_safe(&self) -> bool { true } // All hooks default to Continue — placeholder until LLM driver is wired in. async fn before_completion(&self, _ctx: &mut crate::middleware::MiddlewareContext) -> zclaw_types::Result {