//! 中转服务 HTTP 处理器 use axum::{ extract::{Extension, Path, Query, State}, http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, Json, }; use crate::state::AppState; use crate::error::{SaasError, SaasResult}; use crate::auth::types::AuthContext; use crate::auth::handlers::check_permission; use crate::model_config::service as model_service; use super::{types::*, service}; /// POST /api/v1/relay/chat/completions /// OpenAI 兼容的聊天补全端点 pub async fn chat_completions( State(state): State, Extension(ctx): Extension, _headers: HeaderMap, Json(req): Json, ) -> SaasResult { check_permission(&ctx, "relay:use")?; // 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询 let max_queue_size = { let config = state.config.read().await; config.relay.max_queue_size }; let queued_count = state.cache.relay_queue_count(&ctx.account_id); if queued_count >= max_queue_size as i64 { return Err(SaasError::RateLimited( format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count) )); } // --- 输入验证 --- // 请求体大小限制 (1 MB) — 直接序列化一次,后续复用 const MAX_BODY_BYTES: usize = 1024 * 1024; let request_body = serde_json::to_string(&req) .map_err(|e| SaasError::InvalidInput(format!("请求体序列化失败: {}", e)))?; if request_body.len() > MAX_BODY_BYTES { return Err(SaasError::InvalidInput( format!("请求体超过大小限制 ({} bytes > {} bytes)", request_body.len(), MAX_BODY_BYTES) )); } // model 字段 let model_name = req.get("model") .and_then(|v| v.as_str()) .ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?; // messages 字段:必须存在且为非空数组 let messages = req.get("messages") .ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?; let messages_arr = messages.as_array() .ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?; if messages_arr.is_empty() { return Err(SaasError::InvalidInput("messages 数组不能为空".into())); } // 验证每个 message 的 role 和 content let valid_roles = ["system", "user", "assistant", "tool"]; for (i, msg) in messages_arr.iter().enumerate() { let role = msg.get("role") .and_then(|v| v.as_str()) .ok_or_else(|| SaasError::InvalidInput( format!("messages[{}] 缺少 role 字段", i) ))?; if !valid_roles.contains(&role) { return Err(SaasError::InvalidInput( format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role) )); } let content = msg.get("content") .ok_or_else(|| SaasError::InvalidInput( format!("messages[{}] 缺少 content 字段", i) ))?; // content 必须是字符串或数组 (多模态) if !content.is_string() && !content.is_array() { return Err(SaasError::InvalidInput( format!("messages[{}] 的 content 必须是字符串或数组", i) )); } } // temperature 范围校验 if let Some(temp) = req.get("temperature") { match temp.as_f64() { Some(t) if t < 0.0 || t > 2.0 => { return Err(SaasError::InvalidInput( format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t) )); } Some(_) => {} // valid None => { return Err(SaasError::InvalidInput("temperature 必须是数字".into())); } } } // max_tokens 范围校验 if let Some(tokens) = req.get("max_tokens") { match tokens.as_u64() { Some(t) if t < 1 || t > 128000 => { return Err(SaasError::InvalidInput( format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t) )); } Some(_) => {} // valid None => { return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into())); } } } // --- 输入验证结束 --- let stream = req.get("stream") .and_then(|v| v.as_bool()) .unwrap_or(false); // 查找 model — 优先检查模型组(跨 Provider Failover),回退到直接模型查找 let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) { // 逻辑模型组:构建候选列表 let mut candidates: Vec = Vec::new(); for member in &group.members { if !member.enabled { continue; } let provider = match state.cache.get_provider(&member.provider_id) { Some(p) => p, None => continue, }; let physical_model = match state.cache.get_model(&member.model_id) { Some(m) => m, None => continue, }; candidates.push(CandidateModel { provider_id: member.provider_id.clone(), model_id: member.model_id.clone(), base_url: provider.base_url.clone(), supports_streaming: physical_model.supports_streaming, }); } if candidates.is_empty() { return Err(SaasError::NotFound( format!("模型组 '{}' 没有可用的候选 Provider", model_name) )); } ModelResolution::Group(candidates) } else { // 向后兼容:直接模型查找 + 别名解析(如 "glm-4-flash" → "glm-4-flash-250414") let target_model = state.cache.resolve_model(model_name) .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; // 获取 provider 信息 — 使用内存缓存消除 DB 查询 let provider = state.cache.get_provider(&target_model.provider_id) .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?; if !provider.enabled { return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); } ModelResolution::Direct(CandidateModel { provider_id: target_model.provider_id.clone(), model_id: target_model.model_id.clone(), base_url: provider.base_url.clone(), supports_streaming: target_model.supports_streaming, }) }; // Stream compatibility check if stream && model_resolution.any_non_streaming() { return Err(SaasError::InvalidInput( format!("模型 {} 不支持流式响应,请使用 stream: false", model_name) )); } // request_body 已在前面序列化并验证大小,直接复用 // 创建中转任务(提取配置后立即释放读锁) let (max_attempts, retry_delay_ms, enc_key) = { let config = state.config.read().await; let key = config.api_key_encryption_key() .map_err(|e| SaasError::Internal(e.to_string()))?; (config.relay.max_attempts, config.relay.retry_delay_ms, key) }; let task = service::create_relay_task( &state.db, &ctx.account_id, model_resolution.first_provider_id(), model_resolution.first_model_id(), &request_body, 0, max_attempts, ).await?; // 递增内存队列计数器(替代 DB COUNT 查询) state.cache.relay_enqueue(&ctx.account_id); // 异步派发操作日志(非阻塞,不占用关键路径 DB 连接) // P3-06: Include session_key/agent_id in log for traceability let log_meta = serde_json::json!({ "model": model_name, "stream": stream, "session_key": req.get("session_key").and_then(|v| v.as_str()), "agent_id": req.get("agent_id").and_then(|v| v.as_str()), }); state.dispatch_log_operation( &ctx.account_id, "relay.request", "relay_task", &task.id, Some(log_meta), ctx.client_ip.as_deref(), ).await; // 执行中转:根据解析结果选择执行路径 // C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因 let relay_result = match model_resolution { ModelResolution::Direct(ref candidate) => { // 单 Provider 直接路由(向后兼容) match service::execute_relay( &state.db, &task.id, &ctx.account_id, &candidate.provider_id, &candidate.base_url, &request_body, stream, max_attempts, retry_delay_ms, &enc_key, true, // 独立调用,管理 task 状态 ).await { Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())), Err(e) => Err(e), } } ModelResolution::Group(ref mut candidates) => { // 跨 Provider Failover(按配额余量自动排序) // 注意: Failover 仅适用于预流失败(连接错误、429/5xx 在流开始前)。 // SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。 service::sort_candidates_by_quota(&state.db, candidates).await; service::execute_relay_with_failover( &state.db, &task.id, &ctx.account_id, candidates, &request_body, stream, max_attempts, retry_delay_ms, &enc_key ).await } }; // 失败时:记录 failure usage + 递减队列计数器(失败请求不计费) let (response, actual_provider_id, actual_model_id) = match relay_result { Ok(triple) => triple, Err(e) => { // 通过 Worker dispatch 记录 failure usage { let args = crate::workers::record_usage::RecordUsageArgs { account_id: ctx.account_id.clone(), provider_id: model_resolution.first_provider_id().to_string(), model_id: model_resolution.first_model_id().to_string(), input_tokens: 0, output_tokens: 0, latency_ms: None, status: "failed".to_string(), error_message: Some(e.to_string()), }; if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await { tracing::warn!("Failed to dispatch failure usage: {}", e2); } } // 递减队列计数器(防止队列计数泄漏 → 连接池耗尽) state.cache.relay_dequeue(&ctx.account_id); return Err(e); } }; // 使用实际服务的 provider/model 进行计费归因 let account_id_usage = ctx.account_id.clone(); let provider_id_usage = actual_provider_id; let model_id_usage = actual_model_id; match response { service::RelayResponse::Json(body) => { let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body); // 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应) { let args = crate::workers::record_usage::RecordUsageArgs { account_id: account_id_usage.clone(), provider_id: provider_id_usage.clone(), model_id: model_id_usage.clone(), input_tokens: input_tokens as i32, output_tokens: output_tokens as i32, latency_ms: None, status: "success".to_string(), error_message: None, }; if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await { tracing::warn!("Failed to dispatch record_usage: {}", e); } } // 实时更新计费配额(relay_requests + tokens 同步递增) if let Err(e) = crate::billing::service::increment_usage( &state.db, &account_id_usage, input_tokens as i64, output_tokens as i64, ).await { tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e); } // 任务完成,递减队列计数器 state.cache.relay_dequeue(&account_id_usage); Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response()) } service::RelayResponse::Sse(body) => { // 通过 Worker dispatch 记录 SSE 占位 usage { let args = crate::workers::record_usage::RecordUsageArgs { account_id: account_id_usage.clone(), provider_id: provider_id_usage.clone(), model_id: model_id_usage.clone(), input_tokens: 0, output_tokens: 0, latency_ms: None, status: "streaming".to_string(), error_message: None, }; if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await { tracing::warn!("Failed to dispatch SSE usage: {}", e); } } // SSE: relay_requests 实时递增(tokens 由 AggregateUsageWorker 对账修正) if let Err(e) = crate::billing::service::increment_dimension( &state.db, &account_id_usage, "relay_requests", ).await { tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e); } // SSE 流已返回,递减队列计数器(流式任务开始处理) state.cache.relay_dequeue(&account_id_usage); let response = axum::response::Response::builder() .status(StatusCode::OK) .header(axum::http::header::CONTENT_TYPE, "text/event-stream") .header("Cache-Control", "no-cache") .header("Connection", "keep-alive") .body(body) .expect("SSE response builder with valid status/headers cannot fail"); Ok(response) } } } /// GET /api/v1/relay/tasks pub async fn list_tasks( State(state): State, Extension(ctx): Extension, Query(query): Query, ) -> SaasResult>> { service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json) } /// GET /api/v1/relay/tasks/:id pub async fn get_task( State(state): State, Path(id): Path, Extension(ctx): Extension, ) -> SaasResult> { let task = service::get_relay_task(&state.db, &id).await?; // 只允许查看自己的任务 (admin 可查看全部) if task.account_id != ctx.account_id { check_permission(&ctx, "relay:admin")?; } Ok(Json(task)) } /// GET /api/v1/relay/models /// 列出可用的中转模型 (enabled providers + enabled models) pub async fn list_available_models( State(state): State, _ctx: Extension, ) -> SaasResult>> { // 单次 JOIN 查询替代 2 次全量加载 let rows: Vec<(String, String, String, i64, i64, bool, bool, bool, String)> = sqlx::query_as( "SELECT m.model_id, m.provider_id, m.alias, m.context_window, m.max_output_tokens, m.supports_streaming, m.supports_vision, m.is_embedding, m.model_type FROM models m INNER JOIN providers p ON m.provider_id = p.id WHERE m.enabled = true AND p.enabled = true ORDER BY m.provider_id, m.model_id" ) .fetch_all(&state.db) .await?; let mut available: Vec = rows.into_iter() .map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, is_embedding, model_type)| { serde_json::json!({ "id": model_id, "provider_id": provider_id, "alias": alias, "context_window": context_window, "max_output_tokens": max_output_tokens, "supports_streaming": supports_streaming, "supports_vision": supports_vision, "is_embedding": is_embedding, "model_type": model_type, }) }) .collect(); // 追加模型组(逻辑模型),使前端能展示和选择 for entry in state.cache.model_groups.iter() { let group = entry.value(); if !group.enabled { continue; } // H-2: 过滤无可用成员的模型组,避免前端选择后请求失败 let active_members: Vec<_> = group.members.iter() .filter(|m| m.enabled) .collect(); if active_members.is_empty() { continue; } // 所有 active 成员都支持 streaming → 模型组支持 streaming let all_streaming = active_members.iter().all(|m| { state.cache.get_model(&m.model_id) .map(|cm| cm.supports_streaming) .unwrap_or(true) }); // 任一 active 成员支持 vision → 模型组支持 vision let any_vision = active_members.iter().any(|m| { state.cache.get_model(&m.model_id) .map(|cm| cm.supports_vision) .unwrap_or(false) }); available.push(serde_json::json!({ "id": group.name, "provider_id": "group", "alias": group.display_name, "is_group": true, "member_count": group.members.len(), "supports_streaming": all_streaming, "supports_vision": any_vision, })); } Ok(Json(available)) } /// POST /api/v1/relay/tasks/:id/retry (admin only) /// 重试失败的中转任务 pub async fn retry_task( State(state): State, Path(id): Path, Extension(ctx): Extension, ) -> SaasResult> { check_permission(&ctx, "relay:admin")?; let task = service::get_relay_task(&state.db, &id).await?; if task.status != "failed" { return Err(SaasError::InvalidInput(format!( "只能重试失败的任务,当前状态: {}", task.status ))); } // 读取原始请求体 let request_body: Option = sqlx::query_scalar( "SELECT request_body FROM relay_tasks WHERE id = $1" ) .bind(&id) .fetch_optional(&state.db) .await? .flatten(); let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?; // 从 request body 解析 stream 标志和 model 字段 let parsed_body: Option = serde_json::from_str(&body).ok(); let stream: bool = parsed_body.as_ref() .and_then(|v| v.get("stream").and_then(|s| s.as_bool())) .unwrap_or(false); let model_name: Option = parsed_body.as_ref() .and_then(|v| v.get("model").and_then(|m| m.as_str()).map(|s| s.to_string())); // H-8: 重新解析模型组 — 如果原始请求使用模型组,重试时走 failover 路径 // 而不是盲目使用存储的(可能已失败的)provider_id let mut model_resolution = if let Some(ref name) = model_name { if let Some(group) = state.cache.get_model_group(name) { // 模型组:构建候选列表 let mut candidates: Vec = Vec::new(); for member in &group.members { if !member.enabled { continue; } let provider = match state.cache.get_provider(&member.provider_id) { Some(p) => p, None => continue, }; let physical_model = match state.cache.get_model(&member.model_id) { Some(m) => m, None => continue, }; candidates.push(CandidateModel { provider_id: member.provider_id.clone(), model_id: member.model_id.clone(), base_url: provider.base_url.clone(), supports_streaming: physical_model.supports_streaming, }); } if candidates.is_empty() { return Err(SaasError::NotFound( format!("模型组 '{}' 没有可用的候选 Provider(重试时解析)", name) )); } ModelResolution::Group(candidates) } else if let Some(target_model) = state.cache.get_model(name) { // 直接模型 let provider = state.cache.get_provider(&target_model.provider_id) .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在(重试时解析)", target_model.provider_id)))?; ModelResolution::Direct(CandidateModel { provider_id: target_model.provider_id.clone(), model_id: target_model.model_id.clone(), base_url: provider.base_url.clone(), supports_streaming: target_model.supports_streaming, }) } else { // 无法解析,回退到存储的 provider_id(向后兼容) let provider = model_service::get_provider(&state.db, &task.provider_id).await?; ModelResolution::Direct(CandidateModel { provider_id: task.provider_id.clone(), model_id: task.model_id.clone(), base_url: provider.base_url.clone(), supports_streaming: true, }) } } else { // 无 model 字段,回退到存储的 provider_id let provider = model_service::get_provider(&state.db, &task.provider_id).await?; ModelResolution::Direct(CandidateModel { provider_id: task.provider_id.clone(), model_id: task.model_id.clone(), base_url: provider.base_url.clone(), supports_streaming: true, }) }; let max_attempts = task.max_attempts as u32; let config = state.config.read().await; let base_delay_ms = config.relay.retry_delay_ms; let enc_key = config.api_key_encryption_key() .map_err(|e| SaasError::Internal(e.to_string()))?; drop(config); // 重置任务状态为 queued 以允许新的 processing sqlx::query( "UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1" ) .bind(&id) .execute(&state.db) .await?; // 异步执行重试 — 根据解析结果选择执行路径 let db = state.db.clone(); let task_id = id.clone(); let account_id_for_spawn = task.account_id.clone(); let handle = tokio::spawn(async move { let result = match model_resolution { ModelResolution::Direct(ref candidate) => { service::execute_relay( &db, &task_id, &account_id_for_spawn, &candidate.provider_id, &candidate.base_url, &body, stream, max_attempts, base_delay_ms, &enc_key, true, ).await } ModelResolution::Group(ref mut candidates) => { service::sort_candidates_by_quota(&db, candidates).await; service::execute_relay_with_failover( &db, &task_id, &account_id_for_spawn, candidates, &body, stream, max_attempts, base_delay_ms, &enc_key, ).await .map(|(resp, _, _)| resp) } }; match result { Ok(_) => tracing::info!("Relay task {} 重试成功", task_id), Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e), } }); // Detach with warning — if server shuts down mid-retry, the task is lost. // The DB status is already reset to 'queued', so a future restart can pick it up. tokio::spawn(async move { if let Err(e) = handle.await { tracing::warn!("Relay retry task aborted (server shutdown?): {}", e); } }); // 异步派发操作日志 state.dispatch_log_operation( &ctx.account_id, "relay.retry", "relay_task", &id, None, ctx.client_ip.as_deref(), ).await; Ok(Json(serde_json::json!({"ok": true, "task_id": id}))) } // ============ Key Pool 管理 (admin only) ============ /// GET /api/v1/providers/:provider_id/keys pub async fn list_provider_keys( State(state): State, Extension(ctx): Extension, Path(provider_id): Path, ) -> SaasResult>> { check_permission(&ctx, "provider:manage")?; let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?; Ok(Json(keys)) } /// POST /api/v1/providers/:provider_id/keys #[derive(serde::Deserialize)] pub struct AddKeyRequest { pub key_label: String, pub key_value: String, #[serde(default)] pub priority: i32, pub max_rpm: Option, pub max_tpm: Option, } pub async fn add_provider_key( State(state): State, Extension(ctx): Extension, Path(provider_id): Path, Json(req): Json, ) -> SaasResult> { check_permission(&ctx, "provider:manage")?; if req.key_label.trim().is_empty() { return Err(SaasError::InvalidInput("key_label 不能为空".into())); } if req.key_value.trim().is_empty() { return Err(SaasError::InvalidInput("key_value 不能为空".into())); } if req.key_value.len() < 20 { return Err(SaasError::InvalidInput("key_value 长度不足(至少 20 字符)".into())); } if req.key_value.contains(char::is_whitespace) { return Err(SaasError::InvalidInput("key_value 不能包含空白字符".into())); } // Encrypt the API key before storing in database let enc_key = state.config.read().await.api_key_encryption_key()?; let encrypted_value = crate::crypto::encrypt_value(&req.key_value, &enc_key)?; let key_id = super::key_pool::add_provider_key( &state.db, &provider_id, &req.key_label, &encrypted_value, req.priority, req.max_rpm, req.max_tpm, ).await?; // 异步派发操作日志 state.dispatch_log_operation( &ctx.account_id, "provider_key.add", "provider_key", &key_id, Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})), ctx.client_ip.as_deref(), ).await; Ok(Json(serde_json::json!({"ok": true, "key_id": key_id}))) } /// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle #[derive(serde::Deserialize)] pub struct ToggleKeyRequest { pub active: bool, } pub async fn toggle_provider_key( State(state): State, Extension(ctx): Extension, Path((provider_id, key_id)): Path<(String, String)>, Json(req): Json, ) -> SaasResult> { check_permission(&ctx, "provider:manage")?; super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?; // 异步派发操作日志 state.dispatch_log_operation( &ctx.account_id, "provider_key.toggle", "provider_key", &key_id, Some(serde_json::json!({"provider_id": provider_id, "active": req.active})), ctx.client_ip.as_deref(), ).await; Ok(Json(serde_json::json!({"ok": true}))) } /// DELETE /api/v1/providers/:provider_id/keys/:key_id pub async fn delete_provider_key( State(state): State, Extension(ctx): Extension, Path((provider_id, key_id)): Path<(String, String)>, ) -> SaasResult> { check_permission(&ctx, "provider:manage")?; super::key_pool::delete_provider_key(&state.db, &key_id).await?; // 异步派发操作日志 state.dispatch_log_operation( &ctx.account_id, "provider_key.delete", "provider_key", &key_id, Some(serde_json::json!({"provider_id": provider_id})), ctx.client_ip.as_deref(), ).await; Ok(Json(serde_json::json!({"ok": true}))) }