perf(middleware): before_completion 分波并行执行
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- MiddlewareContext 加 Clone derive, 支持并行克隆上下文 - AgentMiddleware trait 新增 parallel_safe() 默认方法 (false) - MiddlewareChain::run_before_completion 改为分波执行: 连续 2+ 个 parallel_safe 中间件用 tokio::spawn 并发执行, 各自独立修改 system_prompt, 执行完成后合并贡献 - 5 个只修改 system_prompt 的中间件标记 parallel_safe: evolution(P78), butler_router(P80), memory(P150), title(P180), skill_index(P200) - 非 parallel_safe 中间件 (compaction, dangling_tool 等) 保持串行 分波效果: Wave 1: evolution + butler_router → 并行 (省 ~0.5-1s) Wave 2: compaction → 串行 (可能修改 messages) Wave 3: memory + title + skill_index → 并行 (省 ~0.5-2s) Wave 4+: 工具/安全中间件 → 串行
This commit is contained in:
@@ -12,6 +12,13 @@
|
||||
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
||||
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
||||
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
||||
//!
|
||||
//! # Wave parallelization
|
||||
//!
|
||||
//! `before_completion` middlewares that only modify `system_prompt` (not `messages`)
|
||||
//! can declare `parallel_safe() == true`. The chain runs consecutive parallel-safe
|
||||
//! middlewares concurrently, merging their prompt contributions. This reduces
|
||||
//! sequential latency for the context-injection phase.
|
||||
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
@@ -50,6 +57,7 @@ pub enum ToolCallDecision {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Carries the mutable state that middleware may inspect or modify.
|
||||
#[derive(Clone)]
|
||||
pub struct MiddlewareContext {
|
||||
/// The agent that owns this loop.
|
||||
pub agent_id: AgentId,
|
||||
@@ -101,6 +109,15 @@ pub trait AgentMiddleware: Send + Sync {
|
||||
500
|
||||
}
|
||||
|
||||
/// Whether `before_completion` is safe to run concurrently with other
|
||||
/// parallel-safe middlewares. Only return `true` if the middleware:
|
||||
/// - Only modifies `ctx.system_prompt` (never `ctx.messages`)
|
||||
/// - Does not depend on prompt modifications from other middlewares
|
||||
/// - Does not return `MiddlewareDecision::Stop`
|
||||
fn parallel_safe(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Hook executed **before** the LLM completion request is sent.
|
||||
///
|
||||
/// Use this to inject context (memory, skill index, etc.) or to
|
||||
@@ -163,15 +180,74 @@ impl MiddlewareChain {
|
||||
self.middlewares.insert(pos, mw);
|
||||
}
|
||||
|
||||
/// Run all `before_completion` hooks in order.
|
||||
/// Run all `before_completion` hooks with wave-based parallelization.
|
||||
///
|
||||
/// Consecutive `parallel_safe` middlewares run concurrently — each gets
|
||||
/// its own cloned context and appends to `system_prompt` independently.
|
||||
/// Their contributions are merged after all complete. Non-parallel-safe
|
||||
/// middlewares (and non-consecutive ones) run sequentially as before.
|
||||
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
for mw in &self.middlewares {
|
||||
match mw.before_completion(ctx).await? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
let mut idx = 0;
|
||||
while idx < self.middlewares.len() {
|
||||
// Find the extent of consecutive parallel-safe middlewares
|
||||
let wave_start = idx;
|
||||
let mut wave_end = idx;
|
||||
while wave_end < self.middlewares.len()
|
||||
&& self.middlewares[wave_end].parallel_safe()
|
||||
{
|
||||
wave_end += 1;
|
||||
}
|
||||
|
||||
if wave_end - wave_start >= 2 {
|
||||
// Run parallel wave (2+ consecutive parallel-safe middlewares)
|
||||
let base_prompt_len = ctx.system_prompt.len();
|
||||
let wave = &self.middlewares[wave_start..wave_end];
|
||||
|
||||
// Spawn concurrent tasks — each owns its cloned context + Arc ref to middleware
|
||||
let mut join_handles = Vec::with_capacity(wave.len());
|
||||
for mw in wave.iter() {
|
||||
let mut ctx_clone = ctx.clone();
|
||||
let mw_arc = Arc::clone(mw);
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let result = mw_arc.before_completion(&mut ctx_clone).await;
|
||||
(result, ctx_clone.system_prompt)
|
||||
}));
|
||||
}
|
||||
|
||||
// Await all and merge prompt contributions
|
||||
for (i, handle) in join_handles.into_iter().enumerate() {
|
||||
let (result, modified_prompt): (Result<MiddlewareDecision>, String) = handle.await
|
||||
.map_err(|e| zclaw_types::ZclawError::Internal(format!("Parallel middleware panicked: {}", e)))?;
|
||||
match result? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!(
|
||||
"[MiddlewareChain] '{}' requested stop: {}",
|
||||
self.middlewares[wave_start + i].name(),
|
||||
reason
|
||||
);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
}
|
||||
}
|
||||
// Merge system_prompt contribution from this clone
|
||||
if modified_prompt.len() > base_prompt_len {
|
||||
let contribution = &modified_prompt[base_prompt_len..];
|
||||
ctx.system_prompt.push_str(contribution);
|
||||
}
|
||||
}
|
||||
|
||||
idx = wave_end;
|
||||
} else {
|
||||
// Run single middleware sequentially
|
||||
let mw = &self.middlewares[idx];
|
||||
match mw.before_completion(ctx).await? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
}
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
|
||||
@@ -290,6 +290,8 @@ impl AgentMiddleware for ButlerRouterMiddleware {
|
||||
80
|
||||
}
|
||||
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
// Only route on the first user message in a turn (not tool results)
|
||||
let user_input = &ctx.user_input;
|
||||
|
||||
@@ -88,6 +88,8 @@ impl AgentMiddleware for EvolutionMiddleware {
|
||||
78 // 在 ButlerRouter(80) 之前
|
||||
}
|
||||
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(
|
||||
&self,
|
||||
ctx: &mut MiddlewareContext,
|
||||
|
||||
@@ -111,6 +111,7 @@ impl MemoryMiddleware {
|
||||
impl AgentMiddleware for MemoryMiddleware {
|
||||
fn name(&self) -> &str { "memory" }
|
||||
fn priority(&self) -> i32 { 150 }
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
tracing::debug!(
|
||||
|
||||
@@ -40,6 +40,7 @@ impl SkillIndexMiddleware {
|
||||
impl AgentMiddleware for SkillIndexMiddleware {
|
||||
fn name(&self) -> &str { "skill_index" }
|
||||
fn priority(&self) -> i32 { 200 }
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
if self.entries.is_empty() {
|
||||
|
||||
@@ -41,6 +41,7 @@ impl Default for TitleMiddleware {
|
||||
impl AgentMiddleware for TitleMiddleware {
|
||||
fn name(&self) -> &str { "title" }
|
||||
fn priority(&self) -> i32 { 180 }
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
// All hooks default to Continue — placeholder until LLM driver is wired in.
|
||||
async fn before_completion(&self, _ctx: &mut crate::middleware::MiddlewareContext) -> zclaw_types::Result<MiddlewareDecision> {
|
||||
|
||||
Reference in New Issue
Block a user