//! 中转服务核心逻辑 use sqlx::SqlitePool; use crate::error::{SaasError, SaasResult}; use super::types::*; // ============ Relay Task Management ============ pub async fn create_relay_task( db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str, request_body: &str, priority: i64, ) -> SaasResult { let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().to_rfc3339(); let request_hash = hash_request(request_body); sqlx::query( "INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)" ) .bind(&id).bind(account_id).bind(provider_id).bind(model_id) .bind(&request_hash).bind(request_body).bind(priority).bind(&now) .execute(db).await?; get_relay_task(db, &id).await } pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult { let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option, String, Option, Option, String)> = sqlx::query_as( "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at FROM relay_tasks WHERE id = ?1" ) .bind(task_id) .fetch_optional(db) .await?; let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) = row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?; Ok(RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at, }) } pub async fn list_relay_tasks( db: &SqlitePool, account_id: &str, query: &RelayTaskQuery, ) -> SaasResult> { let page = query.page.unwrap_or(1).max(1); let page_size = query.page_size.unwrap_or(20).min(100); let offset = (page - 1) * page_size; let sql = if query.status.is_some() { "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4" } else { "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3" }; let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option, String, Option, Option, String)>(sql) .bind(account_id); if let Some(ref status) = query.status { query_builder = query_builder.bind(status); } query_builder = query_builder.bind(page_size).bind(offset); let rows = query_builder.fetch_all(db).await?; Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| { RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at } }).collect()) } pub async fn update_task_status( db: &SqlitePool, task_id: &str, status: &str, input_tokens: Option, output_tokens: Option, error_message: Option<&str>, ) -> SaasResult<()> { let now = chrono::Utc::now().to_rfc3339(); let update_sql = match status { "processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1", "completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)", "failed" => "completed_at = ?1, status = 'failed', error_message = ?2", _ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))), }; let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql); let mut query = sqlx::query(&sql).bind(&now); if status == "completed" { query = query.bind(input_tokens).bind(output_tokens); } if status == "failed" { query = query.bind(error_message); } query = query.bind(task_id); query.execute(db).await?; Ok(()) } // ============ Relay Execution ============ pub async fn execute_relay( db: &SqlitePool, task_id: &str, provider_base_url: &str, provider_api_key: Option<&str>, request_body: &str, stream: bool, ) -> SaasResult { update_task_status(db, task_id, "processing", None, None, None).await?; let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/')); let _start = std::time::Instant::now(); let client = reqwest::Client::new(); let mut req_builder = client.post(&url) .header("Content-Type", "application/json") .body(request_body.to_string()); if let Some(key) = provider_api_key { req_builder = req_builder.header("Authorization", format!("Bearer {}", key)); } let result = req_builder.send().await; match result { Ok(resp) if resp.status().is_success() => { if stream { let body = resp.text().await.unwrap_or_default(); update_task_status(db, task_id, "completed", None, None, None).await?; Ok(RelayResponse::Sse(body)) } else { let body = resp.text().await.unwrap_or_default(); let (input_tokens, output_tokens) = extract_token_usage(&body); update_task_status(db, task_id, "completed", Some(input_tokens), Some(output_tokens), None).await?; Ok(RelayResponse::Json(body)) } } Ok(resp) => { let status = resp.status().as_u16(); let body = resp.text().await.unwrap_or_default(); let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]); update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; Err(SaasError::Relay(err_msg)) } Err(e) => { let err_msg = format!("请求上游失败: {}", e); update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; Err(SaasError::Relay(err_msg)) } } } /// 中转响应类型 #[derive(Debug)] pub enum RelayResponse { Json(String), Sse(String), } // ============ Helpers ============ fn hash_request(body: &str) -> String { use sha2::{Sha256, Digest}; hex::encode(Sha256::digest(body.as_bytes())) } fn extract_token_usage(body: &str) -> (i64, i64) { let parsed: serde_json::Value = match serde_json::from_str(body) { Ok(v) => v, Err(_) => return (0, 0), }; let usage = parsed.get("usage"); let input = usage .and_then(|u| u.get("prompt_tokens")) .and_then(|v| v.as_i64()) .unwrap_or(0); let output = usage .and_then(|u| u.get("completion_tokens")) .and_then(|v| v.as_i64()) .unwrap_or(0); (input, output) }