Files
zclaw_openfang/crates/zclaw-saas/src/relay/service.rs
iven d760b9ca10 feat(saas): Phase 1 后端能力补强 — API Token 认证、真实 SSE 流式、速率限制
Phase 1.1: API Token 认证中间件
- auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证)
- 合并 token 自身权限与角色权限,异步更新 last_used_at
- 添加 GET /api/v1/auth/me 端点返回当前用户信息
- get_role_permissions 改为 pub(crate) 供中间件调用

Phase 1.2: 真实 SSE 流式中转
- RelayResponse::Sse 改为 axum::body::Body (bytes_stream)
- 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection)
- 添加 futures 依赖用于 StreamExt

Phase 1.3: 滑动窗口速率限制中间件
- 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst)
- 超限返回 429 + Retry-After header
- RateLimitConfig 支持配置化,DashMap 存储时间戳

21 tests passed, zero warnings.
2026-03-27 13:49:45 +08:00

298 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 中转服务核心逻辑
use sqlx::SqlitePool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
use futures::StreamExt;
// ============ Relay Task Management ============
pub async fn create_relay_task(
db: &SqlitePool,
account_id: &str,
provider_id: &str,
model_id: &str,
request_body: &str,
priority: i64,
) -> SaasResult<RelayTaskInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let request_hash = hash_request(request_body);
sqlx::query(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(&now)
.execute(db).await?;
get_relay_task(db, &id).await
}
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = ?1"
)
.bind(task_id)
.fetch_optional(db)
.await?;
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
Ok(RelayTaskInfo {
id, account_id, provider_id, model_id, status, priority,
attempt_count, max_attempts, input_tokens, output_tokens,
error_message, queued_at, started_at, completed_at, created_at,
})
}
pub async fn list_relay_tasks(
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<Vec<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
let sql = if query.status.is_some() {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
} else {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
.bind(account_id);
if let Some(ref status) = query.status {
query_builder = query_builder.bind(status);
}
query_builder = query_builder.bind(page_size).bind(offset);
let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
}).collect())
}
pub async fn update_task_status(
db: &SqlitePool, task_id: &str, status: &str,
input_tokens: Option<i64>, output_tokens: Option<i64>,
error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let update_sql = match status {
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
};
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
let mut query = sqlx::query(&sql).bind(&now);
if status == "completed" {
query = query.bind(input_tokens).bind(output_tokens);
}
if status == "failed" {
query = query.bind(error_message);
}
query = query.bind(task_id);
query.execute(db).await?;
Ok(())
}
// ============ Relay Execution ============
pub async fn execute_relay(
db: &SqlitePool,
task_id: &str,
provider_base_url: &str,
provider_api_key: Option<&str>,
request_body: &str,
stream: bool,
) -> SaasResult<RelayResponse> {
update_task_status(db, task_id, "processing", None, None, None).await?;
// SSRF 防护: 验证 URL scheme 和禁止内网地址
validate_provider_url(provider_base_url)?;
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
let _start = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
.build()
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
let mut req_builder = client.post(&url)
.header("Content-Type", "application/json")
.body(request_body.to_string());
if let Some(key) = provider_api_key {
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
}
let result = req_builder.send().await;
match result {
Ok(resp) if resp.status().is_success() => {
if stream {
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
let stream = resp.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let body = axum::body::Body::from_stream(stream);
// 流式模式下无法提取 token usage标记为 completed (usage=0)
update_task_status(db, task_id, "completed", None, None, None).await?;
Ok(RelayResponse::Sse(body))
} else {
let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body);
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
Ok(RelayResponse::Json(body))
}
}
Ok(resp) => {
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
Err(SaasError::Relay(err_msg))
}
Err(e) => {
let err_msg = format!("请求上游失败: {}", e);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
Err(SaasError::Relay(err_msg))
}
}
}
/// 中转响应类型
#[derive(Debug)]
pub enum RelayResponse {
Json(String),
Sse(axum::body::Body),
}
// ============ Helpers ============
fn hash_request(body: &str) -> String {
use sha2::{Sha256, Digest};
hex::encode(Sha256::digest(body.as_bytes()))
}
fn extract_token_usage(body: &str) -> (i64, i64) {
let parsed: serde_json::Value = match serde_json::from_str(body) {
Ok(v) => v,
Err(_) => return (0, 0),
};
let usage = parsed.get("usage");
let input = usage
.and_then(|u| u.get("prompt_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let output = usage
.and_then(|u| u.get("completion_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
(input, output)
}
/// SSRF 防护: 验证 provider URL 不指向内网
fn validate_provider_url(url: &str) -> SaasResult<()> {
let parsed: url::Url = url.parse().map_err(|_| {
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
})?;
// 只允许 https
match parsed.scheme() {
"https" => {}
"http" => {
// 开发环境允许 http
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !is_dev {
return Err(SaasError::InvalidInput("生产环境禁止 http scheme请使用 https".into()));
}
}
_ => return Err(SaasError::InvalidInput(format!("不允许的 URL scheme: {}", parsed.scheme()))),
}
// 禁止内网地址
let host = match parsed.host_str() {
Some(h) => h,
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
};
// 精确匹配的阻止列表
let blocked_exact = [
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
"0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal",
"10.0.0.1", "172.16.0.1", "192.168.0.1",
];
if blocked_exact.contains(&host) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
}
// 后缀匹配 (阻止子域名)
let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"];
for suffix in &blocked_suffixes {
if host.ends_with(&format!(".{}", suffix)) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
}
}
// 阻止 IPv4 私有网段 (通过解析 IP)
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_private_ip(&ip) {
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
}
}
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
if host.parse::<u64>().is_ok() {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
}
Ok(())
}
/// 检查 IP 是否属于私有/内网地址范围
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
let octets = v4.octets();
// 10.0.0.0/8
octets[0] == 10
// 172.16.0.0/12
|| (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31)
// 192.168.0.0/16
|| (octets[0] == 192 && octets[1] == 168)
// 127.0.0.0/8 (loopback)
|| octets[0] == 127
// 169.254.0.0/16 (link-local)
|| (octets[0] == 169 && octets[1] == 254)
// 0.0.0.0/8
|| octets[0] == 0
}
std::net::IpAddr::V6(v6) => {
// ::1 (loopback)
v6.is_loopback()
// ::ffff:x.x.x.x (IPv6-mapped IPv4)
|| v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4)))
// fe80::/10 (link-local)
|| (v6.segments()[0] & 0xffc0) == 0xfe80
}
}
}