//! 中转服务核心逻辑 use sqlx::PgPool; use std::collections::HashMap; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; use tokio::sync::Mutex; use crate::error::{SaasError, SaasResult}; use crate::models::RelayTaskRow; use super::types::*; // ============ StreamBridge 背压常量 ============ /// 上游无数据时,发送 SSE 心跳注释行的间隔 const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15); /// 上游无数据时,丢弃连接的超时阈值(180s = 12 个心跳) /// 实测 Kimi for Coding 的 thinking→content 间隔可达 60s+,需要更宽容的超时。 const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(180); /// 流结束后延迟清理的时间窗口(缩短到 5s,仅用于 Arc 引用释放) const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(5); /// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429) fn is_retryable_status(status: u16) -> bool { status == 429 || (500..600).contains(&status) } /// 判断 reqwest 错误是否为可重试的网络错误 fn is_retryable_error(e: &reqwest::Error) -> bool { e.is_timeout() || e.is_connect() || e.is_request() } // ============ Relay Task Management ============ /// 判断 sqlx 错误是否为可重试的瞬态错误(连接池耗尽、临时网络故障) fn is_transient_db_error(e: &sqlx::Error) -> bool { matches!(e, sqlx::Error::PoolTimedOut | sqlx::Error::Io(_)) } pub async fn create_relay_task( db: &PgPool, account_id: &str, provider_id: &str, model_id: &str, request_body: &str, priority: i32, max_attempts: u32, ) -> SaasResult { let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); let request_hash = hash_request(request_body); let max_attempts = max_attempts.max(1).min(5); let query = sqlx::query_as::<_, RelayTaskRow>( "INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at) VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9) RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT" ) .bind(&id).bind(account_id).bind(provider_id).bind(model_id) .bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now); // 对瞬时 DB 错误(连接池耗尽/超时)重试一次 let row = match query.fetch_one(db).await { Ok(row) => row, Err(e) if is_transient_db_error(&e) => { tracing::warn!("Transient DB error in create_relay_task, retrying: {}", e); tokio::time::sleep(Duration::from_millis(200)).await; sqlx::query_as::<_, RelayTaskRow>( "INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at) VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9) RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT" ) .bind(&id).bind(account_id).bind(provider_id).bind(model_id) .bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now) .fetch_one(db) .await? } Err(e) => return Err(e.into()), }; Ok(RelayTaskInfo { id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id, status: row.status, priority: row.priority, attempt_count: row.attempt_count, max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens, error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at, completed_at: row.completed_at, created_at: row.created_at, }) } pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult { let row: Option = sqlx::query_as( "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT FROM relay_tasks WHERE id = $1" ) .bind(task_id) .fetch_optional(db) .await?; let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?; Ok(RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at, }) } pub async fn list_relay_tasks( db: &PgPool, account_id: &str, query: &RelayTaskQuery, ) -> SaasResult> { let page = query.page.unwrap_or(1).max(1) as u32; let page_size = query.page_size.unwrap_or(20).min(100) as u32; let offset = ((page - 1) * page_size) as i64; let (count_sql, data_sql) = if query.status.is_some() { ( "SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status = $2", "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4" ) } else { ( "SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1", "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3" ) }; let total: i64 = if query.status.is_some() { sqlx::query_scalar(count_sql).bind(account_id).bind(query.status.as_ref().unwrap()).fetch_one(db).await? } else { sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await? }; let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql) .bind(account_id); if let Some(ref status) = query.status { query_builder = query_builder.bind(status); } let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?; let items: Vec = rows.into_iter().map(|r| { RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at } }).collect(); Ok(crate::common::PaginatedResponse { items, total, page, page_size }) } pub async fn update_task_status( db: &PgPool, task_id: &str, status: &str, input_tokens: Option, output_tokens: Option, error_message: Option<&str>, ) -> SaasResult<()> { let now = chrono::Utc::now(); match status { "processing" => { sqlx::query( "UPDATE relay_tasks SET started_at = $1, status = 'processing', attempt_count = attempt_count + 1 WHERE id = $2" ) .bind(&now).bind(task_id) .execute(db).await?; } "completed" => { sqlx::query( "UPDATE relay_tasks SET completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens) WHERE id = $4" ) .bind(&now).bind(input_tokens).bind(output_tokens).bind(task_id) .execute(db).await?; } "failed" => { sqlx::query( "UPDATE relay_tasks SET completed_at = $1, status = 'failed', error_message = $2 WHERE id = $3" ) .bind(&now).bind(error_message).bind(task_id) .execute(db).await?; } _ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))), } Ok(()) } // ============ Relay Execution ============ /// SSE 流中的 usage 信息捕获器 #[derive(Debug, Clone, Default)] struct SseUsageCapture { input_tokens: i64, output_tokens: i64, } impl SseUsageCapture { fn parse_sse_line(&mut self, line: &str) { if let Some(data) = line.strip_prefix("data: ") { if data == "[DONE]" { return; } if let Ok(parsed) = serde_json::from_str::(data) { if let Some(usage) = parsed.get("usage") { if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) { self.input_tokens = input; } if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) { self.output_tokens = output; } } } } } } pub async fn execute_relay( db: &PgPool, task_id: &str, provider_id: &str, provider_base_url: &str, request_body: &str, stream: bool, max_attempts: u32, base_delay_ms: u64, enc_key: &[u8; 32], // 当由 `execute_relay_with_failover` 调用时为 false,由外层统一管理 task 状态; // 独立调用时为 true,由本函数管理 task 状态。 manage_task_status: bool, ) -> SaasResult { validate_provider_url(provider_base_url).await?; let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/')); // 复用全局 HTTP 客户端,避免每次请求重建 TLS 连接池和 DNS 解析器 static SHORT_CLIENT: std::sync::OnceLock = std::sync::OnceLock::new(); static LONG_CLIENT: std::sync::OnceLock = std::sync::OnceLock::new(); let client = if stream { LONG_CLIENT.get_or_init(|| { reqwest::Client::builder() .timeout(std::time::Duration::from_secs(300)) .build() .expect("Failed to build long-timeout HTTP client") }) } else { SHORT_CLIENT.get_or_init(|| { reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .build() .expect("Failed to build short-timeout HTTP client") }) }; let max_attempts = max_attempts.max(1).min(5); // Key Pool 轮转状态 let mut current_key_id: Option = None; let mut current_api_key: Option = None; for attempt in 0..max_attempts { let is_first = attempt == 0; if is_first && manage_task_status { update_task_status(db, task_id, "processing", None, None, None).await?; } // 首次或 429 后需要重新选择 Key if current_key_id.is_none() { match super::key_pool::select_best_key(db, provider_id, enc_key).await { Ok(selection) => { let key_id = selection.key_id.clone(); let key_value = selection.key.key_value.clone(); tracing::debug!( "Relay task {} 选择 Key {} (attempt {})", task_id, key_id, attempt + 1 ); current_key_id = Some(key_id); current_api_key = Some(key_value); } Err(SaasError::RateLimited(msg)) => { // 所有 Key 均在冷却中 let err_msg = format!("Key Pool 耗尽: {}", msg); if manage_task_status { update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; } return Err(SaasError::RateLimited(msg)); } Err(e) => return Err(e), } } let key_id = current_key_id.as_ref() .ok_or_else(|| SaasError::Internal("Key pool selection failed: no key_id".into()))? .clone(); let api_key = current_api_key.clone(); let mut req_builder = client.post(&url) .header("Content-Type", "application/json") // Kimi Coding Plan 等 Coding Agent API 需要识别 User-Agent 为 coding agent .header("User-Agent", "claude-code/1.0") .body(request_body.to_string()); if let Some(ref key) = api_key { req_builder = req_builder.header("Authorization", format!("Bearer {}", key)); } let result = req_builder.send().await; match result { Ok(resp) if resp.status().is_success() => { if stream { let usage_capture = Arc::new(Mutex::new(SseUsageCapture::default())); let usage_capture_clone = usage_capture.clone(); let db_clone = db.clone(); let task_id_clone = task_id.to_string(); let key_id_for_spawn = key_id.clone(); // Bounded channel for backpressure: 128 chunks (~128KB) buffer. // If the client reads slowly, the upstream is signaled via // backpressure instead of growing memory indefinitely. let (tx, rx) = tokio::sync::mpsc::channel::>(128); // Spawn a task to consume the upstream stream and forward through the bounded channel tokio::spawn(async move { use futures::StreamExt; let mut upstream = resp.bytes_stream(); while let Some(chunk_result) = upstream.next().await { match chunk_result { Ok(chunk) => { // Parse SSE lines for usage tracking if let Ok(text) = std::str::from_utf8(&chunk) { let mut capture = usage_capture_clone.lock().await; for line in text.lines() { capture.parse_sse_line(line); } } // Forward to bounded channel — if full, this applies backpressure if tx.send(Ok(chunk)).await.is_err() { tracing::debug!("SSE relay: client disconnected, stopping upstream"); break; } } Err(e) => { let err_msg = e.to_string(); if tx.send(Err(std::io::Error::other(e))).await.is_err() { tracing::debug!("SSE relay: client disconnected before error sent: {}", err_msg); } break; } } } }); // Build StreamBridge: wraps the bounded receiver with heartbeat, // timeout, and delayed cleanup (DeerFlow-inspired backpressure). let body = build_stream_bridge(rx, task_id.to_string()); // SSE 流结束后异步记录 usage + Key 使用量 // 使用全局 Arc 限制并发 spawned tasks,防止高并发时耗尽连接池 static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock> = std::sync::OnceLock::new(); let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(64))); let permit = match semaphore.clone().try_acquire_owned() { Ok(p) => p, Err(_) => { // 信号量满时跳过 usage 记录,流本身不受影响 tracing::warn!("SSE usage spawn at capacity, skipping usage record for task {}", task_id); return Ok(RelayResponse::Sse(body)); } }; tokio::spawn(async move { let _permit = permit; // 持有 permit 直到任务完成 // Brief delay to allow SSE stream to settle before recording tokio::time::sleep(std::time::Duration::from_millis(500)).await; let capture = usage_capture.lock().await; let (input, output) = ( if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None }, if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None }, ); // Record task status with timeout to avoid holding DB connections let db_op = async { if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await { tracing::warn!("Failed to update task status after SSE stream: {}", e); } // Record key usage (now 2 queries instead of 3) let total_tokens = input.unwrap_or(0) + output.unwrap_or(0); if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await { tracing::warn!("Failed to record key usage: {}", e); } }; if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() { tracing::warn!("SSE usage recording timed out for task {}", task_id_clone); } // StreamBridge 延迟清理:流结束 60s 后释放残留资源 // (主要是 Arc 等,通过 drop(_permit) 归还信号量) tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await; tracing::debug!( "[StreamBridge] Cleanup delay elapsed for task {}", task_id_clone ); }); return Ok(RelayResponse::Sse(body)); } else { let body = resp.text().await.unwrap_or_default(); let (input_tokens, output_tokens) = extract_token_usage(&body); if manage_task_status { update_task_status(db, task_id, "completed", Some(input_tokens), Some(output_tokens), None).await?; } // 记录 Key 使用量(失败仅记录,不阻塞响应) if let Err(e) = super::key_pool::record_key_usage( db, &key_id, Some(input_tokens + output_tokens), ).await { tracing::warn!("[Relay] Failed to record key usage for billing: {}", e); } return Ok(RelayResponse::Json(body)); } } Ok(resp) => { let status = resp.status().as_u16(); if status == 429 { // 解析 Retry-After header let retry_after = resp.headers() .get("retry-after") .and_then(|v| v.to_str().ok()) .and_then(|v| v.parse::().ok()); // 标记 Key 为 429 冷却 if let Err(e) = super::key_pool::mark_key_429(db, &key_id, retry_after).await { tracing::warn!("Failed to mark key 429: {}", e); } // 强制下次迭代重新选择 Key current_key_id = None; current_api_key = None; if attempt + 1 >= max_attempts { let err_msg = format!( "Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流", max_attempts ); if manage_task_status { update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; } return Err(SaasError::RateLimited(err_msg)); } tracing::warn!( "Relay task {} 收到 429,Key {} 已标记冷却 (attempt {}/{})", task_id, key_id, attempt + 1, max_attempts ); // 429 时立即切换 Key 重试,不做退避延迟 continue; } if !is_retryable_status(status) || attempt + 1 >= max_attempts { let body = resp.text().await.unwrap_or_default(); let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]); if manage_task_status { update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; } return Err(SaasError::Relay(err_msg)); } tracing::warn!( "Relay task {} 可重试错误 HTTP {} (attempt {}/{})", task_id, status, attempt + 1, max_attempts ); } Err(e) => { if !is_retryable_error(&e) || attempt + 1 >= max_attempts { let err_msg = format!("请求上游失败: {}", e); if manage_task_status { update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; } return Err(SaasError::Relay(err_msg)); } tracing::warn!( "Relay task {} 网络错误 (attempt {}/{}): {}", task_id, attempt + 1, max_attempts, e ); } } // 非 429 错误使用指数退避 let delay_ms = base_delay_ms * (1 << attempt); tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; } Err(SaasError::Relay("重试次数已耗尽".into())) } // ============ 跨 Provider Failover ============ /// 跨 Provider Failover 执行器 /// /// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool, /// 直到找到可用 Provider 或全部耗尽。 /// /// **注意**:Failover 仅适用于预流失败(连接错误、429/5xx 在流开始之前)。 /// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。 /// /// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。 pub async fn execute_relay_with_failover( db: &PgPool, task_id: &str, candidates: &[CandidateModel], request_body: &str, stream: bool, max_attempts_per_provider: u32, base_delay_ms: u64, enc_key: &[u8; 32], ) -> SaasResult<(RelayResponse, String, String)> { let mut last_error: Option = None; let failover_start = std::time::Instant::now(); const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60); // C-3: 外层统一管理 task 状态 — 仅设一次 "processing" update_task_status(db, task_id, "processing", None, None, None).await?; for (idx, candidate) in candidates.iter().enumerate() { // M-3: 超时预算检查 — 防止级联失败累积过长 if failover_start.elapsed() >= FAILOVER_TIMEOUT { tracing::warn!( "Failover timeout ({:?}) exceeded after {}/{} candidates for task {}", FAILOVER_TIMEOUT, idx, candidates.len(), task_id ); break; } // 替换请求体中的 model 字段为当前候选的物理模型 ID let patched_body = patch_model_in_body(request_body, &candidate.model_id); match execute_relay( db, task_id, &candidate.provider_id, &candidate.base_url, &patched_body, stream, max_attempts_per_provider, base_delay_ms, enc_key, false, // C-3: 外层管理 task 状态 ) .await { Ok(response) => { if idx > 0 { tracing::info!( "Failover succeeded on candidate {}/{} (provider={}, model={})", idx + 1, candidates.len(), candidate.provider_id, candidate.model_id ); } return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone())); } Err(SaasError::RateLimited(msg)) => { tracing::warn!( "Provider {} rate limited ({}), trying next candidate ({}/{})", candidate.provider_id, msg, idx + 1, candidates.len() ); last_error = Some(SaasError::RateLimited(msg)); continue; } Err(e) => { tracing::warn!( "Provider {} failed: {}, trying next candidate ({}/{})", candidate.provider_id, e, idx + 1, candidates.len() ); last_error = Some(e); continue; } } } // C-3: 所有候选失败 — 外层统一标记 task 为 "failed" let final_error = last_error.unwrap_or_else(|| SaasError::RateLimited( "所有候选 Provider 均不可用".into(), )); let err_msg = format!("{}", final_error); if let Err(e) = update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await { tracing::warn!("Failed to update task {} status after failover exhaustion: {}", task_id, e); } Err(final_error) } /// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID fn patch_model_in_body(body: &str, new_model_id: &str) -> String { if let Ok(mut parsed) = serde_json::from_str::(body) { if let Some(obj) = parsed.as_object_mut() { obj.insert( "model".to_string(), serde_json::Value::String(new_model_id.to_string()), ); } serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string()) } else { body.to_string() } } /// 按配额余量排序候选模型 /// /// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。 /// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。 /// 使用内存缓存(TTL 5s)减少 DB 查询频率。 pub async fn sort_candidates_by_quota( db: &PgPool, candidates: &mut [CandidateModel], ) { if candidates.len() <= 1 { return; } let provider_ids: Vec = candidates.iter().map(|c| c.provider_id.clone()).collect(); // H-4: 配额排序缓存(TTL 5 秒),减少关键路径 DB 查询 static QUOTA_CACHE: OnceLock>> = OnceLock::new(); let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new())); const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5); let now = std::time::Instant::now(); // 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await let cached_entries: HashMap = { let guard = cache.lock().unwrap_or_else(|e| e.into_inner()); guard.clone() }; let all_fresh = provider_ids.iter().all(|pid| { cached_entries.get(pid) .map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL) .unwrap_or(false) }); let quota_map: HashMap = if all_fresh { provider_ids.iter() .filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining))) .collect() } else { let quota_rows: Vec<(String, i64)> = match sqlx::query_as( r#" SELECT pk.provider_id, SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm FROM provider_keys pk LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI') WHERE pk.provider_id = ANY($1) AND pk.is_active = TRUE AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= NOW()) GROUP BY pk.provider_id "#, ) .bind(&provider_ids) .fetch_all(db) .await { Ok(rows) => rows, Err(e) => { // M-6: DB 查询失败时记录警告,使用原始顺序 tracing::warn!("sort_candidates_by_quota DB query failed: {}", e); return; } }; let map: HashMap = quota_rows.into_iter().collect(); // 更新缓存 + 清理过期条目 { let mut cache_guard = cache.lock().unwrap_or_else(|e| e.into_inner()); for (pid, remaining) in &map { cache_guard.insert(pid.clone(), (*remaining, now)); } // M-S3: 清理超过 TTL 5x(25s)的陈旧条目,防止已删除 Provider 的条目永久残留 let ttl_5x = QUOTA_CACHE_TTL * 5; cache_guard.retain(|_, (_, ts)| now.saturating_duration_since(*ts) < ttl_5x); } map }; // H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量 candidates.sort_by(|a, b| { let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999); let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999); qb.cmp(&qa) // 降序:余量多的排前面 }); } /// 中转响应类型 #[derive(Debug)] pub enum RelayResponse { Json(String), Sse(axum::body::Body), } // ============ StreamBridge ============ /// 构建 StreamBridge:将 mpsc::Receiver 包装为带心跳、超时的 axum Body。 /// /// 借鉴 DeerFlow StreamBridge 背压机制: /// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃 /// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流 /// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源 fn build_stream_bridge( mut rx: tokio::sync::mpsc::Receiver>, task_id: String, ) -> axum::body::Body { // SSE heartbeat comment bytes: `: heartbeat\n\n` // SSE spec: lines starting with `:` are comments and ignored by clients const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n"; // SSE timeout error event const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n"; let stream = async_stream::stream! { // Track how many consecutive heartbeat-only cycles have elapsed. // Real data resets this counter; after 2 heartbeats (30s) without // real data, we terminate the stream. let mut idle_heartbeats: u32 = 0; loop { // tokio::select! races the next data chunk against a heartbeat timer. // The timer resets on every iteration, ensuring heartbeats only fire // during genuine idle periods. tokio::select! { biased; // prioritize data over heartbeat chunk = rx.recv() => { match chunk { Some(Ok(data)) => { // Real data received — reset idle counter idle_heartbeats = 0; yield Ok::(data); } Some(Err(e)) => { tracing::warn!( "[StreamBridge] Upstream error for task {}: {}", task_id, e ); yield Err(e); break; } None => { // Channel closed = upstream finished normally tracing::debug!( "[StreamBridge] Upstream completed for task {}", task_id ); break; } } } // Heartbeat: send SSE comment if no data for 15s _ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => { idle_heartbeats += 1; tracing::trace!( "[StreamBridge] Heartbeat #{} for task {} (idle {}s)", idle_heartbeats, task_id, idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(), ); // After 12 consecutive heartbeats without real data (180s), // terminate the stream to prevent connection leaks. if idle_heartbeats >= 12 { tracing::warn!( "[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}", STREAMBRIDGE_TIMEOUT, task_id, ); yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT)); break; } yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES)); } } } }; // Pin the stream to a Box to satisfy Body::from_stream let boxed: std::pin::Pin> + Send>> = Box::pin(stream); axum::body::Body::from_stream(boxed) } // ============ Helpers ============ fn hash_request(body: &str) -> String { use sha2::{Sha256, Digest}; hex::encode(Sha256::digest(body.as_bytes())) } /// 从 JSON 响应中提取 token 使用量 fn extract_token_usage(body: &str) -> (i64, i64) { let parsed: serde_json::Value = match serde_json::from_str(body) { Ok(v) => v, Err(e) => { tracing::debug!("extract_token_usage: JSON parse failed (body len={}): {}", body.len(), e); return (0, 0); } }; let usage = parsed.get("usage"); let input = usage .and_then(|u| u.get("prompt_tokens")) .and_then(|v| v.as_i64()) .unwrap_or(0); let output = usage .and_then(|u| u.get("completion_tokens")) .and_then(|v| v.as_i64()) .unwrap_or(0); (input, output) } /// 从 JSON 响应中提取 token 使用量 (公开版本) pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) { extract_token_usage(body) } /// SSRF 防护: 验证 provider URL 不指向内网 async fn validate_provider_url(url: &str) -> SaasResult<()> { let parsed: url::Url = url.parse().map_err(|_| { SaasError::InvalidInput(format!("无效的 provider URL: {}", url)) })?; // 只允许 https match parsed.scheme() { "https" => {} "http" => { // 开发环境允许 http let is_dev = std::env::var("ZCLAW_SAAS_DEV") .map(|v| v == "true" || v == "1") .unwrap_or(false); if !is_dev { return Err(SaasError::InvalidInput("生产环境禁止 http scheme,请使用 https".into())); } } _ => return Err(SaasError::InvalidInput(format!("不允许的 URL scheme: {}", parsed.scheme()))), } // 禁止内网地址 let host = match parsed.host_str() { Some(h) => h, None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())), }; // 去除 IPv6 方括号 let host = host.trim_start_matches('[').trim_end_matches(']'); // 精确匹配的阻止列表: 仅包含主机名和特殊域名 // 私有 IP 范围 (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, ::1 等) // 由 is_private_ip() 统一判断,无需在此重复列出 let blocked_exact = [ "localhost", "metadata.google.internal", ]; if blocked_exact.contains(&host) { return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host))); } // 后缀匹配 (阻止子域名) let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"]; for suffix in &blocked_suffixes { if host.ends_with(&format!(".{}", suffix)) { return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host))); } } // 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1) if host.parse::().is_ok() { return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host))); } // 阻止十六进制/八进制 IP 混淆 (如 0x7f000001, 0177.0.0.1) if host.chars().all(|c| c.is_ascii_hexdigit() || c == '.' || c == ':' || c == 'x' || c == 'X') { return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host))); } // 阻止 IPv4 私有网段 (通过解析 IP) if let Ok(ip) = host.parse::() { if is_private_ip(&ip) { return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host))); } return Ok(()); } // 对域名做异步 DNS 解析,检查解析结果是否指向内网 let addr_str = format!("{}:0", host); match tokio::net::lookup_host(&*addr_str).await { Ok(addrs) => { for sockaddr in addrs { if is_private_ip(&sockaddr.ip()) { return Err(SaasError::InvalidInput( "provider URL 域名解析到内网地址".into() )); } } } Err(_) => { // DNS 解析失败,可能是无效域名,不阻止请求 } } Ok(()) } /// 检查 IP 是否属于私有/内网地址范围 fn is_private_ip(ip: &std::net::IpAddr) -> bool { match ip { std::net::IpAddr::V4(v4) => { let octets = v4.octets(); // 10.0.0.0/8 octets[0] == 10 // 172.16.0.0/12 || (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) // 192.168.0.0/16 || (octets[0] == 192 && octets[1] == 168) // 127.0.0.0/8 (loopback) || octets[0] == 127 // 169.254.0.0/16 (link-local) || (octets[0] == 169 && octets[1] == 254) // 0.0.0.0/8 || octets[0] == 0 } std::net::IpAddr::V6(v6) => { // ::1 (loopback) v6.is_loopback() // ::ffff:x.x.x.x (IPv6-mapped IPv4) || v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4))) // fe80::/10 (link-local) || (v6.segments()[0] & 0xffc0) == 0xfe80 } } }