chore: 提交所有工作进度 — SaaS 后端增强、Admin UI、桌面端集成

包含大量 SaaS 平台改进、Admin 管理后台更新、桌面端集成完善、
文档同步、测试文件重构等内容。为 QA 测试准备干净工作树。
This commit is contained in:
iven
2026-03-29 10:46:26 +08:00
parent 9a5fad2b59
commit 5fdf96c3f5
268 changed files with 22011 additions and 3886 deletions

View File

@@ -1,6 +1,8 @@
//! 中转服务核心逻辑
use sqlx::SqlitePool;
use sqlx::PgPool;
use std::sync::Arc;
use std::sync::Mutex;
use crate::error::{SaasError, SaasResult};
use super::types::*;
use futures::StreamExt;
@@ -18,7 +20,7 @@ fn is_retryable_error(e: &reqwest::Error) -> bool {
// ============ Relay Task Management ============
pub async fn create_relay_task(
db: &SqlitePool,
db: &PgPool,
account_id: &str,
provider_id: &str,
model_id: &str,
@@ -33,7 +35,7 @@ pub async fn create_relay_task(
sqlx::query(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)"
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
@@ -42,11 +44,11 @@ pub async fn create_relay_task(
get_relay_task(db, &id).await
}
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = ?1"
FROM relay_tasks WHERE id = $1"
)
.bind(task_id)
.fetch_optional(db)
@@ -63,50 +65,62 @@ pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayT
}
pub async fn list_relay_tasks(
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<Vec<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<crate::common::PaginatedResponse<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1) as u32;
let page_size = query.page_size.unwrap_or(20).min(100) as u32;
let offset = ((page - 1) * page_size) as i64;
let sql = if query.status.is_some() {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
let (count_sql, data_sql) = if query.status.is_some() {
(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status = $2",
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
)
} else {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1",
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
)
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
let total: i64 = if query.status.is_some() {
sqlx::query_scalar(count_sql).bind(account_id).bind(query.status.as_ref().unwrap()).fetch_one(db).await?
} else {
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
.bind(account_id);
if let Some(ref status) = query.status {
query_builder = query_builder.bind(status);
}
query_builder = query_builder.bind(page_size).bind(offset);
let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
}).collect())
}).collect();
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
}
pub async fn update_task_status(
db: &SqlitePool, task_id: &str, status: &str,
db: &PgPool, task_id: &str, status: &str,
input_tokens: Option<i64>, output_tokens: Option<i64>,
error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let update_sql = match status {
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
"processing" => "started_at = $1, status = 'processing', attempt_count = attempt_count + 1",
"completed" => "completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens)",
"failed" => "completed_at = $1, status = 'failed', error_message = $2",
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
};
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
let sql = format!("UPDATE relay_tasks SET {} WHERE id = $4", update_sql);
let mut query = sqlx::query(&sql).bind(&now);
if status == "completed" {
@@ -123,15 +137,43 @@ pub async fn update_task_status(
// ============ Relay Execution ============
/// SSE 流中的 usage 信息捕获器
#[derive(Debug, Clone, Default)]
struct SseUsageCapture {
input_tokens: i64,
output_tokens: i64,
}
impl SseUsageCapture {
fn parse_sse_line(&mut self, line: &str) {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return;
}
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(usage) = parsed.get("usage") {
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
self.input_tokens = input;
}
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
self.output_tokens = output;
}
}
}
}
}
}
pub async fn execute_relay(
db: &SqlitePool,
db: &PgPool,
task_id: &str,
provider_id: &str,
provider_base_url: &str,
provider_api_key: Option<&str>,
request_body: &str,
stream: bool,
max_attempts: u32,
base_delay_ms: u64,
enc_key: &[u8; 32],
) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url)?;
@@ -144,17 +186,47 @@ pub async fn execute_relay(
let max_attempts = max_attempts.max(1).min(5);
// Key Pool 轮转状态
let mut current_key_id: Option<String> = None;
let mut current_api_key: Option<String> = None;
for attempt in 0..max_attempts {
let is_first = attempt == 0;
if is_first {
update_task_status(db, task_id, "processing", None, None, None).await?;
}
// 首次或 429 后需要重新选择 Key
if current_key_id.is_none() {
match super::key_pool::select_best_key(db, provider_id, enc_key).await {
Ok(selection) => {
let key_id = selection.key_id.clone();
let key_value = selection.key.key_value.clone();
tracing::debug!(
"Relay task {} 选择 Key {} (attempt {})",
task_id, key_id, attempt + 1
);
current_key_id = Some(key_id);
current_api_key = Some(key_value);
}
Err(SaasError::RateLimited(msg)) => {
// 所有 Key 均在冷却中
let err_msg = format!("Key Pool 耗尽: {}", msg);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
return Err(SaasError::RateLimited(msg));
}
Err(e) => return Err(e),
}
}
let key_id = current_key_id.as_ref().unwrap().clone();
let api_key = current_api_key.clone();
let mut req_builder = client.post(&url)
.header("Content-Type", "application/json")
.body(request_body.to_string());
if let Some(key) = provider_api_key {
if let Some(ref key) = api_key {
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
}
@@ -162,31 +234,128 @@ pub async fn execute_relay(
match result {
Ok(resp) if resp.status().is_success() => {
// 成功
if stream {
let byte_stream = resp.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let body = axum::body::Body::from_stream(byte_stream);
update_task_status(db, task_id, "completed", None, None, None).await?;
let usage_capture = Arc::new(Mutex::new(SseUsageCapture::default()));
let usage_capture_clone = usage_capture.clone();
let db_clone = db.clone();
let task_id_clone = task_id.to_string();
let key_id_for_spawn = key_id.clone();
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
// If the client reads slowly, the upstream is signaled via
// backpressure instead of growing memory indefinitely.
let (tx, rx) = tokio::sync::mpsc::channel::<Result<bytes::Bytes, std::io::Error>>(128);
// Spawn a task to consume the upstream stream and forward through the bounded channel
tokio::spawn(async move {
use futures::StreamExt;
let mut upstream = resp.bytes_stream();
while let Some(chunk_result) = upstream.next().await {
match chunk_result {
Ok(chunk) => {
// Parse SSE lines for usage tracking
if let Ok(text) = std::str::from_utf8(&chunk) {
if let Ok(mut capture) = usage_capture_clone.lock() {
for line in text.lines() {
capture.parse_sse_line(line);
}
}
}
// Forward to bounded channel — if full, this applies backpressure
if tx.send(Ok(chunk)).await.is_err() {
tracing::debug!("SSE relay: client disconnected, stopping upstream");
break;
}
}
Err(e) => {
let _ = tx.send(Err(std::io::Error::other(e))).await;
break;
}
}
}
});
// Convert mpsc::Receiver into a Body stream
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let body = axum::body::Body::from_stream(body_stream);
// SSE 流结束后异步记录 usage + Key 使用量
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
let (input, output) = match usage_capture.lock() {
Ok(capture) => (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
),
Err(e) => {
tracing::warn!("Usage capture lock poisoned: {}", e);
(None, None)
}
};
// 记录任务状态
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
tracing::warn!("Failed to update task status after SSE stream: {}", e);
}
// 记录 Key 使用量
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
tracing::warn!("Failed to record key usage: {}", e);
}
});
return Ok(RelayResponse::Sse(body));
} else {
let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body);
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
// 记录 Key 使用量
let _ = super::key_pool::record_key_usage(
db, &key_id, Some(input_tokens + output_tokens),
).await;
return Ok(RelayResponse::Json(body));
}
}
Ok(resp) => {
let status = resp.status().as_u16();
if status == 429 {
// 解析 Retry-After header
let retry_after = resp.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
// 标记 Key 为 429 冷却
if let Err(e) = super::key_pool::mark_key_429(db, &key_id, retry_after).await {
tracing::warn!("Failed to mark key 429: {}", e);
}
// 强制下次迭代重新选择 Key
current_key_id = None;
current_api_key = None;
if attempt + 1 >= max_attempts {
let err_msg = format!(
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
max_attempts
);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
return Err(SaasError::RateLimited(err_msg));
}
tracing::warn!(
"Relay task {} 收到 429Key {} 已标记冷却 (attempt {}/{})",
task_id, key_id, attempt + 1, max_attempts
);
// 429 时立即切换 Key 重试,不做退避延迟
continue;
}
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
// 4xx 客户端错误或已达最大重试次数 → 立即失败
let body = resp.text().await.unwrap_or_default();
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
return Err(SaasError::Relay(err_msg));
}
// 可重试的服务端错误 → 继续循环
tracing::warn!(
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
task_id, status, attempt + 1, max_attempts
@@ -205,12 +374,11 @@ pub async fn execute_relay(
}
}
// 指数退避: base_delay * 2^attempt
// 非 429 错误使用指数退避
let delay_ms = base_delay_ms * (1 << attempt);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
// 理论上不会到达 (循环内已处理),但满足编译器
Err(SaasError::Relay("重试次数已耗尽".into()))
}
@@ -228,6 +396,7 @@ fn hash_request(body: &str) -> String {
hex::encode(Sha256::digest(body.as_bytes()))
}
/// 从 JSON 响应中提取 token 使用量
fn extract_token_usage(body: &str) -> (i64, i64) {
let parsed: serde_json::Value = match serde_json::from_str(body) {
Ok(v) => v,
@@ -247,6 +416,11 @@ fn extract_token_usage(body: &str) -> (i64, i64) {
(input, output)
}
/// 从 JSON 响应中提取 token 使用量 (公开版本)
pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
extract_token_usage(body)
}
/// SSRF 防护: 验证 provider URL 不指向内网
fn validate_provider_url(url: &str) -> SaasResult<()> {
let parsed: url::Url = url.parse().map_err(|_| {
@@ -274,6 +448,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
};
// 去除 IPv6 方括号
let host = host.trim_start_matches('[').trim_end_matches(']');
// 精确匹配的阻止列表
let blocked_exact = [
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
@@ -292,16 +469,39 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
}
}
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
if host.parse::<u64>().is_ok() {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
}
// 阻止十六进制/八进制 IP 混淆 (如 0x7f000001, 0177.0.0.1)
if host.chars().all(|c| c.is_ascii_hexdigit() || c == '.' || c == ':' || c == 'x' || c == 'X') {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
}
// 阻止 IPv4 私有网段 (通过解析 IP)
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_private_ip(&ip) {
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
}
return Ok(());
}
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
if host.parse::<u64>().is_ok() {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
// 对域名做 DNS 解析,检查解析结果是否指向内网
let addr_str: String = format!("{}:0", host);
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
Ok(addrs) => {
for sockaddr in addrs {
if is_private_ip(&sockaddr.ip()) {
return Err(SaasError::InvalidInput(
"provider URL 域名解析到内网地址".into()
));
}
}
}
Err(_) => {
// DNS 解析失败,可能是无效域名,不阻止请求
}
}
Ok(())