perf(middleware): before_completion 分波并行执行
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- MiddlewareContext 加 Clone derive, 支持并行克隆上下文 - AgentMiddleware trait 新增 parallel_safe() 默认方法 (false) - MiddlewareChain::run_before_completion 改为分波执行: 连续 2+ 个 parallel_safe 中间件用 tokio::spawn 并发执行, 各自独立修改 system_prompt, 执行完成后合并贡献 - 5 个只修改 system_prompt 的中间件标记 parallel_safe: evolution(P78), butler_router(P80), memory(P150), title(P180), skill_index(P200) - 非 parallel_safe 中间件 (compaction, dangling_tool 等) 保持串行 分波效果: Wave 1: evolution + butler_router → 并行 (省 ~0.5-1s) Wave 2: compaction → 串行 (可能修改 messages) Wave 3: memory + title + skill_index → 并行 (省 ~0.5-2s) Wave 4+: 工具/安全中间件 → 串行
This commit is contained in:
@@ -12,6 +12,13 @@
|
|||||||
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
||||||
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
||||||
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
||||||
|
//!
|
||||||
|
//! # Wave parallelization
|
||||||
|
//!
|
||||||
|
//! `before_completion` middlewares that only modify `system_prompt` (not `messages`)
|
||||||
|
//! can declare `parallel_safe() == true`. The chain runs consecutive parallel-safe
|
||||||
|
//! middlewares concurrently, merging their prompt contributions. This reduces
|
||||||
|
//! sequential latency for the context-injection phase.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -50,6 +57,7 @@ pub enum ToolCallDecision {
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Carries the mutable state that middleware may inspect or modify.
|
/// Carries the mutable state that middleware may inspect or modify.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct MiddlewareContext {
|
pub struct MiddlewareContext {
|
||||||
/// The agent that owns this loop.
|
/// The agent that owns this loop.
|
||||||
pub agent_id: AgentId,
|
pub agent_id: AgentId,
|
||||||
@@ -101,6 +109,15 @@ pub trait AgentMiddleware: Send + Sync {
|
|||||||
500
|
500
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether `before_completion` is safe to run concurrently with other
|
||||||
|
/// parallel-safe middlewares. Only return `true` if the middleware:
|
||||||
|
/// - Only modifies `ctx.system_prompt` (never `ctx.messages`)
|
||||||
|
/// - Does not depend on prompt modifications from other middlewares
|
||||||
|
/// - Does not return `MiddlewareDecision::Stop`
|
||||||
|
fn parallel_safe(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
/// Hook executed **before** the LLM completion request is sent.
|
/// Hook executed **before** the LLM completion request is sent.
|
||||||
///
|
///
|
||||||
/// Use this to inject context (memory, skill index, etc.) or to
|
/// Use this to inject context (memory, skill index, etc.) or to
|
||||||
@@ -163,9 +180,66 @@ impl MiddlewareChain {
|
|||||||
self.middlewares.insert(pos, mw);
|
self.middlewares.insert(pos, mw);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run all `before_completion` hooks in order.
|
/// Run all `before_completion` hooks with wave-based parallelization.
|
||||||
|
///
|
||||||
|
/// Consecutive `parallel_safe` middlewares run concurrently — each gets
|
||||||
|
/// its own cloned context and appends to `system_prompt` independently.
|
||||||
|
/// Their contributions are merged after all complete. Non-parallel-safe
|
||||||
|
/// middlewares (and non-consecutive ones) run sequentially as before.
|
||||||
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
for mw in &self.middlewares {
|
let mut idx = 0;
|
||||||
|
while idx < self.middlewares.len() {
|
||||||
|
// Find the extent of consecutive parallel-safe middlewares
|
||||||
|
let wave_start = idx;
|
||||||
|
let mut wave_end = idx;
|
||||||
|
while wave_end < self.middlewares.len()
|
||||||
|
&& self.middlewares[wave_end].parallel_safe()
|
||||||
|
{
|
||||||
|
wave_end += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if wave_end - wave_start >= 2 {
|
||||||
|
// Run parallel wave (2+ consecutive parallel-safe middlewares)
|
||||||
|
let base_prompt_len = ctx.system_prompt.len();
|
||||||
|
let wave = &self.middlewares[wave_start..wave_end];
|
||||||
|
|
||||||
|
// Spawn concurrent tasks — each owns its cloned context + Arc ref to middleware
|
||||||
|
let mut join_handles = Vec::with_capacity(wave.len());
|
||||||
|
for mw in wave.iter() {
|
||||||
|
let mut ctx_clone = ctx.clone();
|
||||||
|
let mw_arc = Arc::clone(mw);
|
||||||
|
join_handles.push(tokio::spawn(async move {
|
||||||
|
let result = mw_arc.before_completion(&mut ctx_clone).await;
|
||||||
|
(result, ctx_clone.system_prompt)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Await all and merge prompt contributions
|
||||||
|
for (i, handle) in join_handles.into_iter().enumerate() {
|
||||||
|
let (result, modified_prompt): (Result<MiddlewareDecision>, String) = handle.await
|
||||||
|
.map_err(|e| zclaw_types::ZclawError::Internal(format!("Parallel middleware panicked: {}", e)))?;
|
||||||
|
match result? {
|
||||||
|
MiddlewareDecision::Continue => {}
|
||||||
|
MiddlewareDecision::Stop(reason) => {
|
||||||
|
tracing::info!(
|
||||||
|
"[MiddlewareChain] '{}' requested stop: {}",
|
||||||
|
self.middlewares[wave_start + i].name(),
|
||||||
|
reason
|
||||||
|
);
|
||||||
|
return Ok(MiddlewareDecision::Stop(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Merge system_prompt contribution from this clone
|
||||||
|
if modified_prompt.len() > base_prompt_len {
|
||||||
|
let contribution = &modified_prompt[base_prompt_len..];
|
||||||
|
ctx.system_prompt.push_str(contribution);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
idx = wave_end;
|
||||||
|
} else {
|
||||||
|
// Run single middleware sequentially
|
||||||
|
let mw = &self.middlewares[idx];
|
||||||
match mw.before_completion(ctx).await? {
|
match mw.before_completion(ctx).await? {
|
||||||
MiddlewareDecision::Continue => {}
|
MiddlewareDecision::Continue => {}
|
||||||
MiddlewareDecision::Stop(reason) => {
|
MiddlewareDecision::Stop(reason) => {
|
||||||
@@ -173,6 +247,8 @@ impl MiddlewareChain {
|
|||||||
return Ok(MiddlewareDecision::Stop(reason));
|
return Ok(MiddlewareDecision::Stop(reason));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(MiddlewareDecision::Continue)
|
Ok(MiddlewareDecision::Continue)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -290,6 +290,8 @@ impl AgentMiddleware for ButlerRouterMiddleware {
|
|||||||
80
|
80
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool { true }
|
||||||
|
|
||||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
// Only route on the first user message in a turn (not tool results)
|
// Only route on the first user message in a turn (not tool results)
|
||||||
let user_input = &ctx.user_input;
|
let user_input = &ctx.user_input;
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ impl AgentMiddleware for EvolutionMiddleware {
|
|||||||
78 // 在 ButlerRouter(80) 之前
|
78 // 在 ButlerRouter(80) 之前
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parallel_safe(&self) -> bool { true }
|
||||||
|
|
||||||
async fn before_completion(
|
async fn before_completion(
|
||||||
&self,
|
&self,
|
||||||
ctx: &mut MiddlewareContext,
|
ctx: &mut MiddlewareContext,
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ impl MemoryMiddleware {
|
|||||||
impl AgentMiddleware for MemoryMiddleware {
|
impl AgentMiddleware for MemoryMiddleware {
|
||||||
fn name(&self) -> &str { "memory" }
|
fn name(&self) -> &str { "memory" }
|
||||||
fn priority(&self) -> i32 { 150 }
|
fn priority(&self) -> i32 { 150 }
|
||||||
|
fn parallel_safe(&self) -> bool { true }
|
||||||
|
|
||||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ impl SkillIndexMiddleware {
|
|||||||
impl AgentMiddleware for SkillIndexMiddleware {
|
impl AgentMiddleware for SkillIndexMiddleware {
|
||||||
fn name(&self) -> &str { "skill_index" }
|
fn name(&self) -> &str { "skill_index" }
|
||||||
fn priority(&self) -> i32 { 200 }
|
fn priority(&self) -> i32 { 200 }
|
||||||
|
fn parallel_safe(&self) -> bool { true }
|
||||||
|
|
||||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ impl Default for TitleMiddleware {
|
|||||||
impl AgentMiddleware for TitleMiddleware {
|
impl AgentMiddleware for TitleMiddleware {
|
||||||
fn name(&self) -> &str { "title" }
|
fn name(&self) -> &str { "title" }
|
||||||
fn priority(&self) -> i32 { 180 }
|
fn priority(&self) -> i32 { 180 }
|
||||||
|
fn parallel_safe(&self) -> bool { true }
|
||||||
|
|
||||||
// All hooks default to Continue — placeholder until LLM driver is wired in.
|
// All hooks default to Continue — placeholder until LLM driver is wired in.
|
||||||
async fn before_completion(&self, _ctx: &mut crate::middleware::MiddlewareContext) -> zclaw_types::Result<MiddlewareDecision> {
|
async fn before_completion(&self, _ctx: &mut crate::middleware::MiddlewareContext) -> zclaw_types::Result<MiddlewareDecision> {
|
||||||
|
|||||||
Reference in New Issue
Block a user