Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Addresses findings from deep code audit:
H-1: Pass abortController.signal to saasClient.chatCompletion() so
user-cancelled streams actually abort the HTTP connection (was only
stopping the read loop, leaving server-side SSE connection open).
H-2: ModelSelector now shows only when (!isTauriRuntime() || isLoggedIn).
Prevents decorative model list in Tauri local kernel mode where model
selection has no effect (violates CLAUDE.md §5.2).
M-1: Normalize CRLF to LF before SSE event boundary parsing (\n\n).
Prevents buffer overflow when behind nginx/CDN with CRLF line endings.
M-2: SQL window_minute comparison uses to_char(NOW()-interval, format)
instead of (NOW()-interval)::TEXT, matching the stored format exactly.
M-3: sort_candidates_by_quota uses same sliding 60s window as select_best_key.
LOW: Fix misleading invalidate_cache doc comment.
945 lines
40 KiB
Rust
945 lines
40 KiB
Rust
//! 中转服务核心逻辑
|
||
|
||
use sqlx::PgPool;
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use std::sync::OnceLock;
|
||
use std::time::Duration;
|
||
use tokio::sync::Mutex;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::models::RelayTaskRow;
|
||
use super::types::*;
|
||
|
||
// ============ StreamBridge 背压常量 ============
|
||
|
||
/// 上游无数据时,发送 SSE 心跳注释行的间隔
|
||
const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
|
||
|
||
/// 上游无数据时,丢弃连接的超时阈值(180s = 12 个心跳)
|
||
/// 实测 Kimi for Coding 的 thinking→content 间隔可达 60s+,需要更宽容的超时。
|
||
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(180);
|
||
|
||
/// 流结束后延迟清理的时间窗口(缩短到 5s,仅用于 Arc 引用释放)
|
||
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(5);
|
||
|
||
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
||
fn is_retryable_status(status: u16) -> bool {
|
||
status == 429 || (500..600).contains(&status)
|
||
}
|
||
|
||
/// 判断 reqwest 错误是否为可重试的网络错误
|
||
fn is_retryable_error(e: &reqwest::Error) -> bool {
|
||
e.is_timeout() || e.is_connect() || e.is_request()
|
||
}
|
||
|
||
// ============ Relay Task Management ============
|
||
|
||
/// 判断 sqlx 错误是否为可重试的瞬态错误(连接池耗尽、临时网络故障)
|
||
fn is_transient_db_error(e: &sqlx::Error) -> bool {
|
||
matches!(e, sqlx::Error::PoolTimedOut | sqlx::Error::Io(_))
|
||
}
|
||
|
||
pub async fn create_relay_task(
|
||
db: &PgPool,
|
||
account_id: &str,
|
||
provider_id: &str,
|
||
model_id: &str,
|
||
request_body: &str,
|
||
priority: i32,
|
||
max_attempts: u32,
|
||
) -> SaasResult<RelayTaskInfo> {
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
let request_hash = hash_request(request_body);
|
||
let max_attempts = max_attempts.max(1).min(5);
|
||
|
||
let query = sqlx::query_as::<_, RelayTaskRow>(
|
||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
|
||
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT"
|
||
)
|
||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now);
|
||
|
||
// 对瞬时 DB 错误(连接池耗尽/超时)重试一次
|
||
let row = match query.fetch_one(db).await {
|
||
Ok(row) => row,
|
||
Err(e) if is_transient_db_error(&e) => {
|
||
tracing::warn!("Transient DB error in create_relay_task, retrying: {}", e);
|
||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||
sqlx::query_as::<_, RelayTaskRow>(
|
||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
|
||
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT"
|
||
)
|
||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||
.fetch_one(db)
|
||
.await?
|
||
}
|
||
Err(e) => return Err(e.into()),
|
||
};
|
||
|
||
Ok(RelayTaskInfo {
|
||
id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id,
|
||
status: row.status, priority: row.priority, attempt_count: row.attempt_count,
|
||
max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens,
|
||
error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at,
|
||
completed_at: row.completed_at, created_at: row.created_at,
|
||
})
|
||
}
|
||
|
||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||
let row: Option<RelayTaskRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT
|
||
FROM relay_tasks WHERE id = $1"
|
||
)
|
||
.bind(task_id)
|
||
.fetch_optional(db)
|
||
.await?;
|
||
|
||
let r = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
|
||
|
||
Ok(RelayTaskInfo {
|
||
id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority,
|
||
attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens,
|
||
error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at,
|
||
})
|
||
}
|
||
|
||
pub async fn list_relay_tasks(
|
||
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
|
||
) -> SaasResult<crate::common::PaginatedResponse<RelayTaskInfo>> {
|
||
let page = query.page.unwrap_or(1).max(1) as u32;
|
||
let page_size = query.page_size.unwrap_or(20).min(100) as u32;
|
||
let offset = ((page - 1) * page_size) as i64;
|
||
|
||
let (count_sql, data_sql) = if query.status.is_some() {
|
||
(
|
||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status = $2",
|
||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT
|
||
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||
)
|
||
} else {
|
||
(
|
||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1",
|
||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at::TEXT, started_at::TEXT, completed_at::TEXT, created_at::TEXT
|
||
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||
)
|
||
};
|
||
|
||
let total: i64 = if query.status.is_some() {
|
||
sqlx::query_scalar(count_sql).bind(account_id).bind(query.status.as_ref().unwrap()).fetch_one(db).await?
|
||
} else {
|
||
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
|
||
};
|
||
|
||
let mut query_builder = sqlx::query_as::<_, RelayTaskRow>(data_sql)
|
||
.bind(account_id);
|
||
|
||
if let Some(ref status) = query.status {
|
||
query_builder = query_builder.bind(status);
|
||
}
|
||
|
||
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
|
||
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|r| {
|
||
RelayTaskInfo { id: r.id, account_id: r.account_id, provider_id: r.provider_id, model_id: r.model_id, status: r.status, priority: r.priority, attempt_count: r.attempt_count, max_attempts: r.max_attempts, input_tokens: r.input_tokens, output_tokens: r.output_tokens, error_message: r.error_message, queued_at: r.queued_at, started_at: r.started_at, completed_at: r.completed_at, created_at: r.created_at }
|
||
}).collect();
|
||
|
||
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||
}
|
||
|
||
pub async fn update_task_status(
|
||
db: &PgPool, task_id: &str, status: &str,
|
||
input_tokens: Option<i64>, output_tokens: Option<i64>,
|
||
error_message: Option<&str>,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
|
||
match status {
|
||
"processing" => {
|
||
sqlx::query(
|
||
"UPDATE relay_tasks SET started_at = $1, status = 'processing', attempt_count = attempt_count + 1 WHERE id = $2"
|
||
)
|
||
.bind(&now).bind(task_id)
|
||
.execute(db).await?;
|
||
}
|
||
"completed" => {
|
||
sqlx::query(
|
||
"UPDATE relay_tasks SET completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens) WHERE id = $4"
|
||
)
|
||
.bind(&now).bind(input_tokens).bind(output_tokens).bind(task_id)
|
||
.execute(db).await?;
|
||
}
|
||
"failed" => {
|
||
sqlx::query(
|
||
"UPDATE relay_tasks SET completed_at = $1, status = 'failed', error_message = $2 WHERE id = $3"
|
||
)
|
||
.bind(&now).bind(error_message).bind(task_id)
|
||
.execute(db).await?;
|
||
}
|
||
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
// ============ Relay Execution ============
|
||
|
||
/// SSE 流中的 usage 信息捕获器
|
||
#[derive(Debug, Clone, Default)]
|
||
struct SseUsageCapture {
|
||
input_tokens: i64,
|
||
output_tokens: i64,
|
||
}
|
||
|
||
impl SseUsageCapture {
|
||
fn parse_sse_line(&mut self, line: &str) {
|
||
if let Some(data) = line.strip_prefix("data: ") {
|
||
if data == "[DONE]" {
|
||
return;
|
||
}
|
||
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
|
||
if let Some(usage) = parsed.get("usage") {
|
||
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
|
||
self.input_tokens = input;
|
||
}
|
||
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
|
||
self.output_tokens = output;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
pub async fn execute_relay(
|
||
db: &PgPool,
|
||
task_id: &str,
|
||
provider_id: &str,
|
||
provider_base_url: &str,
|
||
request_body: &str,
|
||
stream: bool,
|
||
max_attempts: u32,
|
||
base_delay_ms: u64,
|
||
enc_key: &[u8; 32],
|
||
// 当由 `execute_relay_with_failover` 调用时为 false,由外层统一管理 task 状态;
|
||
// 独立调用时为 true,由本函数管理 task 状态。
|
||
manage_task_status: bool,
|
||
) -> SaasResult<RelayResponse> {
|
||
validate_provider_url(provider_base_url).await?;
|
||
|
||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||
|
||
// 复用全局 HTTP 客户端,避免每次请求重建 TLS 连接池和 DNS 解析器
|
||
static SHORT_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||
static LONG_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||
let client = if stream {
|
||
LONG_CLIENT.get_or_init(|| {
|
||
reqwest::Client::builder()
|
||
.timeout(std::time::Duration::from_secs(300))
|
||
.build()
|
||
.expect("Failed to build long-timeout HTTP client")
|
||
})
|
||
} else {
|
||
SHORT_CLIENT.get_or_init(|| {
|
||
reqwest::Client::builder()
|
||
.timeout(std::time::Duration::from_secs(30))
|
||
.build()
|
||
.expect("Failed to build short-timeout HTTP client")
|
||
})
|
||
};
|
||
|
||
let max_attempts = max_attempts.max(1).min(5);
|
||
|
||
// Key Pool 轮转状态
|
||
let mut current_key_id: Option<String> = None;
|
||
let mut current_api_key: Option<String> = None;
|
||
|
||
for attempt in 0..max_attempts {
|
||
let is_first = attempt == 0;
|
||
if is_first && manage_task_status {
|
||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||
}
|
||
|
||
// 首次或 429 后需要重新选择 Key
|
||
if current_key_id.is_none() {
|
||
match super::key_pool::select_best_key(db, provider_id, enc_key).await {
|
||
Ok(selection) => {
|
||
let key_id = selection.key_id.clone();
|
||
let key_value = selection.key.key_value.clone();
|
||
tracing::debug!(
|
||
"Relay task {} 选择 Key {} (attempt {})",
|
||
task_id, key_id, attempt + 1
|
||
);
|
||
current_key_id = Some(key_id);
|
||
current_api_key = Some(key_value);
|
||
}
|
||
Err(SaasError::RateLimited(msg)) => {
|
||
// 所有 Key 均在冷却中
|
||
let err_msg = format!("Key Pool 耗尽: {}", msg);
|
||
if manage_task_status {
|
||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||
}
|
||
return Err(SaasError::RateLimited(msg));
|
||
}
|
||
Err(e) => return Err(e),
|
||
}
|
||
}
|
||
|
||
let key_id = current_key_id.as_ref()
|
||
.ok_or_else(|| SaasError::Internal("Key pool selection failed: no key_id".into()))?
|
||
.clone();
|
||
let api_key = current_api_key.clone();
|
||
|
||
let mut req_builder = client.post(&url)
|
||
.header("Content-Type", "application/json")
|
||
// Kimi Coding Plan 等 Coding Agent API 需要识别 User-Agent 为 coding agent
|
||
.header("User-Agent", "claude-code/1.0")
|
||
.body(request_body.to_string());
|
||
|
||
if let Some(ref key) = api_key {
|
||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||
}
|
||
|
||
let result = req_builder.send().await;
|
||
|
||
match result {
|
||
Ok(resp) if resp.status().is_success() => {
|
||
if stream {
|
||
let usage_capture = Arc::new(Mutex::new(SseUsageCapture::default()));
|
||
let usage_capture_clone = usage_capture.clone();
|
||
let db_clone = db.clone();
|
||
let task_id_clone = task_id.to_string();
|
||
let key_id_for_spawn = key_id.clone();
|
||
|
||
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
|
||
// If the client reads slowly, the upstream is signaled via
|
||
// backpressure instead of growing memory indefinitely.
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<Result<bytes::Bytes, std::io::Error>>(128);
|
||
|
||
// Spawn a task to consume the upstream stream and forward through the bounded channel
|
||
tokio::spawn(async move {
|
||
use futures::StreamExt;
|
||
let mut upstream = resp.bytes_stream();
|
||
while let Some(chunk_result) = upstream.next().await {
|
||
match chunk_result {
|
||
Ok(chunk) => {
|
||
// Parse SSE lines for usage tracking
|
||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||
let mut capture = usage_capture_clone.lock().await;
|
||
for line in text.lines() {
|
||
capture.parse_sse_line(line);
|
||
}
|
||
}
|
||
// Forward to bounded channel — if full, this applies backpressure
|
||
if tx.send(Ok(chunk)).await.is_err() {
|
||
tracing::debug!("SSE relay: client disconnected, stopping upstream");
|
||
break;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
let err_msg = e.to_string();
|
||
if tx.send(Err(std::io::Error::other(e))).await.is_err() {
|
||
tracing::debug!("SSE relay: client disconnected before error sent: {}", err_msg);
|
||
}
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
});
|
||
|
||
// Build StreamBridge: wraps the bounded receiver with heartbeat,
|
||
// timeout, and delayed cleanup (DeerFlow-inspired backpressure).
|
||
let body = build_stream_bridge(rx, task_id.to_string());
|
||
|
||
// SSE 流结束后异步记录 usage + Key 使用量
|
||
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
||
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
|
||
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(64)));
|
||
let permit = match semaphore.clone().try_acquire_owned() {
|
||
Ok(p) => p,
|
||
Err(_) => {
|
||
// 信号量满时跳过 usage 记录,流本身不受影响
|
||
tracing::warn!("SSE usage spawn at capacity, skipping usage record for task {}", task_id);
|
||
return Ok(RelayResponse::Sse(body));
|
||
}
|
||
};
|
||
|
||
tokio::spawn(async move {
|
||
let _permit = permit; // 持有 permit 直到任务完成
|
||
// Brief delay to allow SSE stream to settle before recording
|
||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||
let capture = usage_capture.lock().await;
|
||
let (input, output) = (
|
||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||
);
|
||
// Record task status with timeout to avoid holding DB connections
|
||
let db_op = async {
|
||
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
|
||
tracing::warn!("Failed to update task status after SSE stream: {}", e);
|
||
}
|
||
// Record key usage (now 2 queries instead of 3)
|
||
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
|
||
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
|
||
tracing::warn!("Failed to record key usage: {}", e);
|
||
}
|
||
};
|
||
if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() {
|
||
tracing::warn!("SSE usage recording timed out for task {}", task_id_clone);
|
||
}
|
||
|
||
// StreamBridge 延迟清理:流结束 60s 后释放残留资源
|
||
// (主要是 Arc<SseUsageCapture> 等,通过 drop(_permit) 归还信号量)
|
||
tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await;
|
||
tracing::debug!(
|
||
"[StreamBridge] Cleanup delay elapsed for task {}",
|
||
task_id_clone
|
||
);
|
||
});
|
||
|
||
return Ok(RelayResponse::Sse(body));
|
||
} else {
|
||
let body = resp.text().await.unwrap_or_default();
|
||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||
if manage_task_status {
|
||
update_task_status(db, task_id, "completed",
|
||
Some(input_tokens), Some(output_tokens), None).await?;
|
||
}
|
||
// 记录 Key 使用量(失败仅记录,不阻塞响应)
|
||
if let Err(e) = super::key_pool::record_key_usage(
|
||
db, &key_id, Some(input_tokens + output_tokens),
|
||
).await {
|
||
tracing::warn!("[Relay] Failed to record key usage for billing: {}", e);
|
||
}
|
||
return Ok(RelayResponse::Json(body));
|
||
}
|
||
}
|
||
Ok(resp) => {
|
||
let status = resp.status().as_u16();
|
||
if status == 429 {
|
||
// 解析 Retry-After header
|
||
let retry_after = resp.headers()
|
||
.get("retry-after")
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(|v| v.parse::<u64>().ok());
|
||
|
||
// 标记 Key 为 429 冷却
|
||
if let Err(e) = super::key_pool::mark_key_429(db, &key_id, retry_after).await {
|
||
tracing::warn!("Failed to mark key 429: {}", e);
|
||
}
|
||
|
||
// 强制下次迭代重新选择 Key
|
||
current_key_id = None;
|
||
current_api_key = None;
|
||
|
||
if attempt + 1 >= max_attempts {
|
||
let err_msg = format!(
|
||
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
|
||
max_attempts
|
||
);
|
||
if manage_task_status {
|
||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||
}
|
||
return Err(SaasError::RateLimited(err_msg));
|
||
}
|
||
|
||
tracing::warn!(
|
||
"Relay task {} 收到 429,Key {} 已标记冷却 (attempt {}/{})",
|
||
task_id, key_id, attempt + 1, max_attempts
|
||
);
|
||
// 429 时立即切换 Key 重试,不做退避延迟
|
||
continue;
|
||
}
|
||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||
let body = resp.text().await.unwrap_or_default();
|
||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||
if manage_task_status {
|
||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||
}
|
||
return Err(SaasError::Relay(err_msg));
|
||
}
|
||
tracing::warn!(
|
||
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
|
||
task_id, status, attempt + 1, max_attempts
|
||
);
|
||
}
|
||
Err(e) => {
|
||
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
|
||
let err_msg = format!("请求上游失败: {}", e);
|
||
if manage_task_status {
|
||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||
}
|
||
return Err(SaasError::Relay(err_msg));
|
||
}
|
||
tracing::warn!(
|
||
"Relay task {} 网络错误 (attempt {}/{}): {}",
|
||
task_id, attempt + 1, max_attempts, e
|
||
);
|
||
}
|
||
}
|
||
|
||
// 非 429 错误使用指数退避
|
||
let delay_ms = base_delay_ms * (1 << attempt);
|
||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
||
}
|
||
|
||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||
}
|
||
|
||
// ============ 跨 Provider Failover ============
|
||
|
||
/// 跨 Provider Failover 执行器
|
||
///
|
||
/// 按配额余量自动排序候选模型,依次尝试每个 Provider 的 Key Pool,
|
||
/// 直到找到可用 Provider 或全部耗尽。
|
||
///
|
||
/// **注意**:Failover 仅适用于预流失败(连接错误、429/5xx 在流开始之前)。
|
||
/// SSE 一旦开始流式传输,中途上游断连不会触发 failover — 这是 SSE 协议的固有限制。
|
||
///
|
||
/// 返回 (RelayResponse, actual_provider_id, actual_model_id) 用于精确计费归因。
|
||
pub async fn execute_relay_with_failover(
|
||
db: &PgPool,
|
||
task_id: &str,
|
||
candidates: &[CandidateModel],
|
||
request_body: &str,
|
||
stream: bool,
|
||
max_attempts_per_provider: u32,
|
||
base_delay_ms: u64,
|
||
enc_key: &[u8; 32],
|
||
) -> SaasResult<(RelayResponse, String, String)> {
|
||
let mut last_error: Option<SaasError> = None;
|
||
let failover_start = std::time::Instant::now();
|
||
const FAILOVER_TIMEOUT: Duration = Duration::from_secs(60);
|
||
|
||
// C-3: 外层统一管理 task 状态 — 仅设一次 "processing"
|
||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||
|
||
for (idx, candidate) in candidates.iter().enumerate() {
|
||
// M-3: 超时预算检查 — 防止级联失败累积过长
|
||
if failover_start.elapsed() >= FAILOVER_TIMEOUT {
|
||
tracing::warn!(
|
||
"Failover timeout ({:?}) exceeded after {}/{} candidates for task {}",
|
||
FAILOVER_TIMEOUT, idx, candidates.len(), task_id
|
||
);
|
||
break;
|
||
}
|
||
|
||
// 替换请求体中的 model 字段为当前候选的物理模型 ID
|
||
let patched_body = patch_model_in_body(request_body, &candidate.model_id);
|
||
|
||
match execute_relay(
|
||
db,
|
||
task_id,
|
||
&candidate.provider_id,
|
||
&candidate.base_url,
|
||
&patched_body,
|
||
stream,
|
||
max_attempts_per_provider,
|
||
base_delay_ms,
|
||
enc_key,
|
||
false, // C-3: 外层管理 task 状态
|
||
)
|
||
.await
|
||
{
|
||
Ok(response) => {
|
||
if idx > 0 {
|
||
tracing::info!(
|
||
"Failover succeeded on candidate {}/{} (provider={}, model={})",
|
||
idx + 1,
|
||
candidates.len(),
|
||
candidate.provider_id,
|
||
candidate.model_id
|
||
);
|
||
}
|
||
return Ok((response, candidate.provider_id.clone(), candidate.model_id.clone()));
|
||
}
|
||
Err(SaasError::RateLimited(msg)) => {
|
||
tracing::warn!(
|
||
"Provider {} rate limited ({}), trying next candidate ({}/{})",
|
||
candidate.provider_id,
|
||
msg,
|
||
idx + 1,
|
||
candidates.len()
|
||
);
|
||
last_error = Some(SaasError::RateLimited(msg));
|
||
continue;
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(
|
||
"Provider {} failed: {}, trying next candidate ({}/{})",
|
||
candidate.provider_id,
|
||
e,
|
||
idx + 1,
|
||
candidates.len()
|
||
);
|
||
last_error = Some(e);
|
||
continue;
|
||
}
|
||
}
|
||
}
|
||
|
||
// C-3: 所有候选失败 — 外层统一标记 task 为 "failed"
|
||
let final_error = last_error.unwrap_or_else(|| SaasError::RateLimited(
|
||
"所有候选 Provider 均不可用".into(),
|
||
));
|
||
let err_msg = format!("{}", final_error);
|
||
if let Err(e) = update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await {
|
||
tracing::warn!("Failed to update task {} status after failover exhaustion: {}", task_id, e);
|
||
}
|
||
Err(final_error)
|
||
}
|
||
|
||
/// 替换 JSON body 中的 "model" 字段为当前候选的物理模型 ID
|
||
fn patch_model_in_body(body: &str, new_model_id: &str) -> String {
|
||
if let Ok(mut parsed) = serde_json::from_str::<serde_json::Value>(body) {
|
||
if let Some(obj) = parsed.as_object_mut() {
|
||
obj.insert(
|
||
"model".to_string(),
|
||
serde_json::Value::String(new_model_id.to_string()),
|
||
);
|
||
}
|
||
serde_json::to_string(&parsed).unwrap_or_else(|_| body.to_string())
|
||
} else {
|
||
body.to_string()
|
||
}
|
||
}
|
||
|
||
/// 按配额余量排序候选模型
|
||
///
|
||
/// 查询每个候选 Provider 的 Key Pool 当前 RPM 余量,余量最多的排前面。
|
||
/// 复用 key_usage_window 表的实时数据,仅执行一次聚合查询。
|
||
/// 使用内存缓存(TTL 5s)减少 DB 查询频率。
|
||
pub async fn sort_candidates_by_quota(
|
||
db: &PgPool,
|
||
candidates: &mut [CandidateModel],
|
||
) {
|
||
if candidates.len() <= 1 {
|
||
return;
|
||
}
|
||
|
||
let provider_ids: Vec<String> = candidates.iter().map(|c| c.provider_id.clone()).collect();
|
||
|
||
// H-4: 配额排序缓存(TTL 5 秒),减少关键路径 DB 查询
|
||
static QUOTA_CACHE: OnceLock<std::sync::Mutex<HashMap<String, (i64, std::time::Instant)>>> = OnceLock::new();
|
||
let cache = QUOTA_CACHE.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
|
||
const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5);
|
||
|
||
let now = std::time::Instant::now();
|
||
// 先提取缓存值后立即释放锁,避免 MutexGuard 跨 await
|
||
let cached_entries: HashMap<String, (i64, std::time::Instant)> = {
|
||
let guard = cache.lock().unwrap_or_else(|e| e.into_inner());
|
||
guard.clone()
|
||
};
|
||
let all_fresh = provider_ids.iter().all(|pid| {
|
||
cached_entries.get(pid)
|
||
.map(|(_, ts)| now.duration_since(*ts) < QUOTA_CACHE_TTL)
|
||
.unwrap_or(false)
|
||
});
|
||
|
||
let quota_map: HashMap<String, i64> = if all_fresh {
|
||
provider_ids.iter()
|
||
.filter_map(|pid| cached_entries.get(pid).map(|(remaining, _)| (pid.clone(), *remaining)))
|
||
.collect()
|
||
} else {
|
||
|
||
let quota_rows: Vec<(String, i64)> = match sqlx::query_as(
|
||
r#"
|
||
SELECT pk.provider_id,
|
||
SUM(COALESCE(pk.max_rpm, 999999) - COALESCE(uw.request_count, 0)) AS remaining_rpm
|
||
FROM provider_keys pk
|
||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
|
||
WHERE pk.provider_id = ANY($1)
|
||
AND pk.is_active = TRUE
|
||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= NOW())
|
||
GROUP BY pk.provider_id
|
||
"#,
|
||
)
|
||
.bind(&provider_ids)
|
||
.fetch_all(db)
|
||
.await
|
||
{
|
||
Ok(rows) => rows,
|
||
Err(e) => {
|
||
// M-6: DB 查询失败时记录警告,使用原始顺序
|
||
tracing::warn!("sort_candidates_by_quota DB query failed: {}", e);
|
||
return;
|
||
}
|
||
};
|
||
|
||
let map: HashMap<String, i64> = quota_rows.into_iter().collect();
|
||
|
||
// 更新缓存 + 清理过期条目
|
||
{
|
||
let mut cache_guard = cache.lock().unwrap_or_else(|e| e.into_inner());
|
||
for (pid, remaining) in &map {
|
||
cache_guard.insert(pid.clone(), (*remaining, now));
|
||
}
|
||
// M-S3: 清理超过 TTL 5x(25s)的陈旧条目,防止已删除 Provider 的条目永久残留
|
||
let ttl_5x = QUOTA_CACHE_TTL * 5;
|
||
cache_guard.retain(|_, (_, ts)| now.saturating_duration_since(*ts) < ttl_5x);
|
||
}
|
||
|
||
map
|
||
};
|
||
|
||
// H-1: 新 Provider 没有 usage 记录 → unwrap_or(999999) 表示完整余量
|
||
candidates.sort_by(|a, b| {
|
||
let qa = quota_map.get(&a.provider_id).copied().unwrap_or(999999);
|
||
let qb = quota_map.get(&b.provider_id).copied().unwrap_or(999999);
|
||
qb.cmp(&qa) // 降序:余量多的排前面
|
||
});
|
||
}
|
||
|
||
/// 中转响应类型
|
||
#[derive(Debug)]
|
||
pub enum RelayResponse {
|
||
Json(String),
|
||
Sse(axum::body::Body),
|
||
}
|
||
|
||
// ============ StreamBridge ============
|
||
|
||
/// 构建 StreamBridge:将 mpsc::Receiver 包装为带心跳、超时的 axum Body。
|
||
///
|
||
/// 借鉴 DeerFlow StreamBridge 背压机制:
|
||
/// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃
|
||
/// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流
|
||
/// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源
|
||
fn build_stream_bridge(
|
||
mut rx: tokio::sync::mpsc::Receiver<Result<bytes::Bytes, std::io::Error>>,
|
||
task_id: String,
|
||
) -> axum::body::Body {
|
||
// SSE heartbeat comment bytes: `: heartbeat\n\n`
|
||
// SSE spec: lines starting with `:` are comments and ignored by clients
|
||
const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n";
|
||
// SSE timeout error event
|
||
const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n";
|
||
|
||
let stream = async_stream::stream! {
|
||
// Track how many consecutive heartbeat-only cycles have elapsed.
|
||
// Real data resets this counter; after 2 heartbeats (30s) without
|
||
// real data, we terminate the stream.
|
||
let mut idle_heartbeats: u32 = 0;
|
||
|
||
loop {
|
||
// tokio::select! races the next data chunk against a heartbeat timer.
|
||
// The timer resets on every iteration, ensuring heartbeats only fire
|
||
// during genuine idle periods.
|
||
tokio::select! {
|
||
biased; // prioritize data over heartbeat
|
||
|
||
chunk = rx.recv() => {
|
||
match chunk {
|
||
Some(Ok(data)) => {
|
||
// Real data received — reset idle counter
|
||
idle_heartbeats = 0;
|
||
yield Ok::<bytes::Bytes, std::io::Error>(data);
|
||
}
|
||
Some(Err(e)) => {
|
||
tracing::warn!(
|
||
"[StreamBridge] Upstream error for task {}: {}",
|
||
task_id, e
|
||
);
|
||
yield Err(e);
|
||
break;
|
||
}
|
||
None => {
|
||
// Channel closed = upstream finished normally
|
||
tracing::debug!(
|
||
"[StreamBridge] Upstream completed for task {}",
|
||
task_id
|
||
);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Heartbeat: send SSE comment if no data for 15s
|
||
_ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => {
|
||
idle_heartbeats += 1;
|
||
tracing::trace!(
|
||
"[StreamBridge] Heartbeat #{} for task {} (idle {}s)",
|
||
idle_heartbeats,
|
||
task_id,
|
||
idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(),
|
||
);
|
||
|
||
// After 12 consecutive heartbeats without real data (180s),
|
||
// terminate the stream to prevent connection leaks.
|
||
if idle_heartbeats >= 12 {
|
||
tracing::warn!(
|
||
"[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}",
|
||
STREAMBRIDGE_TIMEOUT,
|
||
task_id,
|
||
);
|
||
yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT));
|
||
break;
|
||
}
|
||
|
||
yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES));
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
// Pin the stream to a Box<dyn Stream + Send> to satisfy Body::from_stream
|
||
let boxed: std::pin::Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes, std::io::Error>> + Send>> =
|
||
Box::pin(stream);
|
||
|
||
axum::body::Body::from_stream(boxed)
|
||
}
|
||
|
||
// ============ Helpers ============
|
||
|
||
fn hash_request(body: &str) -> String {
|
||
use sha2::{Sha256, Digest};
|
||
hex::encode(Sha256::digest(body.as_bytes()))
|
||
}
|
||
|
||
/// 从 JSON 响应中提取 token 使用量
|
||
fn extract_token_usage(body: &str) -> (i64, i64) {
|
||
let parsed: serde_json::Value = match serde_json::from_str(body) {
|
||
Ok(v) => v,
|
||
Err(e) => {
|
||
tracing::debug!("extract_token_usage: JSON parse failed (body len={}): {}", body.len(), e);
|
||
return (0, 0);
|
||
}
|
||
};
|
||
|
||
let usage = parsed.get("usage");
|
||
let input = usage
|
||
.and_then(|u| u.get("prompt_tokens"))
|
||
.and_then(|v| v.as_i64())
|
||
.unwrap_or(0);
|
||
let output = usage
|
||
.and_then(|u| u.get("completion_tokens"))
|
||
.and_then(|v| v.as_i64())
|
||
.unwrap_or(0);
|
||
|
||
(input, output)
|
||
}
|
||
|
||
/// 从 JSON 响应中提取 token 使用量 (公开版本)
|
||
pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
|
||
extract_token_usage(body)
|
||
}
|
||
|
||
/// SSRF 防护: 验证 provider URL 不指向内网
|
||
async fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||
let parsed: url::Url = url.parse().map_err(|_| {
|
||
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
|
||
})?;
|
||
|
||
// 只允许 https
|
||
match parsed.scheme() {
|
||
"https" => {}
|
||
"http" => {
|
||
// 开发环境允许 http
|
||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||
.map(|v| v == "true" || v == "1")
|
||
.unwrap_or(false);
|
||
if !is_dev {
|
||
return Err(SaasError::InvalidInput("生产环境禁止 http scheme,请使用 https".into()));
|
||
}
|
||
}
|
||
_ => return Err(SaasError::InvalidInput(format!("不允许的 URL scheme: {}", parsed.scheme()))),
|
||
}
|
||
|
||
// 禁止内网地址
|
||
let host = match parsed.host_str() {
|
||
Some(h) => h,
|
||
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
|
||
};
|
||
|
||
// 去除 IPv6 方括号
|
||
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||
|
||
// 精确匹配的阻止列表: 仅包含主机名和特殊域名
|
||
// 私有 IP 范围 (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, ::1 等)
|
||
// 由 is_private_ip() 统一判断,无需在此重复列出
|
||
let blocked_exact = [
|
||
"localhost", "metadata.google.internal",
|
||
];
|
||
if blocked_exact.contains(&host) {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
||
}
|
||
|
||
// 后缀匹配 (阻止子域名)
|
||
let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"];
|
||
for suffix in &blocked_suffixes {
|
||
if host.ends_with(&format!(".{}", suffix)) {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
||
}
|
||
}
|
||
|
||
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
|
||
if host.parse::<u64>().is_ok() {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||
}
|
||
|
||
// 阻止十六进制/八进制 IP 混淆 (如 0x7f000001, 0177.0.0.1)
|
||
if host.chars().all(|c| c.is_ascii_hexdigit() || c == '.' || c == ':' || c == 'x' || c == 'X') {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||
}
|
||
|
||
// 阻止 IPv4 私有网段 (通过解析 IP)
|
||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||
if is_private_ip(&ip) {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
|
||
}
|
||
return Ok(());
|
||
}
|
||
|
||
// 对域名做异步 DNS 解析,检查解析结果是否指向内网
|
||
let addr_str = format!("{}:0", host);
|
||
match tokio::net::lookup_host(&*addr_str).await {
|
||
Ok(addrs) => {
|
||
for sockaddr in addrs {
|
||
if is_private_ip(&sockaddr.ip()) {
|
||
return Err(SaasError::InvalidInput(
|
||
"provider URL 域名解析到内网地址".into()
|
||
));
|
||
}
|
||
}
|
||
}
|
||
Err(_) => {
|
||
// DNS 解析失败,可能是无效域名,不阻止请求
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 检查 IP 是否属于私有/内网地址范围
|
||
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||
match ip {
|
||
std::net::IpAddr::V4(v4) => {
|
||
let octets = v4.octets();
|
||
// 10.0.0.0/8
|
||
octets[0] == 10
|
||
// 172.16.0.0/12
|
||
|| (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31)
|
||
// 192.168.0.0/16
|
||
|| (octets[0] == 192 && octets[1] == 168)
|
||
// 127.0.0.0/8 (loopback)
|
||
|| octets[0] == 127
|
||
// 169.254.0.0/16 (link-local)
|
||
|| (octets[0] == 169 && octets[1] == 254)
|
||
// 0.0.0.0/8
|
||
|| octets[0] == 0
|
||
}
|
||
std::net::IpAddr::V6(v6) => {
|
||
// ::1 (loopback)
|
||
v6.is_loopback()
|
||
// ::ffff:x.x.x.x (IPv6-mapped IPv4)
|
||
|| v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4)))
|
||
// fe80::/10 (link-local)
|
||
|| (v6.segments()[0] & 0xffc0) == 0xfe80
|
||
}
|
||
}
|
||
}
|