Files
zclaw_openfang/crates/zclaw-saas/src/relay/service.rs
iven a081a97678
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(relay): audit fixes — abort signal, model selector guard, SSE CRLF, SQL format
Addresses findings from deep code audit:

H-1: Pass abortController.signal to saasClient.chatCompletion() so
     user-cancelled streams actually abort the HTTP connection (was only
     stopping the read loop, leaving server-side SSE connection open).

H-2: ModelSelector now shows only when (!isTauriRuntime() || isLoggedIn).
     Prevents decorative model list in Tauri local kernel mode where model
     selection has no effect (violates CLAUDE.md §5.2).

M-1: Normalize CRLF to LF before SSE event boundary parsing (\n\n).
     Prevents buffer overflow when behind nginx/CDN with CRLF line endings.

M-2: SQL window_minute comparison uses to_char(NOW()-interval, format)
     instead of (NOW()-interval)::TEXT, matching the stored format exactly.

M-3: sort_candidates_by_quota uses same sliding 60s window as select_best_key.

LOW: Fix misleading invalidate_cache doc comment.
2026-04-09 19:51:34 +08:00

945 lines
40 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 中转服务核心逻辑
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::sync::Mutex;
use crate::error::{SaasError, SaasResult};
use crate::models::RelayTaskRow;
use super::types::*;
// ============ StreamBridge 背压常量 ============
/// 上游无数据时,发送 SSE 心跳注释行的间隔
const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
/// 上游无数据时丢弃连接的超时阈值180s = 12 个心跳)
/// 实测 Kimi for Coding 的 thinking→content 间隔可达 60s+,需要更宽容的超时。
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(180);
/// 流结束后延迟清理的时间窗口(缩短到 5s仅用于 Arc 引用释放)
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(5);
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
fn is_retryable_status(status: u16) -> bool {
status == 429 || (500..600).contains(&status)
}
/// 判断 reqwest 错误是否为可重试的网络错误
fn is_retryable_error(e: &reqwest::Error) -> bool {
e.is_timeout() || e.is_connect() || e.is_request()
}
// ============ Relay Task Management ============
/// 判断 sqlx 错误是否为可重试的瞬态错误(连接池耗尽、临时网络故障)
fn is_transient_db_error(e: &sqlx::Error) -> bool {
matches!(e, sqlx::Error::PoolTimedOut | sqlx::Error::Io(_))
}
pub async fn create_relay_task(
db: &PgPool,
account_id: &str,
provider_id: &str,
model_id: &str,
request_body: &str,
priority: i32,
max_attempts: u32,
) -> SaasResult<RelayTaskInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let request_hash = hash_request(request_body);
let max_attempts = max_attempts.max(1).min(5);
let query = sqlx::query_as::<_, RelayTaskRow>(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now);
// 对瞬时 DB 错误(连接池耗尽/超时)重试一次
let row = match query.fetch_one(db).await {
Ok(row) => row,
Err(e) if is_transient_db_error(&e) => {
tracing::warn!("Transient DB error in create_relay_task, retrying: {}", e);
tokio::time::sleep(Duration::from_millis(200)).await;
sqlx::query_as::<_, RelayTaskRow>(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
.fetch_one(db)
.await?
}
Err(e) => return Err(e.into()),
};
Ok(RelayTaskInfo {
id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id,
status: row.status, priority: row.priority, attempt_count: row.attempt_count,
max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens,
error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at,
completed_at: row.completed_at, created_at: row.created_at,
})
}
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<RelayTaskRow> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT
FROM relay_tasks WHERE id = $1"
)
.bind(task_id)
.fetch_optional(db)
.await?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
Ok(RelayTaskInfo {
id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority,
attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens,
error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at,
})
}
pub async fn list_relay_tasks(
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<crate::common::PaginatedResponse<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1) as u32;
let page_size = query.page_size.unwrap_or(20).min(100) as u32;
let offset = ((page - 1) * page_size) as i64;
let (count_sql, data_sql) = if query.status.is_some() {
(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status = $2",
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
)
} else {
(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1",
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
)
};
let total: i64 = if query.status.is_some() {
sqlx::query_scalar(count_sql).bind(account_id).bind(query.status.as_ref().unwrap()).fetch_one(db).await?
} else {
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql)
.bind(account_id);
if let Some(ref status) = query.status {
query_builder = query_builder.bind(status);
}
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|r| {
RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at }
}).collect();
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
}
pub async fn update_task_status(
db: &PgPool, task_id: &str, status: &str,
input_tokens: Option<i64>, output_tokens: Option<i64>,
error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now();
match status {
"processing" => {
sqlx::query(
"UPDATE relay_tasks SET started_at = $1, status = 'processing', attempt_count = attempt_count + 1 WHERE id = $2"
)
.bind(&now).bind(task_id)
.execute(db).await?;
}
"completed" => {
sqlx::query(
"UPDATE relay_tasks SET completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens) WHERE id = $4"
)
.bind(&now).bind(input_tokens).bind(output_tokens).bind(task_id)
.execute(db).await?;
}
"failed" => {
sqlx::query(
"UPDATE relay_tasks SET completed_at = $1, status = 'failed', error_message = $2 WHERE id = $3"
)
.bind(&now).bind(error_message).bind(task_id)
.execute(db).await?;
}
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
}
Ok(())
}
// ============ Relay Execution ============
/// SSE 流中的 usage 信息捕获器
#[derive(Debug, Clone, Default)]
struct SseUsageCapture {
input_tokens: i64,
output_tokens: i64,
}
impl SseUsageCapture {
fn parse_sse_line(&mut self, line: &str) {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return;
}
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(usage) = parsed.get("usage") {
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
self.input_tokens = input;
}
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
self.output_tokens = output;
}
}
}
}
}
}
pub async fn execute_relay(
db: &PgPool,
task_id: &str,
provider_id: &str,
provider_base_url: &str,
request_body: &str,
stream: bool,
max_attempts: u32,
base_delay_ms: u64,
enc_key: &[u8; 32],
// 当由 `execute_relay_with_failover` 调用时为 false由外层统一管理 task 状态;
// 独立调用时为 true由本函数管理 task 状态。
manage_task_status: bool,
) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url).await?;
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
// 复用全局 HTTP 客户端,避免每次请求重建 TLS 连接池和 DNS 解析器
static SHORT_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
static LONG_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
let client = if stream {
LONG_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.expect("Failed to build long-timeout HTTP client")
})
} else {
SHORT_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to build short-timeout HTTP client")
})
};
let max_attempts = max_attempts.max(1).min(5);
// Key Pool 轮转状态
let mut current_key_id: Option<String> = None;
let mut current_api_key: Option<String> = None;
for attempt in 0..max_attempts {
let is_first = attempt == 0;
if is_first && manage_task_status {
update_task_status(db, task_id, "processing", None, None, None).await?;
}
// 首次或 429 后需要重新选择 Key
if current_key_id.is_none() {
match super::key_pool::select_best_key(db, provider_id, enc_key).await {
Ok(selection) => {
let key_id = selection.key_id.clone();
let key_value = selection.key.key_value.clone();
tracing::debug!(
"Relay task {} 选择 Key {} (attempt {})",
task_id, key_id, attempt + 1
);
current_key_id = Some(key_id);
current_api_key = Some(key_value);
}
Err(SaasError::RateLimited(msg)) => {
// 所有 Key 均在冷却中
let err_msg = format!("Key Pool 耗尽: {}", msg);
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::RateLimited(msg));
}
Err(e) => return Err(e),
}
}
let key_id = current_key_id.as_ref()
.ok_or_else(|| SaasError::Internal("Key pool selection failed: no key_id".into()))?
.clone();
let api_key = current_api_key.clone();
let mut req_builder = client.post(&url)
.header("Content-Type", "application/json")
// Kimi Coding Plan 等 Coding Agent API 需要识别 User-Agent 为 coding agent
.header("User-Agent", "claude-code/1.0")
.body(request_body.to_string());
if let Some(ref key) = api_key {
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
}
let result = req_builder.send().await;
match result {
Ok(resp) if resp.status().is_success() => {
if stream {
let usage_capture = Arc::new(Mutex::new(SseUsageCapture::default()));
let usage_capture_clone = usage_capture.clone();
let db_clone = db.clone();
let task_id_clone = task_id.to_string();
let key_id_for_spawn = key_id.clone();
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
// If the client reads slowly, the upstream is signaled via
// backpressure instead of growing memory indefinitely.
let (tx, rx) = tokio::sync::mpsc::channel::<Result<bytes::Bytes, std::io::Error>>(128);
// Spawn a task to consume the upstream stream and forward through the bounded channel
tokio::spawn(async move {
use futures::StreamExt;
let mut upstream = resp.bytes_stream();
while let Some(chunk_result) = upstream.next().await {
match chunk_result {
Ok(chunk) => {
// Parse SSE lines for usage tracking
if let Ok(text) = std::str::from_utf8(&chunk) {
let mut capture = usage_capture_clone.lock().await;
for line in text.lines() {
capture.parse_sse_line(line);
}
}
// Forward to bounded channel — if full, this applies backpressure
if tx.send(Ok(chunk)).await.is_err() {
tracing::debug!("SSE relay: client disconnected, stopping upstream");
break;
}
}
Err(e) => {
let err_msg = e.to_string();
if tx.send(Err(std::io::Error::other(e))).await.is_err() {
tracing::debug!("SSE relay: client disconnected before error sent: {}", err_msg);
}
break;
}
}
}
});
// Build StreamBridge: wraps the bounded receiver with heartbeat,
// timeout, and delayed cleanup (DeerFlow-inspired backpressure).
let body = build_stream_bridge(rx, task_id.to_string());
// SSE 流结束后异步记录 usage + Key 使用量
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks防止高并发时耗尽连接池
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(64)));
let permit = match semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
// 信号量满时跳过 usage 记录,流本身不受影响
tracing::warn!("SSE usage spawn at capacity, skipping usage record for task {}", task_id);
return Ok(RelayResponse::Sse(body));
}
};
tokio::spawn(async move {
let _permit = permit; // 持有 permit 直到任务完成
// Brief delay to allow SSE stream to settle before recording
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let capture = usage_capture.lock().await;
let (input, output) = (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
);
// Record task status with timeout to avoid holding DB connections
let db_op = async {
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
tracing::warn!("Failed to update task status after SSE stream: {}", e);
}
// Record key usage (now 2 queries instead of 3)
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
tracing::warn!("Failed to record key usage: {}", e);
}
};
if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() {
tracing::warn!("SSE usage recording timed out for task {}", task_id_clone);
}
// StreamBridge 延迟清理:流结束 60s 后释放残留资源
// (主要是 Arc<SseUsageCapture> 等,通过 drop(_permit) 归还信号量)
tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await;
tracing::debug!(
"[StreamBridge] Cleanup delay elapsed for task {}",
task_id_clone
);
});
return Ok(RelayResponse::Sse(body));
} else {
let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body);
if manage_task_status {
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
}
// 记录 Key 使用量(失败仅记录,不阻塞响应)
if let Err(e) = super::key_pool::record_key_usage(
db, &key_id, Some(input_tokens + output_tokens),
).await {
tracing::warn!("[Relay] Failed to record key usage for billing: {}", e);
}
return Ok(RelayResponse::Json(body));
}
}
Ok(resp) => {
let status = resp.status().as_u16();
if status == 429 {
// 解析 Retry-After header
let retry_after = resp.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
// 标记 Key 为 429 冷却
if let Err(e) = super::key_pool::mark_key_429(db, &key_id, retry_after).await {
tracing::warn!("Failed to mark key 429: {}", e);
}
// 强制下次迭代重新选择 Key
current_key_id = None;
current_api_key = None;
if attempt + 1 >= max_attempts {
let err_msg = format!(
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
max_attempts
);
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::RateLimited(err_msg));
}
tracing::warn!(
"Relay task {} 收到 429Key {} 已标记冷却 (attempt {}/{})",
task_id, key_id, attempt + 1, max_attempts
);
// 429 时立即切换 Key 重试,不做退避延迟
continue;
}
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
let body = resp.text().await.unwrap_or_default();
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::Relay(err_msg));
}
tracing::warn!(
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
task_id, status, attempt + 1, max_attempts
);
}
Err(e) => {
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
let err_msg = format!("请求上游失败: {}", e);
if manage_task_status {
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
}
return Err(SaasError::Relay(err_msg));
}
tracing::warn!(
"Relay task {} 网络错误 (attempt {}/{}): {}",
task_id, attempt + 1, max_attempts, e
);
}
}
// 非 429 错误使用指数退避
let delay_ms = base_delay_ms * (1 << attempt);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
Err(SaasError::Relay("重试次数已耗尽".into()))
}
// ============ 跨 Provider Failover ============
/// 跨 Provider Failover 执行器
///
/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool
/// 直到找到可用 Provider 或全部耗尽。
///
/// **注意**Failover 仅适用于预流失败连接错误、429/5xx 在流开始之前)。
/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。
///
/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。
pub async fn execute_relay_with_failover(
db: &PgPool,
task_id: &str,
candidates: &[CandidateModel],
request_body: &str,
stream: bool,
max_attempts_per_provider: u32,
base_delay_ms: u64,
enc_key: &[u8; 32],
) -> SaasResult<(RelayResponse, String, String)> {
let mut last_error: Option<SaasError> = None;
let failover_start = std::time::Instant::now();
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
// C-3: 外层统一管理 task 状态 — 仅设一次 "processing"
update_task_status(db, task_id, "processing", None, None, None).await?;
for (idx, candidate) in candidates.iter().enumerate() {
// M-3: 超时预算检查 — 防止级联失败累积过长
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
tracing::warn!(
"Failover timeout ({:?}) exceeded after {}/{} candidates for task {}",
FAILOVER_TIMEOUT, idx, candidates.len(), task_id
);
break;
}
// 替换请求体中的 model 字段为当前候选的物理模型 ID
let patched_body = patch_model_in_body(request_body, &candidate.model_id);
match execute_relay(
db,
task_id,
&candidate.provider_id,
&candidate.base_url,
&patched_body,
stream,
max_attempts_per_provider,
base_delay_ms,
enc_key,
false, // C-3: 外层管理 task 状态
)
.await
{
Ok(response) => {
if idx > 0 {
tracing::info!(
"Failover succeeded on candidate {}/{} (provider={}, model={})",
idx + 1,
candidates.len(),
candidate.provider_id,
candidate.model_id
);
}
return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone()));
}
Err(SaasError::RateLimited(msg)) => {
tracing::warn!(
"Provider {} rate limited ({}), trying next candidate ({}/{})",
candidate.provider_id,
msg,
idx + 1,
candidates.len()
);
last_error = Some(SaasError::RateLimited(msg));
continue;
}
Err(e) => {
tracing::warn!(
"Provider {} failed: {}, trying next candidate ({}/{})",
candidate.provider_id,
e,
idx + 1,
candidates.len()
);
last_error = Some(e);
continue;
}
}
}
// C-3: 所有候选失败 — 外层统一标记 task 为 "failed"
let final_error = last_error.unwrap_or_else(|| SaasError::RateLimited(
"所有候选 Provider 均不可用".into(),
));
let err_msg = format!("{}", final_error);
if let Err(e) = update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await {
tracing::warn!("Failed to update task {} status after failover exhaustion: {}", task_id, e);
}
Err(final_error)
}
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
fn patch_model_in_body(body: &str, new_model_id: &str) -> String {
if let Ok(mut parsed) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(obj) = parsed.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(new_model_id.to_string()),
);
}
serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string())
} else {
body.to_string()
}
}
/// 按配额余量排序候选模型
///
/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。
/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。
/// 使用内存缓存TTL 5s减少 DB 查询频率。
pub async fn sort_candidates_by_quota(
db: &PgPool,
candidates: &mut [CandidateModel],
) {
if candidates.len() <= 1 {
return;
}
let provider_ids: Vec<String> = candidates.iter().map(|c| c.provider_id.clone()).collect();
// H-4: 配额排序缓存TTL 5 秒),减少关键路径 DB 查询
static QUOTA_CACHE: OnceLock<std::sync::Mutex<HashMap<String, (i64, std::time::Instant)>>> = OnceLock::new();
let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5);
let now = std::time::Instant::now();
// 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await
let cached_entries: HashMap<String, (i64, std::time::Instant)> = {
let guard = cache.lock().unwrap_or_else(|e| e.into_inner());
guard.clone()
};
let all_fresh = provider_ids.iter().all(|pid| {
cached_entries.get(pid)
.map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL)
.unwrap_or(false)
});
let quota_map: HashMap<String, i64> = if all_fresh {
provider_ids.iter()
.filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining)))
.collect()
} else {
let quota_rows: Vec<(String, i64)> = match sqlx::query_as(
r#"
SELECT pk.provider_id,
SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm
FROM provider_keys pk
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
WHERE pk.provider_id = ANY($1)
AND pk.is_active = TRUE
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= NOW())
GROUP BY pk.provider_id
"#,
)
.bind(&provider_ids)
.fetch_all(db)
.await
{
Ok(rows) => rows,
Err(e) => {
// M-6: DB 查询失败时记录警告,使用原始顺序
tracing::warn!("sort_candidates_by_quota DB query failed: {}", e);
return;
}
};
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
// 更新缓存 + 清理过期条目
{
let mut cache_guard = cache.lock().unwrap_or_else(|e| e.into_inner());
for (pid, remaining) in &map {
cache_guard.insert(pid.clone(), (*remaining, now));
}
// M-S3: 清理超过 TTL 5x25s的陈旧条目防止已删除 Provider 的条目永久残留
let ttl_5x = QUOTA_CACHE_TTL * 5;
cache_guard.retain(|_, (_, ts)| now.saturating_duration_since(*ts) < ttl_5x);
}
map
};
// H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量
candidates.sort_by(|a, b| {
let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999);
let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999);
qb.cmp(&qa) // 降序:余量多的排前面
});
}
/// 中转响应类型
#[derive(Debug)]
pub enum RelayResponse {
Json(String),
Sse(axum::body::Body),
}
// ============ StreamBridge ============
/// 构建 StreamBridge将 mpsc::Receiver 包装为带心跳、超时的 axum Body。
///
/// 借鉴 DeerFlow StreamBridge 背压机制:
/// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃
/// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流
/// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源
fn build_stream_bridge(
mut rx: tokio::sync::mpsc::Receiver<Result<bytes::Bytes, std::io::Error>>,
task_id: String,
) -> axum::body::Body {
// SSE heartbeat comment bytes: `: heartbeat\n\n`
// SSE spec: lines starting with `:` are comments and ignored by clients
const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n";
// SSE timeout error event
const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n";
let stream = async_stream::stream! {
// Track how many consecutive heartbeat-only cycles have elapsed.
// Real data resets this counter; after 2 heartbeats (30s) without
// real data, we terminate the stream.
let mut idle_heartbeats: u32 = 0;
loop {
// tokio::select! races the next data chunk against a heartbeat timer.
// The timer resets on every iteration, ensuring heartbeats only fire
// during genuine idle periods.
tokio::select! {
biased; // prioritize data over heartbeat
chunk = rx.recv() => {
match chunk {
Some(Ok(data)) => {
// Real data received — reset idle counter
idle_heartbeats = 0;
yield Ok::<bytes::Bytes, std::io::Error>(data);
}
Some(Err(e)) => {
tracing::warn!(
"[StreamBridge] Upstream error for task {}: {}",
task_id, e
);
yield Err(e);
break;
}
None => {
// Channel closed = upstream finished normally
tracing::debug!(
"[StreamBridge] Upstream completed for task {}",
task_id
);
break;
}
}
}
// Heartbeat: send SSE comment if no data for 15s
_ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => {
idle_heartbeats += 1;
tracing::trace!(
"[StreamBridge] Heartbeat #{} for task {} (idle {}s)",
idle_heartbeats,
task_id,
idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(),
);
// After 12 consecutive heartbeats without real data (180s),
// terminate the stream to prevent connection leaks.
if idle_heartbeats >= 12 {
tracing::warn!(
"[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}",
STREAMBRIDGE_TIMEOUT,
task_id,
);
yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT));
break;
}
yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES));
}
}
}
};
// Pin the stream to a Box<dyn Stream + Send> to satisfy Body::from_stream
let boxed: std::pin::Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes, std::io::Error>> + Send>> =
Box::pin(stream);
axum::body::Body::from_stream(boxed)
}
// ============ Helpers ============
fn hash_request(body: &str) -> String {
use sha2::{Sha256, Digest};
hex::encode(Sha256::digest(body.as_bytes()))
}
/// 从 JSON 响应中提取 token 使用量
fn extract_token_usage(body: &str) -> (i64, i64) {
let parsed: serde_json::Value = match serde_json::from_str(body) {
Ok(v) => v,
Err(e) => {
tracing::debug!("extract_token_usage: JSON parse failed (body len={}): {}", body.len(), e);
return (0, 0);
}
};
let usage = parsed.get("usage");
let input = usage
.and_then(|u| u.get("prompt_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let output = usage
.and_then(|u| u.get("completion_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
(input, output)
}
/// 从 JSON 响应中提取 token 使用量 (公开版本)
pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
extract_token_usage(body)
}
/// SSRF 防护: 验证 provider URL 不指向内网
async fn validate_provider_url(url: &str) -> SaasResult<()> {
let parsed: url::Url = url.parse().map_err(|_| {
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
})?;
// 只允许 https
match parsed.scheme() {
"https" => {}
"http" => {
// 开发环境允许 http
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !is_dev {
return Err(SaasError::InvalidInput("生产环境禁止 http scheme请使用 https".into()));
}
}
_ => return Err(SaasError::InvalidInput(format!("不允许的 URL scheme: {}", parsed.scheme()))),
}
// 禁止内网地址
let host = match parsed.host_str() {
Some(h) => h,
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
};
// 去除 IPv6 方括号
let host = host.trim_start_matches('[').trim_end_matches(']');
// 精确匹配的阻止列表: 仅包含主机名和特殊域名
// 私有 IP 范围 (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, ::1 等)
// 由 is_private_ip() 统一判断,无需在此重复列出
let blocked_exact = [
"localhost", "metadata.google.internal",
];
if blocked_exact.contains(&host) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
}
// 后缀匹配 (阻止子域名)
let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"];
for suffix in &blocked_suffixes {
if host.ends_with(&format!(".{}", suffix)) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
}
}
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
if host.parse::<u64>().is_ok() {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
}
// 阻止十六进制/八进制 IP 混淆 (如 0x7f000001, 0177.0.0.1)
if host.chars().all(|c| c.is_ascii_hexdigit() || c == '.' || c == ':' || c == 'x' || c == 'X') {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
}
// 阻止 IPv4 私有网段 (通过解析 IP)
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_private_ip(&ip) {
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
}
return Ok(());
}
// 对域名做异步 DNS 解析,检查解析结果是否指向内网
let addr_str = format!("{}:0", host);
match tokio::net::lookup_host(&*addr_str).await {
Ok(addrs) => {
for sockaddr in addrs {
if is_private_ip(&sockaddr.ip()) {
return Err(SaasError::InvalidInput(
"provider URL 域名解析到内网地址".into()
));
}
}
}
Err(_) => {
// DNS 解析失败,可能是无效域名,不阻止请求
}
}
Ok(())
}
/// 检查 IP 是否属于私有/内网地址范围
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
let octets = v4.octets();
// 10.0.0.0/8
octets[0] == 10
// 172.16.0.0/12
|| (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31)
// 192.168.0.0/16
|| (octets[0] == 192 && octets[1] == 168)
// 127.0.0.0/8 (loopback)
|| octets[0] == 127
// 169.254.0.0/16 (link-local)
|| (octets[0] == 169 && octets[1] == 254)
// 0.0.0.0/8
|| octets[0] == 0
}
std::net::IpAddr::V6(v6) => {
// ::1 (loopback)
v6.is_loopback()
// ::ffff:x.x.x.x (IPv6-mapped IPv4)
|| v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4)))
// fe80::/10 (link-local)
|| (v6.segments()[0] & 0xffc0) == 0xfe80
}
}
}