diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index 598b998..d1b476e 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -192,21 +192,39 @@ pub async fn update_task_status( struct SseUsageCapture { input_tokens: i64, output_tokens: i64, + /// 标记上游 stream 是否已结束(channel 关闭或收到 [DONE]) + stream_done: bool, } impl SseUsageCapture { fn parse_sse_line(&mut self, line: &str) { - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - return; - } - if let Ok(parsed) = serde_json::from_str::(data) { - if let Some(usage) = parsed.get("usage") { - if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) { - self.input_tokens = input; - } - if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) { - self.output_tokens = output; + // 兼容 "data: " 和 "data:" 两种前缀 + let data = if let Some(d) = line.strip_prefix("data: ") { + d + } else if let Some(d) = line.strip_prefix("data:") { + d.trim_start() + } else { + return; + }; + + if data == "[DONE]" { + self.stream_done = true; + return; + } + + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(usage) = parsed.get("usage") { + // 标准 OpenAI 格式: prompt_tokens / completion_tokens + if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) { + self.input_tokens = input; + } + if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) { + self.output_tokens = output; + } + // 兜底: 某些 provider 只返回 total_tokens + if self.input_tokens == 0 && self.output_tokens > 0 { + if let Some(total) = usage.get("total_tokens").and_then(|v| v.as_i64()) { + self.input_tokens = (total - self.output_tokens).max(0); } } } @@ -350,6 +368,11 @@ pub async fn execute_relay( } } } + // Stream 结束后设置 stream_done 标志,通知 usage 轮询任务 + { + let mut capture = usage_capture_clone.lock().await; + capture.stream_done = true; + } }); // Build StreamBridge: wraps the bounded receiver with heartbeat, @@ -371,8 +394,8 @@ pub async fn execute_relay( tokio::spawn(async move { let _permit = permit; // 持有 permit 直到任务完成 - // 等待 SSE 流结束 — 等待 capture 稳定(tokens 不再增长) - // 替代原来固定 500ms 的 race condition + // 等待 SSE 流结束 — 优先等待 stream_done 标志, + // 兜底使用 token 稳定检测 + 最大等待时间 let max_wait = std::time::Duration::from_secs(120); let poll_interval = std::time::Duration::from_millis(500); let start = tokio::time::Instant::now(); @@ -381,11 +404,15 @@ pub async fn execute_relay( let (input, output) = loop { tokio::time::sleep(poll_interval).await; let capture = usage_capture.lock().await; + // 优先: stream_done 标志表示上游已结束 + if capture.stream_done { + break (capture.input_tokens, capture.output_tokens); + } let total = capture.input_tokens + capture.output_tokens; + // 兜底: token 数稳定检测(兼容不发送 [DONE] 的 provider) if total == last_tokens && total > 0 { stable_count += 1; if stable_count >= 3 { - // 连续 3 次稳定(1.5s),认为流结束 break (capture.input_tokens, capture.output_tokens); } } else { @@ -393,8 +420,13 @@ pub async fn execute_relay( last_tokens = total; } drop(capture); + // 最终兜底: 超时保护 if start.elapsed() >= max_wait { let capture = usage_capture.lock().await; + tracing::warn!( + "SSE usage capture timed out for task {}, tokens: in={} out={}", + task_id_clone, capture.input_tokens, capture.output_tokens + ); break (capture.input_tokens, capture.output_tokens); } };