chore: 提交所有工作进度 — SaaS 后端增强、Admin UI、桌面端集成

包含大量 SaaS 平台改进、Admin 管理后台更新、桌面端集成完善、
文档同步、测试文件重构等内容。为 QA 测试准备干净工作树。
This commit is contained in:
iven
2026-03-29 10:46:26 +08:00
parent 9a5fad2b59
commit 5fdf96c3f5
268 changed files with 22011 additions and 3886 deletions

View File

@@ -5,19 +5,24 @@ use axum::{
http::StatusCode, Json,
};
use crate::state::AppState;
use crate::error::SaasResult;
use crate::error::{SaasResult, SaasError};
use crate::auth::types::AuthContext;
use crate::auth::handlers::{log_operation, check_permission};
use crate::common::PaginatedResponse;
use super::{types::*, service};
// ============ Providers ============
/// GET /api/v1/providers
/// GET /api/v1/providers?enabled=true&page=1&page_size=20
pub async fn list_providers(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ProviderInfo>>> {
service::list_providers(&state.db).await.map(Json)
Query(params): Query<std::collections::HashMap<String, String>>,
) -> SaasResult<Json<PaginatedResponse<ProviderInfo>>> {
let page = params.get("page").and_then(|v| v.parse().ok());
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
let enabled_filter = params.get("enabled").and_then(|v| v.parse().ok());
service::list_providers(&state.db, page, page_size, enabled_filter).await.map(Json)
}
/// GET /api/v1/providers/:id
@@ -36,13 +41,17 @@ pub async fn create_provider(
Json(req): Json<CreateProviderRequest>,
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
check_permission(&ctx, "provider:manage")?;
let provider = service::create_provider(&state.db, &req).await?;
let config = state.config.read().await;
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
drop(config);
let provider = service::create_provider(&state.db, &req, &enc_key).await?;
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(provider)))
}
/// PUT /api/v1/providers/:id (admin only)
/// PATCH /api/v1/providers/:id (admin only)
pub async fn update_provider(
State(state): State<AppState>,
Path(id): Path<String>,
@@ -50,7 +59,11 @@ pub async fn update_provider(
Json(req): Json<UpdateProviderRequest>,
) -> SaasResult<Json<ProviderInfo>> {
check_permission(&ctx, "provider:manage")?;
let provider = service::update_provider(&state.db, &id, &req).await?;
let config = state.config.read().await;
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
drop(config);
let provider = service::update_provider(&state.db, &id, &req, &enc_key).await?;
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(provider))
}
@@ -69,14 +82,16 @@ pub async fn delete_provider(
// ============ Models ============
/// GET /api/v1/models?provider_id=xxx
/// GET /api/v1/models?provider_id=xxx&page=1&page_size=20
pub async fn list_models(
State(state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelInfo>>> {
) -> SaasResult<Json<PaginatedResponse<ModelInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_models(&state.db, provider_id).await.map(Json)
let page = params.get("page").and_then(|v| v.parse().ok());
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
service::list_models(&state.db, provider_id, page, page_size).await.map(Json)
}
/// GET /api/v1/models/:id
@@ -101,7 +116,7 @@ pub async fn create_model(
Ok((StatusCode::CREATED, Json(model)))
}
/// PUT /api/v1/models/:id (admin only)
/// PATCH /api/v1/models/:id (admin only)
pub async fn update_model(
State(state): State<AppState>,
Path(id): Path<String>,
@@ -128,14 +143,16 @@ pub async fn delete_model(
// ============ Account API Keys ============
/// GET /api/v1/keys?provider_id=xxx
/// GET /api/v1/keys?provider_id=xxx&page=1&page_size=20
pub async fn list_api_keys(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
) -> SaasResult<Json<PaginatedResponse<AccountApiKeyInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json)
let page = params.get("page").and_then(|v| v.parse().ok());
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id, page, page_size).await.map(Json)
}
/// POST /api/v1/keys
@@ -144,7 +161,11 @@ pub async fn create_api_key(
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateAccountApiKeyRequest>,
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?;
let config = state.config.read().await;
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
drop(config);
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req, &enc_key).await?;
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(key)))
@@ -157,7 +178,11 @@ pub async fn rotate_api_key(
Extension(ctx): Extension<AuthContext>,
Json(req): Json<RotateApiKeyRequest>,
) -> SaasResult<Json<serde_json::Value>> {
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?;
let config = state.config.read().await;
let enc_key = config.api_key_encryption_key()
.map_err(|e| SaasError::Internal(e.to_string()))?;
drop(config);
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value, &enc_key).await?;
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
@@ -189,6 +214,6 @@ pub async fn list_provider_models(
State(state): State<AppState>,
Path(provider_id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelInfo>>> {
service::list_models(&state.db, Some(&provider_id)).await.map(Json)
) -> SaasResult<Json<PaginatedResponse<ModelInfo>>> {
service::list_models(&state.db, Some(&provider_id), None, None).await.map(Json)
}

View File

@@ -12,15 +12,15 @@ pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
// Providers
.route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider))
.route("/api/v1/providers/{id}", get(handlers::get_provider).put(handlers::update_provider).delete(handlers::delete_provider))
.route("/api/v1/providers/{id}/models", get(handlers::list_provider_models))
.route("/api/v1/providers/:id", get(handlers::get_provider).patch(handlers::update_provider).delete(handlers::delete_provider))
.route("/api/v1/providers/:id/models", get(handlers::list_provider_models))
// Models
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
.route("/api/v1/models/{id}", get(handlers::get_model).put(handlers::update_model).delete(handlers::delete_model))
.route("/api/v1/models/:id", get(handlers::get_model).patch(handlers::update_model).delete(handlers::delete_model))
// Account API Keys
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
.route("/api/v1/keys/{id}", delete(handlers::revoke_api_key))
.route("/api/v1/keys/{id}/rotate", post(handlers::rotate_api_key))
.route("/api/v1/keys/:id", delete(handlers::revoke_api_key))
.route("/api/v1/keys/:id/rotate", post(handlers::rotate_api_key))
// Usage
.route("/api/v1/usage", get(handlers::get_usage))
}

