chore: 提交所有工作进度 — SaaS 后端增强、Admin UI、桌面端集成
包含大量 SaaS 平台改进、Admin 管理后台更新、桌面端集成完善、 文档同步、测试文件重构等内容。为 QA 测试准备干净工作树。
This commit is contained in:
@@ -23,6 +23,22 @@ pub async fn chat_completions(
|
||||
) -> SaasResult<Response> {
|
||||
check_permission(&ctx, "relay:use")?;
|
||||
|
||||
// 队列容量检查:防止过载
|
||||
let config = state.config.read().await;
|
||||
let queued_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if queued_count >= config.relay.max_queue_size as i64 {
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
||||
));
|
||||
}
|
||||
|
||||
let model_name = req.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||||
@@ -32,7 +48,7 @@ pub async fn chat_completions(
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
@@ -42,15 +58,6 @@ pub async fn chat_completions(
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
// 获取 provider 的 API key (从数据库直接查询)
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(&target_model.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
@@ -64,27 +71,22 @@ pub async fn chat_completions(
|
||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
||||
|
||||
// 执行中转 (带重试)
|
||||
// 获取加密密钥用于解密 API Key
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &request_body, stream,
|
||||
&state.db, &task.id, &target_model.provider_id,
|
||||
&provider.base_url, &request_body, stream,
|
||||
config.relay.max_attempts,
|
||||
config.relay.retry_delay_ms,
|
||||
&enc_key,
|
||||
).await;
|
||||
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
// 记录用量
|
||||
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
|
||||
let input_tokens = parsed.get("usage")
|
||||
.and_then(|u| u.get("prompt_tokens"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
let output_tokens = parsed.get("usage")
|
||||
.and_then(|u| u.get("completion_tokens"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, input_tokens, output_tokens,
|
||||
@@ -94,13 +96,14 @@ pub async fn chat_completions(
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
// SSE 流的 usage 统计在 service 层异步处理
|
||||
// 这里先记录一个占位记录,实际值会在流结束后更新
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, 0, 0,
|
||||
None, "success", None,
|
||||
None, "streaming", None,
|
||||
).await?;
|
||||
|
||||
// 流式响应: 直接转发 axum::body::Body
|
||||
let response = axum::response::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||||
@@ -126,7 +129,7 @@ pub async fn list_tasks(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Query(query): Query<RelayTaskQuery>,
|
||||
) -> SaasResult<Json<Vec<RelayTaskInfo>>> {
|
||||
) -> SaasResult<Json<crate::common::PaginatedResponse<RelayTaskInfo>>> {
|
||||
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
||||
}
|
||||
|
||||
@@ -150,11 +153,11 @@ pub async fn list_available_models(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
let providers = model_service::list_providers(&state.db).await?;
|
||||
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
|
||||
let enabled_provider_ids: std::collections::HashSet<String> =
|
||||
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
||||
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||
let available: Vec<serde_json::Value> = models.into_iter()
|
||||
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
||||
.map(|m| {
|
||||
@@ -191,17 +194,10 @@ pub async fn retry_task(
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(&task.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
// 读取原始请求体
|
||||
let request_body: Option<String> = sqlx::query_scalar(
|
||||
"SELECT request_body FROM relay_tasks WHERE id = ?1"
|
||||
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -219,23 +215,27 @@ pub async fn retry_task(
|
||||
let max_attempts = task.max_attempts as u32;
|
||||
let config = state.config.read().await;
|
||||
let base_delay_ms = config.relay.retry_delay_ms;
|
||||
let enc_key = config.api_key_encryption_key()
|
||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||
|
||||
// 重置任务状态为 queued 以允许新的 processing
|
||||
sqlx::query(
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?1"
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
// 异步执行重试
|
||||
// 异步执行重试 (Key Pool 自动选择)
|
||||
let db = state.db.clone();
|
||||
let task_id = id.clone();
|
||||
let provider_id = task.provider_id.clone();
|
||||
let base_url = provider.base_url.clone();
|
||||
tokio::spawn(async move {
|
||||
match service::execute_relay(
|
||||
&db, &task_id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &body, stream,
|
||||
max_attempts, base_delay_ms,
|
||||
&db, &task_id, &provider_id,
|
||||
&base_url, &body, stream,
|
||||
max_attempts, base_delay_ms, &enc_key,
|
||||
).await {
|
||||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||||
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
||||
@@ -247,3 +247,96 @@ pub async fn retry_task(
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
||||
}
|
||||
|
||||
// ============ Key Pool 管理 (admin only) ============
|
||||
|
||||
/// GET /api/v1/providers/:provider_id/keys
|
||||
pub async fn list_provider_keys(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(provider_id): Path<String>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?;
|
||||
Ok(Json(keys))
|
||||
}
|
||||
|
||||
/// POST /api/v1/providers/:provider_id/keys
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct AddKeyRequest {
|
||||
pub key_label: String,
|
||||
pub key_value: String,
|
||||
#[serde(default)]
|
||||
pub priority: i32,
|
||||
pub max_rpm: Option<i64>,
|
||||
pub max_tpm: Option<i64>,
|
||||
pub quota_reset_interval: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn add_provider_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(provider_id): Path<String>,
|
||||
Json(req): Json<AddKeyRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
|
||||
if req.key_label.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("key_label 不能为空".into()));
|
||||
}
|
||||
if req.key_value.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("key_value 不能为空".into()));
|
||||
}
|
||||
|
||||
let key_id = super::key_pool::add_provider_key(
|
||||
&state.db, &provider_id, &req.key_label, &req.key_value,
|
||||
req.priority, req.max_rpm, req.max_tpm,
|
||||
req.quota_reset_interval.as_deref(),
|
||||
).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "provider_key.add", "provider_key", &key_id,
|
||||
Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "key_id": key_id})))
|
||||
}
|
||||
|
||||
/// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct ToggleKeyRequest {
|
||||
pub active: bool,
|
||||
}
|
||||
|
||||
pub async fn toggle_provider_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path((provider_id, key_id)): Path<(String, String)>,
|
||||
Json(req): Json<ToggleKeyRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
|
||||
super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "provider_key.toggle", "provider_key", &key_id,
|
||||
Some(serde_json::json!({"provider_id": provider_id, "active": req.active})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/providers/:provider_id/keys/:key_id
|
||||
pub async fn delete_provider_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path((provider_id, key_id)): Path<(String, String)>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
|
||||
super::key_pool::delete_provider_key(&state.db, &key_id).await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "provider_key.delete", "provider_key", &key_id,
|
||||
Some(serde_json::json!({"provider_id": provider_id})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user