Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P0-1: SaaS relay 模型别名解析 — "glm-4-flash" → "glm-4-flash-250414" (resolve_model)
P0-2: config.rs interpolate_env_vars UTF-8 修复 (chars 迭代器替代 bytes as char)
+ DB 启动编码检查 + docker-compose UTF-8 编码参数
P1-3: UI 模型选择器覆盖 Agent 默认模型 (model_override 全链路: TS→Tauri→Rust kernel)
P1-6: 知识搜索管道修复 — seed_knowledge 创建 chunks + 默认分类 (seed/uploaded/distillation)
P1-7: 用量限额从当前 Plan 读取 (非 stale usage 表)
P1-8: relay 双维度配额检查 (relay_requests + input_tokens)
P2-9: SSE 路径 token 计数修复 — 流结束检测替代固定 500ms sleep + billing increment
708 lines
28 KiB
Rust
708 lines
28 KiB
Rust
//! 中转服务 HTTP 处理器
|
||
|
||
use axum::{
|
||
extract::{Extension, Path, Query, State},
|
||
http::{HeaderMap, StatusCode},
|
||
response::{IntoResponse, Response},
|
||
Json,
|
||
};
|
||
use crate::state::AppState;
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::auth::types::AuthContext;
|
||
use crate::auth::handlers::check_permission;
|
||
use crate::model_config::service as model_service;
|
||
use super::{types::*, service};
|
||
|
||
/// POST /api/v1/relay/chat/completions
|
||
/// OpenAI 兼容的聊天补全端点
|
||
pub async fn chat_completions(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
_headers: HeaderMap,
|
||
Json(req): Json<serde_json::Value>,
|
||
) -> SaasResult<Response> {
|
||
check_permission(&ctx, "relay:use")?;
|
||
|
||
// 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询
|
||
let max_queue_size = {
|
||
let config = state.config.read().await;
|
||
config.relay.max_queue_size
|
||
};
|
||
let queued_count = state.cache.relay_queue_count(&ctx.account_id);
|
||
|
||
if queued_count >= max_queue_size as i64 {
|
||
return Err(SaasError::RateLimited(
|
||
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
||
));
|
||
}
|
||
|
||
// --- 输入验证 ---
|
||
// 请求体大小限制 (1 MB) — 直接序列化一次,后续复用
|
||
const MAX_BODY_BYTES: usize = 1024 * 1024;
|
||
let request_body = serde_json::to_string(&req)
|
||
.map_err(|e| SaasError::InvalidInput(format!("请求体序列化失败: {}", e)))?;
|
||
if request_body.len() > MAX_BODY_BYTES {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("请求体超过大小限制 ({} bytes > {} bytes)", request_body.len(), MAX_BODY_BYTES)
|
||
));
|
||
}
|
||
|
||
// model 字段
|
||
let model_name = req.get("model")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||
|
||
// messages 字段:必须存在且为非空数组
|
||
let messages = req.get("messages")
|
||
.ok_or_else(|| SaasError::InvalidInput("缺少 messages 字段".into()))?;
|
||
let messages_arr = messages.as_array()
|
||
.ok_or_else(|| SaasError::InvalidInput("messages 必须是数组".into()))?;
|
||
if messages_arr.is_empty() {
|
||
return Err(SaasError::InvalidInput("messages 数组不能为空".into()));
|
||
}
|
||
|
||
// 验证每个 message 的 role 和 content
|
||
let valid_roles = ["system", "user", "assistant", "tool"];
|
||
for (i, msg) in messages_arr.iter().enumerate() {
|
||
let role = msg.get("role")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| SaasError::InvalidInput(
|
||
format!("messages[{}] 缺少 role 字段", i)
|
||
))?;
|
||
if !valid_roles.contains(&role) {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("messages[{}] 的 role 必须是 system/user/assistant/tool 之一,得到: {}", i, role)
|
||
));
|
||
}
|
||
|
||
let content = msg.get("content")
|
||
.ok_or_else(|| SaasError::InvalidInput(
|
||
format!("messages[{}] 缺少 content 字段", i)
|
||
))?;
|
||
// content 必须是字符串或数组 (多模态)
|
||
if !content.is_string() && !content.is_array() {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("messages[{}] 的 content 必须是字符串或数组", i)
|
||
));
|
||
}
|
||
}
|
||
|
||
// temperature 范围校验
|
||
if let Some(temp) = req.get("temperature") {
|
||
match temp.as_f64() {
|
||
Some(t) if t < 0.0 || t > 2.0 => {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("temperature 必须在 0.0 ~ 2.0 范围内,得到: {}", t)
|
||
));
|
||
}
|
||
Some(_) => {} // valid
|
||
None => {
|
||
return Err(SaasError::InvalidInput("temperature 必须是数字".into()));
|
||
}
|
||
}
|
||
}
|
||
|
||
// max_tokens 范围校验
|
||
if let Some(tokens) = req.get("max_tokens") {
|
||
match tokens.as_u64() {
|
||
Some(t) if t < 1 || t > 128000 => {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("max_tokens 必须在 1 ~ 128000 范围内,得到: {}", t)
|
||
));
|
||
}
|
||
Some(_) => {} // valid
|
||
None => {
|
||
return Err(SaasError::InvalidInput("max_tokens 必须是正整数".into()));
|
||
}
|
||
}
|
||
}
|
||
// --- 输入验证结束 ---
|
||
|
||
let stream = req.get("stream")
|
||
.and_then(|v| v.as_bool())
|
||
.unwrap_or(false);
|
||
|
||
// 查找 model — 优先检查模型组(跨 Provider Failover),回退到直接模型查找
|
||
let mut model_resolution = if let Some(group) = state.cache.get_model_group(model_name) {
|
||
// 逻辑模型组:构建候选列表
|
||
let mut candidates: Vec<CandidateModel> = Vec::new();
|
||
for member in &group.members {
|
||
if !member.enabled {
|
||
continue;
|
||
}
|
||
let provider = match state.cache.get_provider(&member.provider_id) {
|
||
Some(p) => p,
|
||
None => continue,
|
||
};
|
||
let physical_model = match state.cache.get_model(&member.model_id) {
|
||
Some(m) => m,
|
||
None => continue,
|
||
};
|
||
candidates.push(CandidateModel {
|
||
provider_id: member.provider_id.clone(),
|
||
model_id: member.model_id.clone(),
|
||
base_url: provider.base_url.clone(),
|
||
supports_streaming: physical_model.supports_streaming,
|
||
});
|
||
}
|
||
if candidates.is_empty() {
|
||
return Err(SaasError::NotFound(
|
||
format!("模型组 '{}' 没有可用的候选 Provider", model_name)
|
||
));
|
||
}
|
||
ModelResolution::Group(candidates)
|
||
} else {
|
||
// 向后兼容:直接模型查找 + 别名解析(如 "glm-4-flash" → "glm-4-flash-250414")
|
||
let target_model = state.cache.resolve_model(model_name)
|
||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||
|
||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
|
||
if !provider.enabled {
|
||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||
}
|
||
|
||
ModelResolution::Direct(CandidateModel {
|
||
provider_id: target_model.provider_id.clone(),
|
||
model_id: target_model.model_id.clone(),
|
||
base_url: provider.base_url.clone(),
|
||
supports_streaming: target_model.supports_streaming,
|
||
})
|
||
};
|
||
|
||
// Stream compatibility check
|
||
if stream && model_resolution.any_non_streaming() {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("模型 {} 不支持流式响应,请使用 stream: false", model_name)
|
||
));
|
||
}
|
||
|
||
// request_body 已在前面序列化并验证大小,直接复用
|
||
|
||
// 创建中转任务(提取配置后立即释放读锁)
|
||
let (max_attempts, retry_delay_ms, enc_key) = {
|
||
let config = state.config.read().await;
|
||
let key = config.api_key_encryption_key()
|
||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||
(config.relay.max_attempts, config.relay.retry_delay_ms, key)
|
||
};
|
||
|
||
let task = service::create_relay_task(
|
||
&state.db, &ctx.account_id,
|
||
model_resolution.first_provider_id(),
|
||
model_resolution.first_model_id(),
|
||
&request_body, 0,
|
||
max_attempts,
|
||
).await?;
|
||
|
||
// 递增内存队列计数器(替代 DB COUNT 查询)
|
||
state.cache.relay_enqueue(&ctx.account_id);
|
||
|
||
// 异步派发操作日志(非阻塞,不占用关键路径 DB 连接)
|
||
// P3-06: Include session_key/agent_id in log for traceability
|
||
let log_meta = serde_json::json!({
|
||
"model": model_name,
|
||
"stream": stream,
|
||
"session_key": req.get("session_key").and_then(|v| v.as_str()),
|
||
"agent_id": req.get("agent_id").and_then(|v| v.as_str()),
|
||
});
|
||
state.dispatch_log_operation(
|
||
&ctx.account_id, "relay.request", "relay_task", &task.id,
|
||
Some(log_meta), ctx.client_ip.as_deref(),
|
||
).await;
|
||
|
||
// 执行中转:根据解析结果选择执行路径
|
||
// C-1: 提取实际服务的 provider_id / model_id 用于精准计费归因
|
||
let relay_result = match model_resolution {
|
||
ModelResolution::Direct(ref candidate) => {
|
||
// 单 Provider 直接路由(向后兼容)
|
||
match service::execute_relay(
|
||
&state.db, &task.id, &ctx.account_id, &candidate.provider_id,
|
||
&candidate.base_url, &request_body, stream,
|
||
max_attempts, retry_delay_ms, &enc_key,
|
||
true, // 独立调用,管理 task 状态
|
||
).await {
|
||
Ok(resp) => Ok((resp, candidate.provider_id.clone(), candidate.model_id.clone())),
|
||
Err(e) => Err(e),
|
||
}
|
||
}
|
||
ModelResolution::Group(ref mut candidates) => {
|
||
// 跨 Provider Failover(按配额余量自动排序)
|
||
// 注意: Failover 仅适用于预流失败(连接错误、429/5xx 在流开始前)。
|
||
// SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。
|
||
service::sort_candidates_by_quota(&state.db, candidates).await;
|
||
service::execute_relay_with_failover(
|
||
&state.db, &task.id, &ctx.account_id, candidates,
|
||
&request_body, stream,
|
||
max_attempts, retry_delay_ms, &enc_key
|
||
).await
|
||
}
|
||
};
|
||
|
||
// 失败时:记录 failure usage + 递减队列计数器(失败请求不计费)
|
||
let (response, actual_provider_id, actual_model_id) = match relay_result {
|
||
Ok(triple) => triple,
|
||
Err(e) => {
|
||
// 通过 Worker dispatch 记录 failure usage
|
||
{
|
||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||
account_id: ctx.account_id.clone(),
|
||
provider_id: model_resolution.first_provider_id().to_string(),
|
||
model_id: model_resolution.first_model_id().to_string(),
|
||
input_tokens: 0,
|
||
output_tokens: 0,
|
||
latency_ms: None,
|
||
status: "failed".to_string(),
|
||
error_message: Some(e.to_string()),
|
||
};
|
||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||
}
|
||
}
|
||
// 递减队列计数器(防止队列计数泄漏 → 连接池耗尽)
|
||
state.cache.relay_dequeue(&ctx.account_id);
|
||
return Err(e);
|
||
}
|
||
};
|
||
|
||
// 使用实际服务的 provider/model 进行计费归因
|
||
let account_id_usage = ctx.account_id.clone();
|
||
let provider_id_usage = actual_provider_id;
|
||
let model_id_usage = actual_model_id;
|
||
|
||
match response {
|
||
service::RelayResponse::Json(body) => {
|
||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||
// 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应)
|
||
{
|
||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||
account_id: account_id_usage.clone(),
|
||
provider_id: provider_id_usage.clone(),
|
||
model_id: model_id_usage.clone(),
|
||
input_tokens: input_tokens as i32,
|
||
output_tokens: output_tokens as i32,
|
||
latency_ms: None,
|
||
status: "success".to_string(),
|
||
error_message: None,
|
||
};
|
||
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||
tracing::warn!("Failed to dispatch record_usage: {}", e);
|
||
}
|
||
}
|
||
|
||
// 实时更新计费配额(relay_requests + tokens 同步递增)
|
||
if let Err(e) = crate::billing::service::increment_usage(
|
||
&state.db, &account_id_usage, input_tokens as i64, output_tokens as i64,
|
||
).await {
|
||
tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e);
|
||
}
|
||
|
||
// 任务完成,递减队列计数器
|
||
state.cache.relay_dequeue(&account_id_usage);
|
||
|
||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||
}
|
||
service::RelayResponse::Sse(body) => {
|
||
// 通过 Worker dispatch 记录 SSE 占位 usage
|
||
{
|
||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||
account_id: account_id_usage.clone(),
|
||
provider_id: provider_id_usage.clone(),
|
||
model_id: model_id_usage.clone(),
|
||
input_tokens: 0,
|
||
output_tokens: 0,
|
||
latency_ms: None,
|
||
status: "streaming".to_string(),
|
||
error_message: None,
|
||
};
|
||
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||
tracing::warn!("Failed to dispatch SSE usage: {}", e);
|
||
}
|
||
}
|
||
|
||
// SSE: relay_requests 实时递增(tokens 由 AggregateUsageWorker 对账修正)
|
||
if let Err(e) = crate::billing::service::increment_dimension(
|
||
&state.db, &account_id_usage, "relay_requests",
|
||
).await {
|
||
tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e);
|
||
}
|
||
|
||
// SSE 流已返回,递减队列计数器(流式任务开始处理)
|
||
state.cache.relay_dequeue(&account_id_usage);
|
||
|
||
let response = axum::response::Response::builder()
|
||
.status(StatusCode::OK)
|
||
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||
.header("Cache-Control", "no-cache")
|
||
.header("Connection", "keep-alive")
|
||
.body(body)
|
||
.expect("SSE response builder with valid status/headers cannot fail");
|
||
Ok(response)
|
||
}
|
||
}
|
||
}
|
||
|
||
/// GET /api/v1/relay/tasks
|
||
pub async fn list_tasks(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Query(query): Query<RelayTaskQuery>,
|
||
) -> SaasResult<Json<crate::common::PaginatedResponse<RelayTaskInfo>>> {
|
||
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
||
}
|
||
|
||
/// GET /api/v1/relay/tasks/:id
|
||
pub async fn get_task(
|
||
State(state): State<AppState>,
|
||
Path(id): Path<String>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
) -> SaasResult<Json<RelayTaskInfo>> {
|
||
let task = service::get_relay_task(&state.db, &id).await?;
|
||
// 只允许查看自己的任务 (admin 可查看全部)
|
||
if task.account_id != ctx.account_id {
|
||
check_permission(&ctx, "relay:admin")?;
|
||
}
|
||
Ok(Json(task))
|
||
}
|
||
|
||
/// GET /api/v1/relay/models
|
||
/// 列出可用的中转模型 (enabled providers + enabled models)
|
||
pub async fn list_available_models(
|
||
State(state): State<AppState>,
|
||
_ctx: Extension<AuthContext>,
|
||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||
// 单次 JOIN 查询替代 2 次全量加载
|
||
let rows: Vec<(String, String, String, i64, i64, bool, bool, bool, String)> = sqlx::query_as(
|
||
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
|
||
m.max_output_tokens, m.supports_streaming, m.supports_vision,
|
||
m.is_embedding, m.model_type
|
||
FROM models m
|
||
INNER JOIN providers p ON m.provider_id = p.id
|
||
WHERE m.enabled = true AND p.enabled = true
|
||
ORDER BY m.provider_id, m.model_id"
|
||
)
|
||
.fetch_all(&state.db)
|
||
.await?;
|
||
|
||
let mut available: Vec<serde_json::Value> = rows.into_iter()
|
||
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, is_embedding, model_type)| {
|
||
serde_json::json!({
|
||
"id": model_id,
|
||
"provider_id": provider_id,
|
||
"alias": alias,
|
||
"context_window": context_window,
|
||
"max_output_tokens": max_output_tokens,
|
||
"supports_streaming": supports_streaming,
|
||
"supports_vision": supports_vision,
|
||
"is_embedding": is_embedding,
|
||
"model_type": model_type,
|
||
})
|
||
})
|
||
.collect();
|
||
|
||
// 追加模型组(逻辑模型),使前端能展示和选择
|
||
for entry in state.cache.model_groups.iter() {
|
||
let group = entry.value();
|
||
if !group.enabled {
|
||
continue;
|
||
}
|
||
// H-2: 过滤无可用成员的模型组,避免前端选择后请求失败
|
||
let active_members: Vec<_> = group.members.iter()
|
||
.filter(|m| m.enabled)
|
||
.collect();
|
||
if active_members.is_empty() {
|
||
continue;
|
||
}
|
||
// 所有 active 成员都支持 streaming → 模型组支持 streaming
|
||
let all_streaming = active_members.iter().all(|m| {
|
||
state.cache.get_model(&m.model_id)
|
||
.map(|cm| cm.supports_streaming)
|
||
.unwrap_or(true)
|
||
});
|
||
// 任一 active 成员支持 vision → 模型组支持 vision
|
||
let any_vision = active_members.iter().any(|m| {
|
||
state.cache.get_model(&m.model_id)
|
||
.map(|cm| cm.supports_vision)
|
||
.unwrap_or(false)
|
||
});
|
||
available.push(serde_json::json!({
|
||
"id": group.name,
|
||
"provider_id": "group",
|
||
"alias": group.display_name,
|
||
"is_group": true,
|
||
"member_count": group.members.len(),
|
||
"supports_streaming": all_streaming,
|
||
"supports_vision": any_vision,
|
||
}));
|
||
}
|
||
|
||
Ok(Json(available))
|
||
}
|
||
|
||
/// POST /api/v1/relay/tasks/:id/retry (admin only)
|
||
/// 重试失败的中转任务
|
||
pub async fn retry_task(
|
||
State(state): State<AppState>,
|
||
Path(id): Path<String>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
check_permission(&ctx, "relay:admin")?;
|
||
|
||
let task = service::get_relay_task(&state.db, &id).await?;
|
||
if task.status != "failed" {
|
||
return Err(SaasError::InvalidInput(format!(
|
||
"只能重试失败的任务,当前状态: {}", task.status
|
||
)));
|
||
}
|
||
|
||
// 读取原始请求体
|
||
let request_body: Option<String> = sqlx::query_scalar(
|
||
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
||
)
|
||
.bind(&id)
|
||
.fetch_optional(&state.db)
|
||
.await?
|
||
.flatten();
|
||
|
||
let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?;
|
||
|
||
// 从 request body 解析 stream 标志和 model 字段
|
||
let parsed_body: Option<serde_json::Value> = serde_json::from_str(&body).ok();
|
||
let stream: bool = parsed_body.as_ref()
|
||
.and_then(|v| v.get("stream").and_then(|s| s.as_bool()))
|
||
.unwrap_or(false);
|
||
let model_name: Option<String> = parsed_body.as_ref()
|
||
.and_then(|v| v.get("model").and_then(|m| m.as_str()).map(|s| s.to_string()));
|
||
|
||
// H-8: 重新解析模型组 — 如果原始请求使用模型组,重试时走 failover 路径
|
||
// 而不是盲目使用存储的(可能已失败的)provider_id
|
||
let mut model_resolution = if let Some(ref name) = model_name {
|
||
if let Some(group) = state.cache.get_model_group(name) {
|
||
// 模型组:构建候选列表
|
||
let mut candidates: Vec<CandidateModel> = Vec::new();
|
||
for member in &group.members {
|
||
if !member.enabled { continue; }
|
||
let provider = match state.cache.get_provider(&member.provider_id) {
|
||
Some(p) => p,
|
||
None => continue,
|
||
};
|
||
let physical_model = match state.cache.get_model(&member.model_id) {
|
||
Some(m) => m,
|
||
None => continue,
|
||
};
|
||
candidates.push(CandidateModel {
|
||
provider_id: member.provider_id.clone(),
|
||
model_id: member.model_id.clone(),
|
||
base_url: provider.base_url.clone(),
|
||
supports_streaming: physical_model.supports_streaming,
|
||
});
|
||
}
|
||
if candidates.is_empty() {
|
||
return Err(SaasError::NotFound(
|
||
format!("模型组 '{}' 没有可用的候选 Provider(重试时解析)", name)
|
||
));
|
||
}
|
||
ModelResolution::Group(candidates)
|
||
} else if let Some(target_model) = state.cache.get_model(name) {
|
||
// 直接模型
|
||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在(重试时解析)", target_model.provider_id)))?;
|
||
ModelResolution::Direct(CandidateModel {
|
||
provider_id: target_model.provider_id.clone(),
|
||
model_id: target_model.model_id.clone(),
|
||
base_url: provider.base_url.clone(),
|
||
supports_streaming: target_model.supports_streaming,
|
||
})
|
||
} else {
|
||
// 无法解析,回退到存储的 provider_id(向后兼容)
|
||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||
ModelResolution::Direct(CandidateModel {
|
||
provider_id: task.provider_id.clone(),
|
||
model_id: task.model_id.clone(),
|
||
base_url: provider.base_url.clone(),
|
||
supports_streaming: true,
|
||
})
|
||
}
|
||
} else {
|
||
// 无 model 字段,回退到存储的 provider_id
|
||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||
ModelResolution::Direct(CandidateModel {
|
||
provider_id: task.provider_id.clone(),
|
||
model_id: task.model_id.clone(),
|
||
base_url: provider.base_url.clone(),
|
||
supports_streaming: true,
|
||
})
|
||
};
|
||
|
||
let max_attempts = task.max_attempts as u32;
|
||
let config = state.config.read().await;
|
||
let base_delay_ms = config.relay.retry_delay_ms;
|
||
let enc_key = config.api_key_encryption_key()
|
||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||
drop(config);
|
||
|
||
// 重置任务状态为 queued 以允许新的 processing
|
||
sqlx::query(
|
||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
|
||
)
|
||
.bind(&id)
|
||
.execute(&state.db)
|
||
.await?;
|
||
|
||
// 异步执行重试 — 根据解析结果选择执行路径
|
||
let db = state.db.clone();
|
||
let task_id = id.clone();
|
||
let account_id_for_spawn = task.account_id.clone();
|
||
let handle = tokio::spawn(async move {
|
||
let result = match model_resolution {
|
||
ModelResolution::Direct(ref candidate) => {
|
||
service::execute_relay(
|
||
&db, &task_id, &account_id_for_spawn, &candidate.provider_id,
|
||
&candidate.base_url, &body, stream,
|
||
max_attempts, base_delay_ms, &enc_key,
|
||
true,
|
||
).await
|
||
}
|
||
ModelResolution::Group(ref mut candidates) => {
|
||
service::sort_candidates_by_quota(&db, candidates).await;
|
||
service::execute_relay_with_failover(
|
||
&db, &task_id, &account_id_for_spawn, candidates,
|
||
&body, stream,
|
||
max_attempts, base_delay_ms, &enc_key,
|
||
).await
|
||
.map(|(resp, _, _)| resp)
|
||
}
|
||
};
|
||
match result {
|
||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
||
}
|
||
});
|
||
// Detach with warning — if server shuts down mid-retry, the task is lost.
|
||
// The DB status is already reset to 'queued', so a future restart can pick it up.
|
||
tokio::spawn(async move {
|
||
if let Err(e) = handle.await {
|
||
tracing::warn!("Relay retry task aborted (server shutdown?): {}", e);
|
||
}
|
||
});
|
||
|
||
// 异步派发操作日志
|
||
state.dispatch_log_operation(
|
||
&ctx.account_id, "relay.retry", "relay_task", &id,
|
||
None, ctx.client_ip.as_deref(),
|
||
).await;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
||
}
|
||
|
||
// ============ Key Pool 管理 (admin only) ============
|
||
|
||
/// GET /api/v1/providers/:provider_id/keys
|
||
pub async fn list_provider_keys(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Path(provider_id): Path<String>,
|
||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||
check_permission(&ctx, "provider:manage")?;
|
||
let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?;
|
||
Ok(Json(keys))
|
||
}
|
||
|
||
/// POST /api/v1/providers/:provider_id/keys
|
||
#[derive(serde::Deserialize)]
|
||
pub struct AddKeyRequest {
|
||
pub key_label: String,
|
||
pub key_value: String,
|
||
#[serde(default)]
|
||
pub priority: i32,
|
||
pub max_rpm: Option<i64>,
|
||
pub max_tpm: Option<i64>,
|
||
}
|
||
|
||
pub async fn add_provider_key(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Path(provider_id): Path<String>,
|
||
Json(req): Json<AddKeyRequest>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
check_permission(&ctx, "provider:manage")?;
|
||
|
||
if req.key_label.trim().is_empty() {
|
||
return Err(SaasError::InvalidInput("key_label 不能为空".into()));
|
||
}
|
||
if req.key_value.trim().is_empty() {
|
||
return Err(SaasError::InvalidInput("key_value 不能为空".into()));
|
||
}
|
||
if req.key_value.len() < 20 {
|
||
return Err(SaasError::InvalidInput("key_value 长度不足(至少 20 字符)".into()));
|
||
}
|
||
if req.key_value.contains(char::is_whitespace) {
|
||
return Err(SaasError::InvalidInput("key_value 不能包含空白字符".into()));
|
||
}
|
||
|
||
// Encrypt the API key before storing in database
|
||
let enc_key = state.config.read().await.api_key_encryption_key()?;
|
||
let encrypted_value = crate::crypto::encrypt_value(&req.key_value, &enc_key)?;
|
||
|
||
let key_id = super::key_pool::add_provider_key(
|
||
&state.db, &provider_id, &req.key_label, &encrypted_value,
|
||
req.priority, req.max_rpm, req.max_tpm,
|
||
).await?;
|
||
|
||
// 异步派发操作日志
|
||
state.dispatch_log_operation(
|
||
&ctx.account_id, "provider_key.add", "provider_key", &key_id,
|
||
Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})),
|
||
ctx.client_ip.as_deref(),
|
||
).await;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true, "key_id": key_id})))
|
||
}
|
||
|
||
/// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle
|
||
#[derive(serde::Deserialize)]
|
||
pub struct ToggleKeyRequest {
|
||
pub active: bool,
|
||
}
|
||
|
||
pub async fn toggle_provider_key(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Path((provider_id, key_id)): Path<(String, String)>,
|
||
Json(req): Json<ToggleKeyRequest>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
check_permission(&ctx, "provider:manage")?;
|
||
|
||
super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?;
|
||
|
||
// 异步派发操作日志
|
||
state.dispatch_log_operation(
|
||
&ctx.account_id, "provider_key.toggle", "provider_key", &key_id,
|
||
Some(serde_json::json!({"provider_id": provider_id, "active": req.active})),
|
||
ctx.client_ip.as_deref(),
|
||
).await;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true})))
|
||
}
|
||
|
||
/// DELETE /api/v1/providers/:provider_id/keys/:key_id
|
||
pub async fn delete_provider_key(
|
||
State(state): State<AppState>,
|
||
Extension(ctx): Extension<AuthContext>,
|
||
Path((provider_id, key_id)): Path<(String, String)>,
|
||
) -> SaasResult<Json<serde_json::Value>> {
|
||
check_permission(&ctx, "provider:manage")?;
|
||
|
||
super::key_pool::delete_provider_key(&state.db, &key_id).await?;
|
||
|
||
// 异步派发操作日志
|
||
state.dispatch_log_operation(
|
||
&ctx.account_id, "provider_key.delete", "provider_key", &key_id,
|
||
Some(serde_json::json!({"provider_id": provider_id})),
|
||
ctx.client_ip.as_deref(),
|
||
).await;
|
||
|
||
Ok(Json(serde_json::json!({"ok": true})))
|
||
}
|