feat(saas): Phase 3 — 模型请求中转服务
- OpenAI 兼容 API 代理 (/api/v1/relay/chat/completions) - 中转任务管理 (创建/查询/状态跟踪) - 可用模型列表端点 (仅 enabled providers+models) - 任务生命周期 (queued → processing → completed/failed) - 用量自动记录 (token 统计 + 错误追踪) - 3 个新集成测试覆盖中转端点
This commit is contained in:
165
crates/zclaw-saas/src/relay/handlers.rs
Normal file
165
crates/zclaw-saas/src/relay/handlers.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
//! 中转服务 HTTP 处理器
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Path, Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::auth::handlers::log_operation;
|
||||
use crate::model_config::service as model_service;
|
||||
use super::{types::*, service};
|
||||
|
||||
/// POST /api/v1/relay/chat/completions
|
||||
/// OpenAI 兼容的聊天补全端点
|
||||
pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
_headers: HeaderMap,
|
||||
Json(req): Json<serde_json::Value>,
|
||||
) -> SaasResult<Response> {
|
||||
// 检查 relay:use 权限
|
||||
if !ctx.permissions.contains(&"relay:use".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 relay:use 权限".into()));
|
||||
}
|
||||
|
||||
let model_name = req.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||||
|
||||
let stream = req.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
|
||||
if !provider.enabled {
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
// 获取 provider 的 API key (从数据库直接查询)
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(&target_model.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), None).await?;
|
||||
|
||||
// 执行中转
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &request_body, stream,
|
||||
).await;
|
||||
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
// 记录用量
|
||||
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
|
||||
let input_tokens = parsed.get("usage")
|
||||
.and_then(|u| u.get("prompt_tokens"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
let output_tokens = parsed.get("usage")
|
||||
.and_then(|u| u.get("completion_tokens"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, input_tokens, output_tokens,
|
||||
None, "success", None,
|
||||
).await?;
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, 0, 0,
|
||||
None, "success", None,
|
||||
).await?;
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "text/event-stream")], body).into_response())
|
||||
}
|
||||
Err(e) => {
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, 0, 0,
|
||||
None, "failed", Some(&e.to_string()),
|
||||
).await?;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/v1/relay/tasks
|
||||
pub async fn list_tasks(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Query(query): Query<RelayTaskQuery>,
|
||||
) -> SaasResult<Json<Vec<RelayTaskInfo>>> {
|
||||
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
||||
}
|
||||
|
||||
/// GET /api/v1/relay/tasks/:id
|
||||
pub async fn get_task(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<RelayTaskInfo>> {
|
||||
let task = service::get_relay_task(&state.db, &id).await?;
|
||||
// 只允许查看自己的任务 (admin 可查看全部)
|
||||
if task.account_id != ctx.account_id && !ctx.permissions.contains(&"relay:admin".to_string()) {
|
||||
return Err(SaasError::Forbidden("无权查看此任务".into()));
|
||||
}
|
||||
Ok(Json(task))
|
||||
}
|
||||
|
||||
/// GET /api/v1/relay/models
|
||||
/// 列出可用的中转模型 (enabled providers + enabled models)
|
||||
pub async fn list_available_models(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
let providers = model_service::list_providers(&state.db).await?;
|
||||
let enabled_provider_ids: std::collections::HashSet<String> =
|
||||
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
||||
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let available: Vec<serde_json::Value> = models.into_iter()
|
||||
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
||||
.map(|m| {
|
||||
serde_json::json!({
|
||||
"id": m.model_id,
|
||||
"provider_id": m.provider_id,
|
||||
"alias": m.alias,
|
||||
"context_window": m.context_window,
|
||||
"max_output_tokens": m.max_output_tokens,
|
||||
"supports_streaming": m.supports_streaming,
|
||||
"supports_vision": m.supports_vision,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(available))
|
||||
}
|
||||
Reference in New Issue
Block a user