feat(saas): Phase 3 — 模型请求中转服务

- OpenAI 兼容 API 代理 (/api/v1/relay/chat/completions)
- 中转任务管理 (创建/查询/状态跟踪)
- 可用模型列表端点 (仅 enabled providers+models)
- 任务生命周期 (queued → processing → completed/failed)
- 用量自动记录 (token 统计 + 错误追踪)
- 3 个新集成测试覆盖中转端点
This commit is contained in:
iven
2026-03-27 12:50:05 +08:00
parent fec64af565
commit a99a3df9dd
6 changed files with 500 additions and 1 deletions

View File

@@ -0,0 +1,165 @@
//! 中转服务 HTTP 处理器
use axum::{
extract::{Extension, Path, Query, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::log_operation;
use crate::model_config::service as model_service;
use super::{types::*, service};
/// POST /api/v1/relay/chat/completions
/// OpenAI 兼容的聊天补全端点
pub async fn chat_completions(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
_headers: HeaderMap,
Json(req): Json<serde_json::Value>,
) -> SaasResult<Response> {
// 检查 relay:use 权限
if !ctx.permissions.contains(&"relay:use".to_string()) {
return Err(SaasError::Forbidden("需要 relay:use 权限".into()));
}
let model_name = req.get("model")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
let stream = req.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model 对应的 provider
let models = model_service::list_models(&state.db, None).await?;
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// 获取 provider 信息
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
// 获取 provider 的 API key (从数据库直接查询)
let provider_api_key: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = ?1"
)
.bind(&target_model.provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
let request_body = serde_json::to_string(&req)?;
// 创建中转任务
let task = service::create_relay_task(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, &request_body, 0,
).await?;
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
Some(serde_json::json!({"model": model_name, "stream": stream})), None).await?;
// 执行中转
let response = service::execute_relay(
&state.db, &task.id, &provider.base_url,
provider_api_key.as_deref(), &request_body, stream,
).await;
match response {
Ok(service::RelayResponse::Json(body)) => {
// 记录用量
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
let input_tokens = parsed.get("usage")
.and_then(|u| u.get("prompt_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let output_tokens = parsed.get("usage")
.and_then(|u| u.get("completion_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
model_service::record_usage(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, input_tokens, output_tokens,
None, "success", None,
).await?;
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
}
Ok(service::RelayResponse::Sse(body)) => {
model_service::record_usage(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, 0, 0,
None, "success", None,
).await?;
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response())
}
Err(e) => {
model_service::record_usage(
&state.db, &ctx.account_id, &target_model.provider_id,
&target_model.model_id, 0, 0,
None, "failed", Some(&e.to_string()),
).await?;
Err(e)
}
}
}
/// GET /api/v1/relay/tasks
pub async fn list_tasks(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(query): Query<RelayTaskQuery>,
) -> SaasResult<Json<Vec<RelayTaskInfo>>> {
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
}
/// GET /api/v1/relay/tasks/:id
pub async fn get_task(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<RelayTaskInfo>> {
let task = service::get_relay_task(&state.db, &id).await?;
// 只允许查看自己的任务 (admin 可查看全部)
if task.account_id != ctx.account_id && !ctx.permissions.contains(&"relay:admin".to_string()) {
return Err(SaasError::Forbidden("无权查看此任务".into()));
}
Ok(Json(task))
}
/// GET /api/v1/relay/models
/// 列出可用的中转模型 (enabled providers + enabled models)
pub async fn list_available_models(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> {
let providers = model_service::list_providers(&state.db).await?;
let enabled_provider_ids: std::collections::HashSet<String> =
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
let models = model_service::list_models(&state.db, None).await?;
let available: Vec<serde_json::Value> = models.into_iter()
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
.map(|m| {
serde_json::json!({
"id": m.model_id,
"provider_id": m.provider_id,
"alias": m.alias,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
"supports_streaming": m.supports_streaming,
"supports_vision": m.supports_vision,
})
})
.collect();
Ok(Json(available))
}