fix(saas): SSE relay token capture 修复 — stream_done 标志 + 前缀兼容
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- SseUsageCapture 增加 stream_done 标志,[DONE] 和 stream 结束时设置 - parse_sse_line 兼容 "data:" 和 "data: " 两种前缀 - 增加 total_tokens 兜底解析(某些 provider 不返回 prompt_tokens) - 轮询逻辑优先检测 stream_done,而非依赖 total > 0 条件 - 超时时增加 warn 日志记录实际 token 值 根因: 上游 provider 不在 SSE chunk 中返回 usage 时,轮询稳定逻辑 (total > 0 条件) 永远不满足,导致 token 始终为 0。
This commit is contained in:
@@ -192,22 +192,40 @@ pub async fn update_task_status(
|
||||
struct SseUsageCapture {
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
/// 标记上游 stream 是否已结束(channel 关闭或收到 [DONE])
|
||||
stream_done: bool,
|
||||
}
|
||||
|
||||
impl SseUsageCapture {
|
||||
fn parse_sse_line(&mut self, line: &str) {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// 兼容 "data: " 和 "data:" 两种前缀
|
||||
let data = if let Some(d) = line.strip_prefix("data: ") {
|
||||
d
|
||||
} else if let Some(d) = line.strip_prefix("data:") {
|
||||
d.trim_start()
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
if data == "[DONE]" {
|
||||
self.stream_done = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(usage) = parsed.get("usage") {
|
||||
// 标准 OpenAI 格式: prompt_tokens / completion_tokens
|
||||
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
|
||||
self.input_tokens = input;
|
||||
}
|
||||
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
|
||||
self.output_tokens = output;
|
||||
}
|
||||
// 兜底: 某些 provider 只返回 total_tokens
|
||||
if self.input_tokens == 0 && self.output_tokens > 0 {
|
||||
if let Some(total) = usage.get("total_tokens").and_then(|v| v.as_i64()) {
|
||||
self.input_tokens = (total - self.output_tokens).max(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -350,6 +368,11 @@ pub async fn execute_relay(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Stream 结束后设置 stream_done 标志,通知 usage 轮询任务
|
||||
{
|
||||
let mut capture = usage_capture_clone.lock().await;
|
||||
capture.stream_done = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Build StreamBridge: wraps the bounded receiver with heartbeat,
|
||||
@@ -371,8 +394,8 @@ pub async fn execute_relay(
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit; // 持有 permit 直到任务完成
|
||||
// 等待 SSE 流结束 — 等待 capture 稳定(tokens 不再增长)
|
||||
// 替代原来固定 500ms 的 race condition
|
||||
// 等待 SSE 流结束 — 优先等待 stream_done 标志,
|
||||
// 兜底使用 token 稳定检测 + 最大等待时间
|
||||
let max_wait = std::time::Duration::from_secs(120);
|
||||
let poll_interval = std::time::Duration::from_millis(500);
|
||||
let start = tokio::time::Instant::now();
|
||||
@@ -381,11 +404,15 @@ pub async fn execute_relay(
|
||||
let (input, output) = loop {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
let capture = usage_capture.lock().await;
|
||||
// 优先: stream_done 标志表示上游已结束
|
||||
if capture.stream_done {
|
||||
break (capture.input_tokens, capture.output_tokens);
|
||||
}
|
||||
let total = capture.input_tokens + capture.output_tokens;
|
||||
// 兜底: token 数稳定检测(兼容不发送 [DONE] 的 provider)
|
||||
if total == last_tokens && total > 0 {
|
||||
stable_count += 1;
|
||||
if stable_count >= 3 {
|
||||
// 连续 3 次稳定(1.5s),认为流结束
|
||||
break (capture.input_tokens, capture.output_tokens);
|
||||
}
|
||||
} else {
|
||||
@@ -393,8 +420,13 @@ pub async fn execute_relay(
|
||||
last_tokens = total;
|
||||
}
|
||||
drop(capture);
|
||||
// 最终兜底: 超时保护
|
||||
if start.elapsed() >= max_wait {
|
||||
let capture = usage_capture.lock().await;
|
||||
tracing::warn!(
|
||||
"SSE usage capture timed out for task {}, tokens: in={} out={}",
|
||||
task_id_clone, capture.input_tokens, capture.output_tokens
|
||||
);
|
||||
break (capture.input_tokens, capture.output_tokens);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user