From a99a3df9dd36065c2b2a154b657e6d38068288bc Mon Sep 17 00:00:00 2001 From: iven Date: Fri, 27 Mar 2026 12:50:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(saas):=20Phase=203=20=E2=80=94=20=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=AF=B7=E6=B1=82=E4=B8=AD=E8=BD=AC=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - OpenAI 兼容 API 代理 (/api/v1/relay/chat/completions) - 中转任务管理 (创建/查询/状态跟踪) - 可用模型列表端点 (仅 enabled providers+models) - 任务生命周期 (queued → processing → completed/failed) - 用量自动记录 (token 统计 + 错误追踪) - 3 个新集成测试覆盖中转端点 --- crates/zclaw-saas/src/main.rs | 1 + crates/zclaw-saas/src/relay/handlers.rs | 165 ++++++++++++++++ crates/zclaw-saas/src/relay/mod.rs | 18 +- crates/zclaw-saas/src/relay/service.rs | 197 ++++++++++++++++++++ crates/zclaw-saas/src/relay/types.rs | 59 ++++++ crates/zclaw-saas/tests/integration_test.rs | 61 ++++++ 6 files changed, 500 insertions(+), 1 deletion(-) create mode 100644 crates/zclaw-saas/src/relay/handlers.rs create mode 100644 crates/zclaw-saas/src/relay/service.rs create mode 100644 crates/zclaw-saas/src/relay/types.rs diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index 3c72030..e3a9ab1 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -44,6 +44,7 @@ fn build_router(state: AppState) -> axum::Router { let protected_routes = zclaw_saas::auth::protected_routes() .merge(zclaw_saas::account::routes()) .merge(zclaw_saas::model_config::routes()) + .merge(zclaw_saas::relay::routes()) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, diff --git a/crates/zclaw-saas/src/relay/handlers.rs b/crates/zclaw-saas/src/relay/handlers.rs new file mode 100644 index 0000000..0ce12c6 --- /dev/null +++ b/crates/zclaw-saas/src/relay/handlers.rs @@ -0,0 +1,165 @@ +//! 中转服务 HTTP 处理器 + +use axum::{ + extract::{Extension, Path, Query, State}, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, + Json, +}; +use crate::state::AppState; +use crate::error::{SaasError, SaasResult}; +use crate::auth::types::AuthContext; +use crate::auth::handlers::log_operation; +use crate::model_config::service as model_service; +use super::{types::*, service}; + +/// POST /api/v1/relay/chat/completions +/// OpenAI 兼容的聊天补全端点 +pub async fn chat_completions( + State(state): State, + Extension(ctx): Extension, + _headers: HeaderMap, + Json(req): Json, +) -> SaasResult { + // 检查 relay:use 权限 + if !ctx.permissions.contains(&"relay:use".to_string()) { + return Err(SaasError::Forbidden("需要 relay:use 权限".into())); + } + + let model_name = req.get("model") + .and_then(|v| v.as_str()) + .ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?; + + let stream = req.get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + // 查找 model 对应的 provider + let models = model_service::list_models(&state.db, None).await?; + let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled) + .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; + + // 获取 provider 信息 + let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?; + if !provider.enabled { + return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); + } + + // 获取 provider 的 API key (从数据库直接查询) + let provider_api_key: Option = sqlx::query_scalar( + "SELECT api_key FROM providers WHERE id = ?1" + ) + .bind(&target_model.provider_id) + .fetch_optional(&state.db) + .await? + .flatten(); + + let request_body = serde_json::to_string(&req)?; + + // 创建中转任务 + let task = service::create_relay_task( + &state.db, &ctx.account_id, &target_model.provider_id, + &target_model.model_id, &request_body, 0, + ).await?; + + log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id, + Some(serde_json::json!({"model": model_name, "stream": stream})), None).await?; + + // 执行中转 + let response = service::execute_relay( + &state.db, &task.id, &provider.base_url, + provider_api_key.as_deref(), &request_body, stream, + ).await; + + match response { + Ok(service::RelayResponse::Json(body)) => { + // 记录用量 + let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + let input_tokens = parsed.get("usage") + .and_then(|u| u.get("prompt_tokens")) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + let output_tokens = parsed.get("usage") + .and_then(|u| u.get("completion_tokens")) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + model_service::record_usage( + &state.db, &ctx.account_id, &target_model.provider_id, + &target_model.model_id, input_tokens, output_tokens, + None, "success", None, + ).await?; + + Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response()) + } + Ok(service::RelayResponse::Sse(body)) => { + model_service::record_usage( + &state.db, &ctx.account_id, &target_model.provider_id, + &target_model.model_id, 0, 0, + None, "success", None, + ).await?; + + Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response()) + } + Err(e) => { + model_service::record_usage( + &state.db, &ctx.account_id, &target_model.provider_id, + &target_model.model_id, 0, 0, + None, "failed", Some(&e.to_string()), + ).await?; + Err(e) + } + } +} + +/// GET /api/v1/relay/tasks +pub async fn list_tasks( + State(state): State, + Extension(ctx): Extension, + Query(query): Query, +) -> SaasResult>> { + service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json) +} + +/// GET /api/v1/relay/tasks/:id +pub async fn get_task( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, +) -> SaasResult> { + let task = service::get_relay_task(&state.db, &id).await?; + // 只允许查看自己的任务 (admin 可查看全部) + if task.account_id != ctx.account_id && !ctx.permissions.contains(&"relay:admin".to_string()) { + return Err(SaasError::Forbidden("无权查看此任务".into())); + } + Ok(Json(task)) +} + +/// GET /api/v1/relay/models +/// 列出可用的中转模型 (enabled providers + enabled models) +pub async fn list_available_models( + State(state): State, + _ctx: Extension, +) -> SaasResult>> { + let providers = model_service::list_providers(&state.db).await?; + let enabled_provider_ids: std::collections::HashSet = + providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect(); + + let models = model_service::list_models(&state.db, None).await?; + let available: Vec = models.into_iter() + .filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id)) + .map(|m| { + serde_json::json!({ + "id": m.model_id, + "provider_id": m.provider_id, + "alias": m.alias, + "context_window": m.context_window, + "max_output_tokens": m.max_output_tokens, + "supports_streaming": m.supports_streaming, + "supports_vision": m.supports_vision, + }) + }) + .collect(); + + Ok(Json(available)) +} diff --git a/crates/zclaw-saas/src/relay/mod.rs b/crates/zclaw-saas/src/relay/mod.rs index 8504245..0ef760f 100644 --- a/crates/zclaw-saas/src/relay/mod.rs +++ b/crates/zclaw-saas/src/relay/mod.rs @@ -1 +1,17 @@ -//! 请求中转模块 +//! 中转服务模块 + +pub mod types; +pub mod service; +pub mod handlers; + +use axum::routing::{get, post}; +use crate::state::AppState; + +/// 中转服务路由 (需要认证) +pub fn routes() -> axum::Router { + axum::Router::new() + .route("/api/v1/relay/chat/completions", post(handlers::chat_completions)) + .route("/api/v1/relay/tasks", get(handlers::list_tasks)) + .route("/api/v1/relay/tasks/{id}", get(handlers::get_task)) + .route("/api/v1/relay/models", get(handlers::list_available_models)) +} diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs new file mode 100644 index 0000000..06fa697 --- /dev/null +++ b/crates/zclaw-saas/src/relay/service.rs @@ -0,0 +1,197 @@ +//! 中转服务核心逻辑 + +use sqlx::SqlitePool; +use crate::error::{SaasError, SaasResult}; +use super::types::*; + +// ============ Relay Task Management ============ + +pub async fn create_relay_task( + db: &SqlitePool, + account_id: &str, + provider_id: &str, + model_id: &str, + request_body: &str, + priority: i64, +) -> SaasResult { + let id = uuid::Uuid::new_v4().to_string(); + let now = chrono::Utc::now().to_rfc3339(); + let request_hash = hash_request(request_body); + + sqlx::query( + "INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)" + ) + .bind(&id).bind(account_id).bind(provider_id).bind(model_id) + .bind(&request_hash).bind(request_body).bind(priority).bind(&now) + .execute(db).await?; + + get_relay_task(db, &id).await +} + +pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult { + let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option, String, Option, Option, String)> = + sqlx::query_as( + "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at + FROM relay_tasks WHERE id = ?1" + ) + .bind(task_id) + .fetch_optional(db) + .await?; + + let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) = + row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?; + + Ok(RelayTaskInfo { + id, account_id, provider_id, model_id, status, priority, + attempt_count, max_attempts, input_tokens, output_tokens, + error_message, queued_at, started_at, completed_at, created_at, + }) +} + +pub async fn list_relay_tasks( + db: &SqlitePool, account_id: &str, query: &RelayTaskQuery, +) -> SaasResult> { + let page = query.page.unwrap_or(1).max(1); + let page_size = query.page_size.unwrap_or(20).min(100); + let offset = (page - 1) * page_size; + + let sql = if query.status.is_some() { + "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at + FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4" + } else { + "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at + FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3" + }; + + let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option, String, Option, Option, String)>(sql) + .bind(account_id); + + if let Some(ref status) = query.status { + query_builder = query_builder.bind(status); + } + + query_builder = query_builder.bind(page_size).bind(offset); + + let rows = query_builder.fetch_all(db).await?; + Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| { + RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at } + }).collect()) +} + +pub async fn update_task_status( + db: &SqlitePool, task_id: &str, status: &str, + input_tokens: Option, output_tokens: Option, + error_message: Option<&str>, +) -> SaasResult<()> { + let now = chrono::Utc::now().to_rfc3339(); + + let update_sql = match status { + "processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1", + "completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)", + "failed" => "completed_at = ?1, status = 'failed', error_message = ?2", + _ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))), + }; + + let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql); + + let mut query = sqlx::query(&sql).bind(&now); + if status == "completed" { + query = query.bind(input_tokens).bind(output_tokens); + } + if status == "failed" { + query = query.bind(error_message); + } + query = query.bind(task_id); + query.execute(db).await?; + + Ok(()) +} + +// ============ Relay Execution ============ + +pub async fn execute_relay( + db: &SqlitePool, + task_id: &str, + provider_base_url: &str, + provider_api_key: Option<&str>, + request_body: &str, + stream: bool, +) -> SaasResult { + update_task_status(db, task_id, "processing", None, None, None).await?; + + let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/')); + let _start = std::time::Instant::now(); + + let client = reqwest::Client::new(); + let mut req_builder = client.post(&url) + .header("Content-Type", "application/json") + .body(request_body.to_string()); + + if let Some(key) = provider_api_key { + req_builder = req_builder.header("Authorization", format!("Bearer {}", key)); + } + + let result = req_builder.send().await; + + match result { + Ok(resp) if resp.status().is_success() => { + if stream { + let body = resp.text().await.unwrap_or_default(); + update_task_status(db, task_id, "completed", None, None, None).await?; + Ok(RelayResponse::Sse(body)) + } else { + let body = resp.text().await.unwrap_or_default(); + let (input_tokens, output_tokens) = extract_token_usage(&body); + update_task_status(db, task_id, "completed", + Some(input_tokens), Some(output_tokens), None).await?; + Ok(RelayResponse::Json(body)) + } + } + Ok(resp) => { + let status = resp.status().as_u16(); + let body = resp.text().await.unwrap_or_default(); + let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]); + update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + Err(SaasError::Relay(err_msg)) + } + Err(e) => { + let err_msg = format!("请求上游失败: {}", e); + update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; + Err(SaasError::Relay(err_msg)) + } + } +} + +/// 中转响应类型 +#[derive(Debug)] +pub enum RelayResponse { + Json(String), + Sse(String), +} + +// ============ Helpers ============ + +fn hash_request(body: &str) -> String { + use sha2::{Sha256, Digest}; + hex::encode(Sha256::digest(body.as_bytes())) +} + +fn extract_token_usage(body: &str) -> (i64, i64) { + let parsed: serde_json::Value = match serde_json::from_str(body) { + Ok(v) => v, + Err(_) => return (0, 0), + }; + + let usage = parsed.get("usage"); + let input = usage + .and_then(|u| u.get("prompt_tokens")) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + let output = usage + .and_then(|u| u.get("completion_tokens")) + .and_then(|v| v.as_i64()) + .unwrap_or(0); + + (input, output) +} diff --git a/crates/zclaw-saas/src/relay/types.rs b/crates/zclaw-saas/src/relay/types.rs new file mode 100644 index 0000000..64fdefe --- /dev/null +++ b/crates/zclaw-saas/src/relay/types.rs @@ -0,0 +1,59 @@ +//! 中转服务类型定义 + +use serde::{Deserialize, Serialize}; + +/// 中转请求 (OpenAI 兼容格式) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RelayChatRequest { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub temperature: Option, + #[serde(default)] + pub max_tokens: Option, + #[serde(default)] + pub stream: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: serde_json::Value, +} + +/// 中转任务信息 +#[derive(Debug, Clone, Serialize)] +pub struct RelayTaskInfo { + pub id: String, + pub account_id: String, + pub provider_id: String, + pub model_id: String, + pub status: String, + pub priority: i64, + pub attempt_count: i64, + pub max_attempts: i64, + pub input_tokens: i64, + pub output_tokens: i64, + pub error_message: Option, + pub queued_at: String, + pub started_at: Option, + pub completed_at: Option, + pub created_at: String, +} + +/// 中转任务查询 +#[derive(Debug, Deserialize)] +pub struct RelayTaskQuery { + pub status: Option, + pub page: Option, + pub page_size: Option, +} + +/// Provider 速率限制状态 +#[derive(Debug, Clone)] +pub struct RateLimitState { + pub rpm: i64, + pub tpm: i64, + pub concurrent: usize, + pub max_concurrent: usize, +} diff --git a/crates/zclaw-saas/tests/integration_test.rs b/crates/zclaw-saas/tests/integration_test.rs index d47d016..206d3e7 100644 --- a/crates/zclaw-saas/tests/integration_test.rs +++ b/crates/zclaw-saas/tests/integration_test.rs @@ -22,6 +22,7 @@ async fn build_test_app() -> axum::Router { let protected_routes = zclaw_saas::auth::protected_routes() .merge(zclaw_saas::account::routes()) .merge(zclaw_saas::model_config::routes()) + .merge(zclaw_saas::relay::routes()) .layer(axum::middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, @@ -288,3 +289,63 @@ async fn test_api_keys_lifecycle() { // provider 不存在 → 404 assert_eq!(resp.status(), StatusCode::NOT_FOUND); } + +// ============ Phase 3: 中转服务测试 ============ + +#[tokio::test] +async fn test_relay_models_list() { + let app = build_test_app().await; + let token = register_and_login(&app, "relayuser", "relayuser@example.com").await; + + // 列出可用中转模型 (空列表,因为没有 provider/model 种子数据) + let req = Request::builder() + .method("GET") + .uri("/api/v1/relay/models") + .header("Authorization", auth_header(&token)) + .body(Body::empty()) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert!(body.is_array()); +} + +#[tokio::test] +async fn test_relay_chat_no_model() { + let app = build_test_app().await; + let token = register_and_login(&app, "relayfail", "relayfail@example.com").await; + + // 尝试中转到不存在的模型 + let req = Request::builder() + .method("POST") + .uri("/api/v1/relay/chat/completions") + .header("Content-Type", "application/json") + .header("Authorization", auth_header(&token)) + .body(Body::from(serde_json::to_string(&json!({ + "model": "nonexistent-model", + "messages": [{"role": "user", "content": "hello"}] + })).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + // 模型不存在 → 404 + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn test_relay_tasks_list() { + let app = build_test_app().await; + let token = register_and_login(&app, "relaytasks", "relaytasks@example.com").await; + + let req = Request::builder() + .method("GET") + .uri("/api/v1/relay/tasks") + .header("Authorization", auth_header(&token)) + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); +}