Phase 1.1: API Token 认证中间件 - auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证) - 合并 token 自身权限与角色权限,异步更新 last_used_at - 添加 GET /api/v1/auth/me 端点返回当前用户信息 - get_role_permissions 改为 pub(crate) 供中间件调用 Phase 1.2: 真实 SSE 流式中转 - RelayResponse::Sse 改为 axum::body::Body (bytes_stream) - 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection) - 添加 futures 依赖用于 StreamExt Phase 1.3: 滑动窗口速率限制中间件 - 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst) - 超限返回 429 + Retry-After header - RateLimitConfig 支持配置化,DashMap 存储时间戳 21 tests passed, zero warnings.
298 lines
12 KiB
Rust
298 lines
12 KiB
Rust
//! 中转服务核心逻辑
|
||
|
||
use sqlx::SqlitePool;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use super::types::*;
|
||
use futures::StreamExt;
|
||
|
||
// ============ Relay Task Management ============
|
||
|
||
pub async fn create_relay_task(
|
||
db: &SqlitePool,
|
||
account_id: &str,
|
||
provider_id: &str,
|
||
model_id: &str,
|
||
request_body: &str,
|
||
priority: i64,
|
||
) -> SaasResult<RelayTaskInfo> {
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now().to_rfc3339();
|
||
let request_hash = hash_request(request_body);
|
||
|
||
sqlx::query(
|
||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)"
|
||
)
|
||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||
.bind(&request_hash).bind(request_body).bind(priority).bind(&now)
|
||
.execute(db).await?;
|
||
|
||
get_relay_task(db, &id).await
|
||
}
|
||
|
||
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
||
sqlx::query_as(
|
||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||
FROM relay_tasks WHERE id = ?1"
|
||
)
|
||
.bind(task_id)
|
||
.fetch_optional(db)
|
||
.await?;
|
||
|
||
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
|
||
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
|
||
|
||
Ok(RelayTaskInfo {
|
||
id, account_id, provider_id, model_id, status, priority,
|
||
attempt_count, max_attempts, input_tokens, output_tokens,
|
||
error_message, queued_at, started_at, completed_at, created_at,
|
||
})
|
||
}
|
||
|
||
pub async fn list_relay_tasks(
|
||
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
|
||
) -> SaasResult<Vec<RelayTaskInfo>> {
|
||
let page = query.page.unwrap_or(1).max(1);
|
||
let page_size = query.page_size.unwrap_or(20).min(100);
|
||
let offset = (page - 1) * page_size;
|
||
|
||
let sql = if query.status.is_some() {
|
||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
|
||
} else {
|
||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
|
||
};
|
||
|
||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
|
||
.bind(account_id);
|
||
|
||
if let Some(ref status) = query.status {
|
||
query_builder = query_builder.bind(status);
|
||
}
|
||
|
||
query_builder = query_builder.bind(page_size).bind(offset);
|
||
|
||
let rows = query_builder.fetch_all(db).await?;
|
||
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
||
}).collect())
|
||
}
|
||
|
||
pub async fn update_task_status(
|
||
db: &SqlitePool, task_id: &str, status: &str,
|
||
input_tokens: Option<i64>, output_tokens: Option<i64>,
|
||
error_message: Option<&str>,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now().to_rfc3339();
|
||
|
||
let update_sql = match status {
|
||
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
|
||
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
|
||
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
|
||
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
|
||
};
|
||
|
||
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
|
||
|
||
let mut query = sqlx::query(&sql).bind(&now);
|
||
if status == "completed" {
|
||
query = query.bind(input_tokens).bind(output_tokens);
|
||
}
|
||
if status == "failed" {
|
||
query = query.bind(error_message);
|
||
}
|
||
query = query.bind(task_id);
|
||
query.execute(db).await?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
// ============ Relay Execution ============
|
||
|
||
pub async fn execute_relay(
|
||
db: &SqlitePool,
|
||
task_id: &str,
|
||
provider_base_url: &str,
|
||
provider_api_key: Option<&str>,
|
||
request_body: &str,
|
||
stream: bool,
|
||
) -> SaasResult<RelayResponse> {
|
||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||
|
||
// SSRF 防护: 验证 URL scheme 和禁止内网地址
|
||
validate_provider_url(provider_base_url)?;
|
||
|
||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||
let _start = std::time::Instant::now();
|
||
|
||
let client = reqwest::Client::builder()
|
||
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
|
||
.build()
|
||
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
||
let mut req_builder = client.post(&url)
|
||
.header("Content-Type", "application/json")
|
||
.body(request_body.to_string());
|
||
|
||
if let Some(key) = provider_api_key {
|
||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||
}
|
||
|
||
let result = req_builder.send().await;
|
||
|
||
match result {
|
||
Ok(resp) if resp.status().is_success() => {
|
||
if stream {
|
||
// 真实 SSE 流式: 使用 bytes_stream 而非 text().await 缓冲
|
||
let stream = resp.bytes_stream()
|
||
.map(|result| result.map_err(std::io::Error::other));
|
||
let body = axum::body::Body::from_stream(stream);
|
||
// 流式模式下无法提取 token usage,标记为 completed (usage=0)
|
||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||
Ok(RelayResponse::Sse(body))
|
||
} else {
|
||
let body = resp.text().await.unwrap_or_default();
|
||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||
update_task_status(db, task_id, "completed",
|
||
Some(input_tokens), Some(output_tokens), None).await?;
|
||
Ok(RelayResponse::Json(body))
|
||
}
|
||
}
|
||
Ok(resp) => {
|
||
let status = resp.status().as_u16();
|
||
let body = resp.text().await.unwrap_or_default();
|
||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||
Err(SaasError::Relay(err_msg))
|
||
}
|
||
Err(e) => {
|
||
let err_msg = format!("请求上游失败: {}", e);
|
||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||
Err(SaasError::Relay(err_msg))
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 中转响应类型
|
||
#[derive(Debug)]
|
||
pub enum RelayResponse {
|
||
Json(String),
|
||
Sse(axum::body::Body),
|
||
}
|
||
|
||
// ============ Helpers ============
|
||
|
||
fn hash_request(body: &str) -> String {
|
||
use sha2::{Sha256, Digest};
|
||
hex::encode(Sha256::digest(body.as_bytes()))
|
||
}
|
||
|
||
fn extract_token_usage(body: &str) -> (i64, i64) {
|
||
let parsed: serde_json::Value = match serde_json::from_str(body) {
|
||
Ok(v) => v,
|
||
Err(_) => return (0, 0),
|
||
};
|
||
|
||
let usage = parsed.get("usage");
|
||
let input = usage
|
||
.and_then(|u| u.get("prompt_tokens"))
|
||
.and_then(|v| v.as_i64())
|
||
.unwrap_or(0);
|
||
let output = usage
|
||
.and_then(|u| u.get("completion_tokens"))
|
||
.and_then(|v| v.as_i64())
|
||
.unwrap_or(0);
|
||
|
||
(input, output)
|
||
}
|
||
|
||
/// SSRF 防护: 验证 provider URL 不指向内网
|
||
fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||
let parsed: url::Url = url.parse().map_err(|_| {
|
||
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
|
||
})?;
|
||
|
||
// 只允许 https
|
||
match parsed.scheme() {
|
||
"https" => {}
|
||
"http" => {
|
||
// 开发环境允许 http
|
||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||
.map(|v| v == "true" || v == "1")
|
||
.unwrap_or(false);
|
||
if !is_dev {
|
||
return Err(SaasError::InvalidInput("生产环境禁止 http scheme,请使用 https".into()));
|
||
}
|
||
}
|
||
_ => return Err(SaasError::InvalidInput(format!("不允许的 URL scheme: {}", parsed.scheme()))),
|
||
}
|
||
|
||
// 禁止内网地址
|
||
let host = match parsed.host_str() {
|
||
Some(h) => h,
|
||
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
|
||
};
|
||
|
||
// 精确匹配的阻止列表
|
||
let blocked_exact = [
|
||
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
|
||
"0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal",
|
||
"10.0.0.1", "172.16.0.1", "192.168.0.1",
|
||
];
|
||
if blocked_exact.contains(&host) {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
||
}
|
||
|
||
// 后缀匹配 (阻止子域名)
|
||
let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"];
|
||
for suffix in &blocked_suffixes {
|
||
if host.ends_with(&format!(".{}", suffix)) {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
||
}
|
||
}
|
||
|
||
// 阻止 IPv4 私有网段 (通过解析 IP)
|
||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||
if is_private_ip(&ip) {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
|
||
}
|
||
}
|
||
|
||
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
|
||
if host.parse::<u64>().is_ok() {
|
||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 检查 IP 是否属于私有/内网地址范围
|
||
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||
match ip {
|
||
std::net::IpAddr::V4(v4) => {
|
||
let octets = v4.octets();
|
||
// 10.0.0.0/8
|
||
octets[0] == 10
|
||
// 172.16.0.0/12
|
||
|| (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31)
|
||
// 192.168.0.0/16
|
||
|| (octets[0] == 192 && octets[1] == 168)
|
||
// 127.0.0.0/8 (loopback)
|
||
|| octets[0] == 127
|
||
// 169.254.0.0/16 (link-local)
|
||
|| (octets[0] == 169 && octets[1] == 254)
|
||
// 0.0.0.0/8
|
||
|| octets[0] == 0
|
||
}
|
||
std::net::IpAddr::V6(v6) => {
|
||
// ::1 (loopback)
|
||
v6.is_loopback()
|
||
// ::ffff:x.x.x.x (IPv6-mapped IPv4)
|
||
|| v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4)))
|
||
// fe80::/10 (link-local)
|
||
|| (v6.segments()[0] & 0xffc0) == 0xfe80
|
||
}
|
||
}
|
||
}
|