View File

@@ -1,30 +1,61 @@
//! 模型配置业务逻辑
use sqlx::SqlitePool;
use sqlx::{PgPool, Row};
use crate::error::{SaasError, SaasResult};
use crate::common::{PaginatedResponse, normalize_pagination};
use crate::crypto;
use super::types::*;
// ============ Providers ============
pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers ORDER BY name"
)
.fetch_all(db)
.await?;
pub async fn list_providers(
db: &PgPool, page: Option<u32>, page_size: Option<u32>, enabled_filter: Option<bool>,
) -> SaasResult<PaginatedResponse<ProviderInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
let (count_sql, data_sql) = if enabled_filter.is_some() {
(
"SELECT COUNT(*) FROM providers WHERE enabled = $1",
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3",
)
} else {
(
"SELECT COUNT(*) FROM providers",
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers ORDER BY name LIMIT $1 OFFSET $2",
)
};
let total: (i64,) = if let Some(en) = enabled_filter {
sqlx::query_as(count_sql).bind(en).fetch_one(db).await?
} else {
sqlx::query_as(count_sql).fetch_one(db).await?
};
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
if let Some(en) = enabled_filter {
sqlx::query_as(data_sql)
.bind(en).bind(ps as i64).bind(offset)
.fetch_all(db).await?
} else {
sqlx::query_as(data_sql)
.bind(ps as i64).bind(offset)
.fetch_all(db).await?
};
let items = rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
}).collect())
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<ProviderInfo> {
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE id = ?1"
FROM providers WHERE id = $1"
)
.bind(provider_id)
.fetch_optional(db)
@@ -36,22 +67,33 @@ pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<Prov
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
}
pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult<ProviderInfo> {
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1")
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
}
// 加密 API Key 后存储
let encrypted_api_key = if let Some(ref key) = req.api_key {
if key.is_empty() {
String::new()
} else {
crypto::encrypt_value(key, enc_key)?
}
} else {
String::new()
};
sqlx::query(
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)"
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
.execute(db).await?;
@@ -59,29 +101,34 @@ pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> Sa
}
pub async fn update_provider(
db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest,
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx = 1;
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); }
if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); }
if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); }
if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); }
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); }
if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); }
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_key {
let encrypted = if v.is_empty() { String::new() } else { crypto::encrypt_value(v, enc_key)? };
updates.push(format!("api_key = ${}", param_idx)); params.push(Box::new(encrypted)); param_idx += 1;
}
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() {
return get_provider(db, provider_id).await;
}
updates.push("updated_at = ?");
updates.push(format!("updated_at = ${}", param_idx));
params.push(Box::new(now.clone()));
param_idx += 1;
params.push(Box::new(provider_id.to_string()));
let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", "));
let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
@@ -91,8 +138,8 @@ pub async fn update_provider(
get_provider(db, provider_id).await
}
pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = ?1")
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
.bind(provider_id).execute(db).await?;
if result.rows_affected() == 0 {
@@ -103,27 +150,45 @@ pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<(
// ============ Models ============
pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
let sql = if provider_id.is_some() {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE provider_id = ?1 ORDER BY alias"
pub async fn list_models(
db: &PgPool, provider_id: Option<&str>, page: Option<u32>, page_size: Option<u32>,
) -> SaasResult<PaginatedResponse<ModelInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
)
} else {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models ORDER BY provider_id, alias"
(
"SELECT COUNT(*) FROM models",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
)
};
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql);
let total: (i64,) = if let Some(pid) = provider_id {
sqlx::query_as(count_sql).bind(pid).fetch_one(db).await?
} else {
sqlx::query_as(count_sql).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(data_sql);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
let items = rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
}).collect())
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
// 验证 provider 存在
let provider = get_provider(db, &req.provider_id).await?;
@@ -132,7 +197,7 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
// 检查 model 唯一性
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2"
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
)
.bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?;
@@ -152,7 +217,7 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)"
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
)
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
@@ -161,11 +226,11 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
get_model(db, &id).await
}
pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo> {
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE id = ?1"
FROM models WHERE id = $1"
)
.bind(model_id)
.fetch_optional(db)
@@ -178,30 +243,32 @@ pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo>
}
pub async fn update_model(
db: &SqlitePool, model_id: &str, req: &UpdateModelRequest,
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx = 1;
if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); }
if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); }
if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); }
if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); }
if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); }
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); }
if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); }
if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() {
return get_model(db, model_id).await;
}
updates.push("updated_at = ?");
updates.push(format!("updated_at = ${}", param_idx));
params.push(Box::new(now.clone()));
param_idx += 1;
params.push(Box::new(model_id.to_string()));
let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", "));
let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
@@ -211,8 +278,8 @@ pub async fn update_model(
get_model(db, model_id).await
}
pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = ?1")
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = $1")
.bind(model_id).execute(db).await?;
if result.rows_affected() == 0 {
@@ -224,32 +291,52 @@ pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
// ============ Account API Keys ============
pub async fn list_account_api_keys(
db: &SqlitePool, account_id: &str, provider_id: Option<&str>,
) -> SaasResult<Vec<AccountApiKeyInfo>> {
let sql = if provider_id.is_some() {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC"
db: &PgPool, account_id: &str, provider_id: Option<&str>,
page: Option<u32>, page_size: Option<u32>,
) -> SaasResult<PaginatedResponse<AccountApiKeyInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
// Build COUNT and data queries based on whether provider_id is provided
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL",
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4",
)
} else {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
(
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL",
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3",
)
};
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(sql)
let total: (i64,) = if provider_id.is_some() {
let mut q = sqlx::query_as(count_sql).bind(account_id);
if let Some(pid) = provider_id { q = q.bind(pid); }
q.fetch_one(db).await?
} else {
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(data_sql)
.bind(account_id);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
let items = rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
let masked = mask_api_key(&key_value);
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
}).collect())
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn create_account_api_key(
db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest,
db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32],
) -> SaasResult<AccountApiKeyInfo> {
// 验证 provider 存在
get_provider(db, &req.provider_id).await?;
@@ -258,11 +345,14 @@ pub async fn create_account_api_key(
let now = chrono::Utc::now().to_rfc3339();
let permissions = serde_json::to_string(&req.permissions)?;
// 加密 key_value 后存储
let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?;
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)"
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
)
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value)
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
.bind(&req.key_label).bind(&permissions).bind(&now)
.execute(db).await?;
@@ -275,13 +365,14 @@ pub async fn create_account_api_key(
}
pub async fn rotate_account_api_key(
db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str,
db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32],
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?;
let result = sqlx::query(
"UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL"
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
)
.bind(new_key_value).bind(&now).bind(key_id).bind(account_id)
.bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
@@ -291,11 +382,11 @@ pub async fn rotate_account_api_key(
}
pub async fn revoke_account_api_key(
db: &SqlitePool, key_id: &str, account_id: &str,
db: &PgPool, key_id: &str, account_id: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
)
.bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
@@ -309,25 +400,30 @@ pub async fn revoke_account_api_key(
// ============ Usage Statistics ============
pub async fn get_usage_stats(
db: &SqlitePool, account_id: &str, query: &UsageQuery,
db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> {
let mut where_clauses = vec!["account_id = ?".to_string()];
let mut param_idx = 1;
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
let mut params: Vec<String> = vec![account_id.to_string()];
param_idx += 1;
if let Some(ref from) = query.from {
where_clauses.push("created_at >= ?".to_string());
where_clauses.push(format!("created_at >= ${}", param_idx));
params.push(from.clone());
param_idx += 1;
}
if let Some(ref to) = query.to {
where_clauses.push("created_at <= ?".to_string());
where_clauses.push(format!("created_at <= ${}", param_idx));
params.push(to.clone());
param_idx += 1;
}
if let Some(ref pid) = query.provider_id {
where_clauses.push("provider_id = ?".to_string());
where_clauses.push(format!("provider_id = ${}", param_idx));
params.push(pid.clone());
param_idx += 1;
}
if let Some(ref mid) = query.model_id {
where_clauses.push("model_id = ?".to_string());
where_clauses.push(format!("model_id = ${}", param_idx));
params.push(mid.clone());
}
@@ -335,18 +431,21 @@ pub async fn get_usage_stats(
// 总量统计
let total_sql = format!(
"SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
"SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {}", where_sql
);
let mut total_query = sqlx::query_as::<_, (i64, i64, i64)>(&total_sql);
let mut total_query = sqlx::query(&total_sql);
for p in &params {
total_query = total_query.bind(p);
}
let (total_requests, total_input, total_output) = total_query.fetch_one(db).await?;
let row = total_query.fetch_one(db).await?;
let total_requests: i64 = row.try_get(0).unwrap_or(0);
let total_input: i64 = row.try_get(1).unwrap_or(0);
let total_output: i64 = row.try_get(2).unwrap_or(0);
// 按模型统计
let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
"SELECT provider_id, model_id, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
where_sql
);
@@ -360,21 +459,27 @@ pub async fn get_usage_stats(
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
}).collect();
// 按天统计 (最近 30 天)
let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339();
// 按天统计 (使用 days 参数或默认 30 天)
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
let from_days = (chrono::Utc::now() - chrono::Duration::days(days)).format("%Y-%m-%d").to_string() + "T00:00:00Z";
let daily_sql = format!(
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = ?1 AND created_at >= ?2
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
"SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = $1 AND created_at >= $2
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3"
);
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
.bind(account_id).bind(&from_30d)
.bind(account_id).bind(&from_days).bind(days as i32)
.fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|(date, count, input, output)| {
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
}).collect();
// 按 group_by 过滤返回
let group_by = query.group_by.as_deref();
let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] };
let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] };
Ok(UsageStats {
total_requests,
total_input_tokens: total_input,
@@ -385,14 +490,14 @@ pub async fn get_usage_stats(
}
pub async fn record_usage(
db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str,
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
status: &str, error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)"
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(account_id).bind(provider_id).bind(model_id)
.bind(input_tokens).bind(output_tokens).bind(latency_ms)

View File

@@ -149,6 +149,10 @@ pub struct UsageQuery {
pub to: Option<String>,
pub provider_id: Option<String>,
pub model_id: Option<String>,
/// 聚合维度: "day" 或 "model"。不传则返回完整 UsageStats
pub group_by: Option<String>,
/// 最近 N 天
pub days: Option<i32>,
}
// --- Seed Data ---