//! 中转服务 HTTP 处理器 use axum::{ extract::{Extension, Path, Query, State}, http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, Json, }; use crate::state::AppState; use crate::error::{SaasError, SaasResult}; use crate::auth::types::AuthContext; use crate::auth::handlers::{log_operation, check_permission}; use crate::model_config::service as model_service; use super::{types::*, service}; /// POST /api/v1/relay/chat/completions /// OpenAI 兼容的聊天补全端点 pub async fn chat_completions( State(state): State, Extension(ctx): Extension, _headers: HeaderMap, Json(req): Json, ) -> SaasResult { check_permission(&ctx, "relay:use")?; // 队列容量检查:防止过载(立即释放读锁) let max_queue_size = { let config = state.config.read().await; config.relay.max_queue_size }; let queued_count: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')" ) .bind(&ctx.account_id) .fetch_one(&state.db) .await .unwrap_or(0); if queued_count >= max_queue_size as i64 { return Err(SaasError::RateLimited( format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count) )); } // --- 输入验证 --- // 请求体大小限制 (1 MB) const MAX_BODY_BYTES: usize = 1024 * 1024; let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0); if estimated_size > MAX_BODY_BYTES { return Err(SaasError::InvalidInput( format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES) )); } // model 字段 let model_name = req.get("model") .and_then(|v| v.as_str()) .ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?; // messages 字段:必须存在且为非空数组 let messages = req.get("messages") .ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?; let messages_arr = messages.as_array() .ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?; if messages_arr.is_empty() { return Err(SaasError::InvalidInput("messages 数组不能为空".into())); } // 验证每个 message 的 role 和 content let valid_roles = ["system", "user", "assistant", "tool"]; for (i, msg) in messages_arr.iter().enumerate() { let role = msg.get("role") .and_then(|v| v.as_str()) .ok_or_else(|| SaasError::InvalidInput( format!("messages[{}] 缺少 role 字段", i) ))?; if !valid_roles.contains(&role) { return Err(SaasError::InvalidInput( format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role) )); } let content = msg.get("content") .ok_or_else(|| SaasError::InvalidInput( format!("messages[{}] 缺少 content 字段", i) ))?; // content 必须是字符串或数组 (多模态) if !content.is_string() && !content.is_array() { return Err(SaasError::InvalidInput( format!("messages[{}] 的 content 必须是字符串或数组", i) )); } } // temperature 范围校验 if let Some(temp) = req.get("temperature") { match temp.as_f64() { Some(t) if t < 0.0 || t > 2.0 => { return Err(SaasError::InvalidInput( format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t) )); } Some(_) => {} // valid None => { return Err(SaasError::InvalidInput("temperature 必须是数字".into())); } } } // max_tokens 范围校验 if let Some(tokens) = req.get("max_tokens") { match tokens.as_u64() { Some(t) if t < 1 || t > 128000 => { return Err(SaasError::InvalidInput( format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t) )); } Some(_) => {} // valid None => { return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into())); } } } // --- 输入验证结束 --- let stream = req.get("stream") .and_then(|v| v.as_bool()) .unwrap_or(false); // 查找 model 对应的 provider — 使用精准查询避免全量加载 let target_model: Option = sqlx::query_as( "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at FROM models WHERE model_id = $1 AND enabled = true LIMIT 1" ) .bind(&model_name) .fetch_optional(&state.db) .await?; let target_model = target_model .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; // 获取 provider 信息 let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?; if !provider.enabled { return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); } let request_body = serde_json::to_string(&req)?; // 创建中转任务(提取配置后立即释放读锁) let (max_attempts, retry_delay_ms, enc_key) = { let config = state.config.read().await; let key = config.api_key_encryption_key() .map_err(|e| SaasError::Internal(e.to_string()))?; (config.relay.max_attempts, config.relay.retry_delay_ms, key) }; let task = service::create_relay_task( &state.db, &ctx.account_id, &target_model.provider_id, &target_model.model_id, &request_body, 0, max_attempts, ).await?; log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id, Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?; // 执行中转 (Key Pool 自动选择 + 429 轮转) let response = service::execute_relay( &state.db, &task.id, &target_model.provider_id, &provider.base_url, &request_body, stream, max_attempts, retry_delay_ms, &enc_key, ).await; match response { Ok(service::RelayResponse::Json(body)) => { let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body); model_service::record_usage( &state.db, &ctx.account_id, &target_model.provider_id, &target_model.model_id, input_tokens, output_tokens, None, "success", None, ).await?; Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response()) } Ok(service::RelayResponse::Sse(body)) => { // SSE 流的 usage 统计在 service 层异步处理 // 这里先记录一个占位记录,实际值会在流结束后更新 model_service::record_usage( &state.db, &ctx.account_id, &target_model.provider_id, &target_model.model_id, 0, 0, None, "streaming", None, ).await?; let response = axum::response::Response::builder() .status(StatusCode::OK) .header(axum::http::header::CONTENT_TYPE, "text/event-stream") .header("Cache-Control", "no-cache") .header("Connection", "keep-alive") .body(body) .unwrap(); Ok(response) } Err(e) => { model_service::record_usage( &state.db, &ctx.account_id, &target_model.provider_id, &target_model.model_id, 0, 0, None, "failed", Some(&e.to_string()), ).await?; Err(e) } } } /// GET /api/v1/relay/tasks pub async fn list_tasks( State(state): State, Extension(ctx): Extension, Query(query): Query, ) -> SaasResult>> { service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json) } /// GET /api/v1/relay/tasks/:id pub async fn get_task( State(state): State, Path(id): Path, Extension(ctx): Extension, ) -> SaasResult> { let task = service::get_relay_task(&state.db, &id).await?; // 只允许查看自己的任务 (admin 可查看全部) if task.account_id != ctx.account_id { check_permission(&ctx, "relay:admin")?; } Ok(Json(task)) } /// GET /api/v1/relay/models /// 列出可用的中转模型 (enabled providers + enabled models) pub async fn list_available_models( State(state): State, _ctx: Extension, ) -> SaasResult>> { // 单次 JOIN 查询替代 2 次全量加载 let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as( "SELECT m.model_id, m.provider_id, m.alias, m.context_window, m.max_output_tokens, m.supports_streaming, m.supports_vision FROM models m INNER JOIN providers p ON m.provider_id = p.id WHERE m.enabled = true AND p.enabled = true ORDER BY m.provider_id, m.model_id" ) .fetch_all(&state.db) .await?; let available: Vec = rows.into_iter() .map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| { serde_json::json!({ "id": model_id, "provider_id": provider_id, "alias": alias, "context_window": context_window, "max_output_tokens": max_output_tokens, "supports_streaming": supports_streaming, "supports_vision": supports_vision, }) }) .collect(); Ok(Json(available)) } /// POST /api/v1/relay/tasks/:id/retry (admin only) /// 重试失败的中转任务 pub async fn retry_task( State(state): State, Path(id): Path, Extension(ctx): Extension, ) -> SaasResult> { check_permission(&ctx, "relay:admin")?; let task = service::get_relay_task(&state.db, &id).await?; if task.status != "failed" { return Err(SaasError::InvalidInput(format!( "只能重试失败的任务,当前状态: {}", task.status ))); } // 获取 provider 信息 let provider = model_service::get_provider(&state.db, &task.provider_id).await?; // 读取原始请求体 let request_body: Option = sqlx::query_scalar( "SELECT request_body FROM relay_tasks WHERE id = $1" ) .bind(&id) .fetch_optional(&state.db) .await? .flatten(); let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?; // 从 request body 解析 stream 标志 let stream: bool = serde_json::from_str::(&body) .ok() .and_then(|v| v.get("stream").and_then(|s| s.as_bool())) .unwrap_or(false); let max_attempts = task.max_attempts as u32; let config = state.config.read().await; let base_delay_ms = config.relay.retry_delay_ms; let enc_key = config.api_key_encryption_key() .map_err(|e| SaasError::Internal(e.to_string()))?; // 重置任务状态为 queued 以允许新的 processing sqlx::query( "UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1" ) .bind(&id) .execute(&state.db) .await?; // 异步执行重试 (Key Pool 自动选择) let db = state.db.clone(); let task_id = id.clone(); let provider_id = task.provider_id.clone(); let base_url = provider.base_url.clone(); tokio::spawn(async move { match service::execute_relay( &db, &task_id, &provider_id, &base_url, &body, stream, max_attempts, base_delay_ms, &enc_key, ).await { Ok(_) => tracing::info!("Relay task {} 重试成功", task_id), Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e), } }); log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id, None, ctx.client_ip.as_deref()).await?; Ok(Json(serde_json::json!({"ok": true, "task_id": id}))) } // ============ Key Pool 管理 (admin only) ============ /// GET /api/v1/providers/:provider_id/keys pub async fn list_provider_keys( State(state): State, Extension(ctx): Extension, Path(provider_id): Path, ) -> SaasResult>> { check_permission(&ctx, "provider:manage")?; let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?; Ok(Json(keys)) } /// POST /api/v1/providers/:provider_id/keys #[derive(serde::Deserialize)] pub struct AddKeyRequest { pub key_label: String, pub key_value: String, #[serde(default)] pub priority: i32, pub max_rpm: Option, pub max_tpm: Option, pub quota_reset_interval: Option, } pub async fn add_provider_key( State(state): State, Extension(ctx): Extension, Path(provider_id): Path, Json(req): Json, ) -> SaasResult> { check_permission(&ctx, "provider:manage")?; if req.key_label.trim().is_empty() { return Err(SaasError::InvalidInput("key_label 不能为空".into())); } if req.key_value.trim().is_empty() { return Err(SaasError::InvalidInput("key_value 不能为空".into())); } let key_id = super::key_pool::add_provider_key( &state.db, &provider_id, &req.key_label, &req.key_value, req.priority, req.max_rpm, req.max_tpm, req.quota_reset_interval.as_deref(), ).await?; log_operation(&state.db, &ctx.account_id, "provider_key.add", "provider_key", &key_id, Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})), ctx.client_ip.as_deref()).await?; Ok(Json(serde_json::json!({"ok": true, "key_id": key_id}))) } /// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle #[derive(serde::Deserialize)] pub struct ToggleKeyRequest { pub active: bool, } pub async fn toggle_provider_key( State(state): State, Extension(ctx): Extension, Path((provider_id, key_id)): Path<(String, String)>, Json(req): Json, ) -> SaasResult> { check_permission(&ctx, "provider:manage")?; super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?; log_operation(&state.db, &ctx.account_id, "provider_key.toggle", "provider_key", &key_id, Some(serde_json::json!({"provider_id": provider_id, "active": req.active})), ctx.client_ip.as_deref()).await?; Ok(Json(serde_json::json!({"ok": true}))) } /// DELETE /api/v1/providers/:provider_id/keys/:key_id pub async fn delete_provider_key( State(state): State, Extension(ctx): Extension, Path((provider_id, key_id)): Path<(String, String)>, ) -> SaasResult> { check_permission(&ctx, "provider:manage")?; super::key_pool::delete_provider_key(&state.db, &key_id).await?; log_operation(&state.db, &ctx.account_id, "provider_key.delete", "provider_key", &key_id, Some(serde_json::json!({"provider_id": provider_id})), ctx.client_ip.as_deref()).await?; Ok(Json(serde_json::json!({"ok": true}))) }