fix(saas): SSE relay token capture 修复 — stream_done 标志 + 前缀兼容
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- SseUsageCapture 增加 stream_done 标志,[DONE] 和 stream 结束时设置 - parse_sse_line 兼容 "data:" 和 "data: " 两种前缀 - 增加 total_tokens 兜底解析(某些 provider 不返回 prompt_tokens) - 轮询逻辑优先检测 stream_done,而非依赖 total > 0 条件 - 超时时增加 warn 日志记录实际 token 值 根因: 上游 provider 不在 SSE chunk 中返回 usage 时,轮询稳定逻辑 (total > 0 条件) 永远不满足,导致 token 始终为 0。
This commit is contained in:
@@ -192,21 +192,39 @@ pub async fn update_task_status(
|
|||||||
struct SseUsageCapture {
|
struct SseUsageCapture {
|
||||||
input_tokens: i64,
|
input_tokens: i64,
|
||||||
output_tokens: i64,
|
output_tokens: i64,
|
||||||
|
/// 标记上游 stream 是否已结束(channel 关闭或收到 [DONE])
|
||||||
|
stream_done: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SseUsageCapture {
|
impl SseUsageCapture {
|
||||||
fn parse_sse_line(&mut self, line: &str) {
|
fn parse_sse_line(&mut self, line: &str) {
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
// 兼容 "data: " 和 "data:" 两种前缀
|
||||||
if data == "[DONE]" {
|
let data = if let Some(d) = line.strip_prefix("data: ") {
|
||||||
return;
|
d
|
||||||
}
|
} else if let Some(d) = line.strip_prefix("data:") {
|
||||||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
|
d.trim_start()
|
||||||
if let Some(usage) = parsed.get("usage") {
|
} else {
|
||||||
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
|
return;
|
||||||
self.input_tokens = input;
|
};
|
||||||
}
|
|
||||||
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
|
if data == "[DONE]" {
|
||||||
self.output_tokens = output;
|
self.stream_done = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
|
||||||
|
if let Some(usage) = parsed.get("usage") {
|
||||||
|
// 标准 OpenAI 格式: prompt_tokens / completion_tokens
|
||||||
|
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
|
||||||
|
self.input_tokens = input;
|
||||||
|
}
|
||||||
|
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
|
||||||
|
self.output_tokens = output;
|
||||||
|
}
|
||||||
|
// 兜底: 某些 provider 只返回 total_tokens
|
||||||
|
if self.input_tokens == 0 && self.output_tokens > 0 {
|
||||||
|
if let Some(total) = usage.get("total_tokens").and_then(|v| v.as_i64()) {
|
||||||
|
self.input_tokens = (total - self.output_tokens).max(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -350,6 +368,11 @@ pub async fn execute_relay(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Stream 结束后设置 stream_done 标志,通知 usage 轮询任务
|
||||||
|
{
|
||||||
|
let mut capture = usage_capture_clone.lock().await;
|
||||||
|
capture.stream_done = true;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Build StreamBridge: wraps the bounded receiver with heartbeat,
|
// Build StreamBridge: wraps the bounded receiver with heartbeat,
|
||||||
@@ -371,8 +394,8 @@ pub async fn execute_relay(
|
|||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let _permit = permit; // 持有 permit 直到任务完成
|
let _permit = permit; // 持有 permit 直到任务完成
|
||||||
// 等待 SSE 流结束 — 等待 capture 稳定(tokens 不再增长)
|
// 等待 SSE 流结束 — 优先等待 stream_done 标志,
|
||||||
// 替代原来固定 500ms 的 race condition
|
// 兜底使用 token 稳定检测 + 最大等待时间
|
||||||
let max_wait = std::time::Duration::from_secs(120);
|
let max_wait = std::time::Duration::from_secs(120);
|
||||||
let poll_interval = std::time::Duration::from_millis(500);
|
let poll_interval = std::time::Duration::from_millis(500);
|
||||||
let start = tokio::time::Instant::now();
|
let start = tokio::time::Instant::now();
|
||||||
@@ -381,11 +404,15 @@ pub async fn execute_relay(
|
|||||||
let (input, output) = loop {
|
let (input, output) = loop {
|
||||||
tokio::time::sleep(poll_interval).await;
|
tokio::time::sleep(poll_interval).await;
|
||||||
let capture = usage_capture.lock().await;
|
let capture = usage_capture.lock().await;
|
||||||
|
// 优先: stream_done 标志表示上游已结束
|
||||||
|
if capture.stream_done {
|
||||||
|
break (capture.input_tokens, capture.output_tokens);
|
||||||
|
}
|
||||||
let total = capture.input_tokens + capture.output_tokens;
|
let total = capture.input_tokens + capture.output_tokens;
|
||||||
|
// 兜底: token 数稳定检测(兼容不发送 [DONE] 的 provider)
|
||||||
if total == last_tokens && total > 0 {
|
if total == last_tokens && total > 0 {
|
||||||
stable_count += 1;
|
stable_count += 1;
|
||||||
if stable_count >= 3 {
|
if stable_count >= 3 {
|
||||||
// 连续 3 次稳定(1.5s),认为流结束
|
|
||||||
break (capture.input_tokens, capture.output_tokens);
|
break (capture.input_tokens, capture.output_tokens);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -393,8 +420,13 @@ pub async fn execute_relay(
|
|||||||
last_tokens = total;
|
last_tokens = total;
|
||||||
}
|
}
|
||||||
drop(capture);
|
drop(capture);
|
||||||
|
// 最终兜底: 超时保护
|
||||||
if start.elapsed() >= max_wait {
|
if start.elapsed() >= max_wait {
|
||||||
let capture = usage_capture.lock().await;
|
let capture = usage_capture.lock().await;
|
||||||
|
tracing::warn!(
|
||||||
|
"SSE usage capture timed out for task {}, tokens: in={} out={}",
|
||||||
|
task_id_clone, capture.input_tokens, capture.output_tokens
|
||||||
|
);
|
||||||
break (capture.input_tokens, capture.output_tokens);
|
break (capture.input_tokens, capture.output_tokens);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user