From 9c59e6e82a8c7b9573df03a9836e761855811c7a Mon Sep 17 00:00:00 2001 From: iven Date: Wed, 15 Apr 2026 00:15:03 +0800 Subject: [PATCH] =?UTF-8?q?fix(saas):=20SSE=20relay=20token=20capture=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20=E2=80=94=20stream=5Fdone=20=E6=A0=87?= =?UTF-8?q?=E5=BF=97=20+=20=E5=89=8D=E7=BC=80=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SseUsageCapture 增加 stream_done 标志,[DONE] 和 stream 结束时设置 - parse_sse_line 兼容 "data:" 和 "data: " 两种前缀 - 增加 total_tokens 兜底解析(某些 provider 不返回 prompt_tokens) - 轮询逻辑优先检测 stream_done,而非依赖 total > 0 条件 - 超时时增加 warn 日志记录实际 token 值 根因: 上游 provider 不在 SSE chunk 中返回 usage 时,轮询稳定逻辑 (total > 0 条件) 永远不满足,导致 token 始终为 0。 --- crates/zclaw-saas/src/relay/service.rs | 60 ++++++++++++++++++++------ 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index 598b998..d1b476e 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -192,21 +192,39 @@ pub async fn update_task_status( struct SseUsageCapture { input_tokens: i64, output_tokens: i64, + /// 标记上游 stream 是否已结束(channel 关闭或收到 [DONE]) + stream_done: bool, } impl SseUsageCapture { fn parse_sse_line(&mut self, line: &str) { - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - return; - } - if let Ok(parsed) = serde_json::from_str::(data) { - if let Some(usage) = parsed.get("usage") { - if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) { - self.input_tokens = input; - } - if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) { - self.output_tokens = output; + // 兼容 "data: " 和 "data:" 两种前缀 + let data = if let Some(d) = line.strip_prefix("data: ") { + d + } else if let Some(d) = line.strip_prefix("data:") { + d.trim_start() + } else { + return; + }; + + if data == "[DONE]" { + self.stream_done = true; + return; + } + + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(usage) = parsed.get("usage") { + // 标准 OpenAI 格式: prompt_tokens / completion_tokens + if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) { + self.input_tokens = input; + } + if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) { + self.output_tokens = output; + } + // 兜底: 某些 provider 只返回 total_tokens + if self.input_tokens == 0 && self.output_tokens > 0 { + if let Some(total) = usage.get("total_tokens").and_then(|v| v.as_i64()) { + self.input_tokens = (total - self.output_tokens).max(0); } } } @@ -350,6 +368,11 @@ pub async fn execute_relay( } } } + // Stream 结束后设置 stream_done 标志,通知 usage 轮询任务 + { + let mut capture = usage_capture_clone.lock().await; + capture.stream_done = true; + } }); // Build StreamBridge: wraps the bounded receiver with heartbeat, @@ -371,8 +394,8 @@ pub async fn execute_relay( tokio::spawn(async move { let _permit = permit; // 持有 permit 直到任务完成 - // 等待 SSE 流结束 — 等待 capture 稳定(tokens 不再增长) - // 替代原来固定 500ms 的 race condition + // 等待 SSE 流结束 — 优先等待 stream_done 标志, + // 兜底使用 token 稳定检测 + 最大等待时间 let max_wait = std::time::Duration::from_secs(120); let poll_interval = std::time::Duration::from_millis(500); let start = tokio::time::Instant::now(); @@ -381,11 +404,15 @@ pub async fn execute_relay( let (input, output) = loop { tokio::time::sleep(poll_interval).await; let capture = usage_capture.lock().await; + // 优先: stream_done 标志表示上游已结束 + if capture.stream_done { + break (capture.input_tokens, capture.output_tokens); + } let total = capture.input_tokens + capture.output_tokens; + // 兜底: token 数稳定检测(兼容不发送 [DONE] 的 provider) if total == last_tokens && total > 0 { stable_count += 1; if stable_count >= 3 { - // 连续 3 次稳定(1.5s),认为流结束 break (capture.input_tokens, capture.output_tokens); } } else { @@ -393,8 +420,13 @@ pub async fn execute_relay( last_tokens = total; } drop(capture); + // 最终兜底: 超时保护 if start.elapsed() >= max_wait { let capture = usage_capture.lock().await; + tracing::warn!( + "SSE usage capture timed out for task {}, tokens: in={} out={}", + task_id_clone, capture.input_tokens, capture.output_tokens + ); break (capture.input_tokens, capture.output_tokens); } };