feat(saas): Phase 2 — 模型配置模块

- Provider CRUD (列表/详情/创建/更新/删除)
- Model CRUD (列表/详情/创建/更新/删除)
- Account API Key 管理 (创建/轮换/撤销/掩码显示)
- Usage 统计 (总量/按模型/按天, 支持时间/供应商/模型过滤)
- 权限控制 (provider:manage, model:manage)
- 3 个新集成测试覆盖 providers/models/keys
This commit is contained in:
iven
2026-03-27 12:46:59 +08:00
parent a2f8112d69
commit fec64af565
6 changed files with 949 additions and 66 deletions

View File

@@ -43,6 +43,7 @@ fn build_router(state: AppState) -> axum::Router {
let protected_routes = zclaw_saas::auth::protected_routes()
.merge(zclaw_saas::account::routes())
.merge(zclaw_saas::model_config::routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,

View File

@@ -0,0 +1,206 @@
//! 模型配置 HTTP 处理器
use axum::{
extract::{Extension, Path, Query, State},
http::StatusCode, Json,
};
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::log_operation;
use super::{types::*, service};
// ============ Providers ============
/// GET /api/v1/providers
pub async fn list_providers(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ProviderInfo>>> {
service::list_providers(&state.db).await.map(Json)
}
/// GET /api/v1/providers/:id
pub async fn get_provider(
State(state): State<AppState>,
Path(id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ProviderInfo>> {
service::get_provider(&state.db, &id).await.map(Json)
}
/// POST /api/v1/providers (admin only)
pub async fn create_provider(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateProviderRequest>,
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
if !ctx.permissions.contains(&"provider:manage".to_string()) {
return Err(SaasError::Forbidden("需要 provider:manage 权限".into()));
}
let provider = service::create_provider(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
Some(serde_json::json!({"name": &req.name})), None).await?;
Ok((StatusCode::CREATED, Json(provider)))
}
/// PUT /api/v1/providers/:id (admin only)
pub async fn update_provider(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateProviderRequest>,
) -> SaasResult<Json<ProviderInfo>> {
if !ctx.permissions.contains(&"provider:manage".to_string()) {
return Err(SaasError::Forbidden("需要 provider:manage 权限".into()));
}
let provider = service::update_provider(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, None).await?;
Ok(Json(provider))
}
/// DELETE /api/v1/providers/:id (admin only)
pub async fn delete_provider(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
if !ctx.permissions.contains(&"provider:manage".to_string()) {
return Err(SaasError::Forbidden("需要 provider:manage 权限".into()));
}
service::delete_provider(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, None).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
// ============ Models ============
/// GET /api/v1/models?provider_id=xxx
pub async fn list_models(
State(state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_models(&state.db, provider_id).await.map(Json)
}
/// GET /api/v1/models/:id
pub async fn get_model(
State(state): State<AppState>,
Path(id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ModelInfo>> {
service::get_model(&state.db, &id).await.map(Json)
}
/// POST /api/v1/models (admin only)
pub async fn create_model(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateModelRequest>,
) -> SaasResult<(StatusCode, Json<ModelInfo>)> {
if !ctx.permissions.contains(&"model:manage".to_string()) {
return Err(SaasError::Forbidden("需要 model:manage 权限".into()));
}
let model = service::create_model(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id,
Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), None).await?;
Ok((StatusCode::CREATED, Json(model)))
}
/// PUT /api/v1/models/:id (admin only)
pub async fn update_model(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateModelRequest>,
) -> SaasResult<Json<ModelInfo>> {
if !ctx.permissions.contains(&"model:manage".to_string()) {
return Err(SaasError::Forbidden("需要 model:manage 权限".into()));
}
let model = service::update_model(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, None).await?;
Ok(Json(model))
}
/// DELETE /api/v1/models/:id (admin only)
pub async fn delete_model(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
if !ctx.permissions.contains(&"model:manage".to_string()) {
return Err(SaasError::Forbidden("需要 model:manage 权限".into()));
}
service::delete_model(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, None).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
// ============ Account API Keys ============
/// GET /api/v1/keys?provider_id=xxx
pub async fn list_api_keys(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json)
}
/// POST /api/v1/keys
pub async fn create_api_key(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateAccountApiKeyRequest>,
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?;
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
Some(serde_json::json!({"provider_id": &req.provider_id})), None).await?;
Ok((StatusCode::CREATED, Json(key)))
}
/// POST /api/v1/keys/:id/rotate
pub async fn rotate_api_key(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<RotateApiKeyRequest>,
) -> SaasResult<Json<serde_json::Value>> {
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?;
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, None).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
/// DELETE /api/v1/keys/:id
pub async fn revoke_api_key(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
service::revoke_account_api_key(&state.db, &id, &ctx.account_id).await?;
log_operation(&state.db, &ctx.account_id, "api_key.revoke", "api_key", &id, None, None).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
// ============ Usage ============
/// GET /api/v1/usage?from=...&to=...&provider_id=...&model_id=...
pub async fn get_usage(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(params): Query<UsageQuery>,
) -> SaasResult<Json<UsageStats>> {
service::get_usage_stats(&state.db, &ctx.account_id, &params).await.map(Json)
}
/// GET /api/v1/providers/:id/models (便捷路由)
pub async fn list_provider_models(
State(state): State<AppState>,
Path(provider_id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelInfo>>> {
service::list_models(&state.db, Some(&provider_id)).await.map(Json)
}

View File

@@ -1 +1,26 @@
//! 模型配置模块
pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{delete, get, post};
use crate::state::AppState;
/// 模型配置路由 (需要认证)
pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
// Providers
.route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider))
.route("/api/v1/providers/{id}", get(handlers::get_provider).put(handlers::update_provider).delete(handlers::delete_provider))
.route("/api/v1/providers/{id}/models", get(handlers::list_provider_models))
// Models
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
.route("/api/v1/models/{id}", get(handlers::get_model).put(handlers::update_model).delete(handlers::delete_model))
// Account API Keys
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
.route("/api/v1/keys/{id}", delete(handlers::revoke_api_key))
.route("/api/v1/keys/{id}/rotate", post(handlers::rotate_api_key))
// Usage
.route("/api/v1/usage", get(handlers::get_usage))
}

View File

@@ -0,0 +1,411 @@
//! 模型配置业务逻辑
use sqlx::SqlitePool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
// ============ Providers ============
pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers ORDER BY name"
)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
}).collect())
}
pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE id = ?1"
)
.bind(provider_id)
.fetch_optional(db)
.await?;
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
}
pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult<ProviderInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
}
sqlx::query(
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)"
)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key)
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
.execute(db).await?;
get_provider(db, &id).await
}
pub async fn update_provider(
db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest,
) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); }
if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); }
if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); }
if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); }
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); }
if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); }
if updates.is_empty() {
return get_provider(db, provider_id).await;
}
updates.push("updated_at = ?");
params.push(Box::new(now.clone()));
params.push(Box::new(provider_id.to_string()));
let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", "));
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
}
query.execute(db).await?;
get_provider(db, provider_id).await
}
pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = ?1")
.bind(provider_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
}
Ok(())
}
// ============ Models ============
pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
let sql = if provider_id.is_some() {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE provider_id = ?1 ORDER BY alias"
} else {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models ORDER BY provider_id, alias"
};
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
}).collect())
}
pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
// 验证 provider 存在
let provider = get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
// 检查 model 唯一性
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2"
)
.bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!(
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
)));
}
let ctx = req.context_window.unwrap_or(8192);
let max_out = req.max_output_tokens.unwrap_or(4096);
let streaming = req.supports_streaming.unwrap_or(true);
let vision = req.supports_vision.unwrap_or(false);
let pi = req.pricing_input.unwrap_or(0.0);
let po = req.pricing_output.unwrap_or(0.0);
sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)"
)
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
.execute(db).await?;
get_model(db, &id).await
}
pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE id = ?1"
)
.bind(model_id)
.fetch_optional(db)
.await?;
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at })
}
pub async fn update_model(
db: &SqlitePool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); }
if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); }
if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); }
if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); }
if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); }
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); }
if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); }
if updates.is_empty() {
return get_model(db, model_id).await;
}
updates.push("updated_at = ?");
params.push(Box::new(now.clone()));
params.push(Box::new(model_id.to_string()));
let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", "));
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
}
query.execute(db).await?;
get_model(db, model_id).await
}
pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = ?1")
.bind(model_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
}
Ok(())
}
// ============ Account API Keys ============
pub async fn list_account_api_keys(
db: &SqlitePool, account_id: &str, provider_id: Option<&str>,
) -> SaasResult<Vec<AccountApiKeyInfo>> {
let sql = if provider_id.is_some() {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC"
} else {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
};
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(sql)
.bind(account_id);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
let masked = mask_api_key(&key_value);
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
}).collect())
}
pub async fn create_account_api_key(
db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest,
) -> SaasResult<AccountApiKeyInfo> {
// 验证 provider 存在
get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let permissions = serde_json::to_string(&req.permissions)?;
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)"
)
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value)
.bind(&req.key_label).bind(&permissions).bind(&now)
.execute(db).await?;
let masked = mask_api_key(&req.key_value);
Ok(AccountApiKeyInfo {
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
created_at: now, masked_key: masked,
})
}
pub async fn rotate_account_api_key(
db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL"
)
.bind(new_key_value).bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
pub async fn revoke_account_api_key(
db: &SqlitePool, key_id: &str, account_id: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
)
.bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
// ============ Usage Statistics ============
pub async fn get_usage_stats(
db: &SqlitePool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> {
let mut where_clauses = vec!["account_id = ?".to_string()];
let mut params: Vec<String> = vec![account_id.to_string()];
if let Some(ref from) = query.from {
where_clauses.push("created_at >= ?".to_string());
params.push(from.clone());
}
if let Some(ref to) = query.to {
where_clauses.push("created_at <= ?".to_string());
params.push(to.clone());
}
if let Some(ref pid) = query.provider_id {
where_clauses.push("provider_id = ?".to_string());
params.push(pid.clone());
}
if let Some(ref mid) = query.model_id {
where_clauses.push("model_id = ?".to_string());
params.push(mid.clone());
}
let where_sql = where_clauses.join(" AND ");
// 总量统计
let total_sql = format!(
"SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {}", where_sql
);
let mut total_query = sqlx::query_as::<_, (i64, i64, i64)>(&total_sql);
for p in &params {
total_query = total_query.bind(p);
}
let (total_requests, total_input, total_output) = total_query.fetch_one(db).await?;
// 按模型统计
let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
where_sql
);
let mut by_model_query = sqlx::query_as::<_, (String, String, i64, i64, i64)>(&by_model_sql);
for p in &params {
by_model_query = by_model_query.bind(p);
}
let by_model_rows = by_model_query.fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|(provider_id, model_id, count, input, output)| {
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
}).collect();
// 按天统计 (最近 30 天)
let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339();
let daily_sql = format!(
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = ?1 AND created_at >= ?2
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
);
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
.bind(account_id).bind(&from_30d)
.fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|(date, count, input, output)| {
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
}).collect();
Ok(UsageStats {
total_requests,
total_input_tokens: total_input,
total_output_tokens: total_output,
by_model,
by_day,
})
}
pub async fn record_usage(
db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str,
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
status: &str, error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)"
)
.bind(account_id).bind(provider_id).bind(model_id)
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
.bind(status).bind(error_message).bind(&now)
.execute(db).await?;
Ok(())
}
// ============ Helpers ============
fn mask_api_key(key: &str) -> String {
if key.len() <= 8 {
return "*".repeat(key.len());
}
format!("{}...{}", &key[..4], &key[key.len()-4..])
}

View File

@@ -0,0 +1,172 @@
//! 模型配置类型定义
use serde::{Deserialize, Serialize};
// --- Provider ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderInfo {
pub id: String,
pub name: String,
pub display_name: String,
pub base_url: String,
pub api_protocol: String,
pub enabled: bool,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateProviderRequest {
pub name: String,
pub display_name: String,
pub base_url: String,
#[serde(default = "default_protocol")]
pub api_protocol: String,
pub api_key: Option<String>,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
}
fn default_protocol() -> String { "openai".into() }
#[derive(Debug, Deserialize)]
pub struct UpdateProviderRequest {
pub display_name: Option<String>,
pub base_url: Option<String>,
pub api_protocol: Option<String>,
pub api_key: Option<String>,
pub enabled: Option<bool>,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
}
// --- Model ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: i64,
pub max_output_tokens: i64,
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub pricing_input: f64,
pub pricing_output: f64,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateModelRequest {
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: Option<i64>,
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
pub pricing_input: Option<f64>,
pub pricing_output: Option<f64>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateModelRequest {
pub alias: Option<String>,
pub context_window: Option<i64>,
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
pub enabled: Option<bool>,
pub pricing_input: Option<f64>,
pub pricing_output: Option<f64>,
}
// --- Account API Key ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountApiKeyInfo {
pub id: String,
pub provider_id: String,
pub key_label: Option<String>,
pub permissions: Vec<String>,
pub enabled: bool,
pub last_used_at: Option<String>,
pub created_at: String,
pub masked_key: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateAccountApiKeyRequest {
pub provider_id: String,
pub key_value: String,
pub key_label: Option<String>,
#[serde(default)]
pub permissions: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct RotateApiKeyRequest {
pub new_key_value: String,
}
// --- Usage ---
#[derive(Debug, Serialize)]
pub struct UsageStats {
pub total_requests: i64,
pub total_input_tokens: i64,
pub total_output_tokens: i64,
pub by_model: Vec<ModelUsage>,
pub by_day: Vec<DailyUsage>,
}
#[derive(Debug, Serialize)]
pub struct ModelUsage {
pub provider_id: String,
pub model_id: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Debug, Serialize)]
pub struct DailyUsage {
pub date: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Debug, Deserialize)]
pub struct UsageQuery {
pub from: Option<String>,
pub to: Option<String>,
pub provider_id: Option<String>,
pub model_id: Option<String>,
}
// --- Seed Data ---
#[derive(Debug, Deserialize)]
pub struct SeedProvider {
pub name: String,
pub display_name: String,
pub base_url: String,
pub models: Vec<SeedModel>,
}
#[derive(Debug, Deserialize)]
pub struct SeedModel {
pub id: String,
pub alias: String,
pub context_window: Option<i64>,
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
}

View File

@@ -1,4 +1,4 @@
//! Phase 1 集成测试
//! 集成测试 (Phase 1 + Phase 2)
use axum::{
body::Body,
@@ -21,6 +21,7 @@ async fn build_test_app() -> axum::Router {
let protected_routes = zclaw_saas::auth::protected_routes()
.merge(zclaw_saas::account::routes())
.merge(zclaw_saas::model_config::routes())
.layer(axum::middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
@@ -32,43 +33,45 @@ async fn build_test_app() -> axum::Router {
.with_state(state)
}
#[tokio::test]
async fn test_register_and_login() {
let app = build_test_app().await;
// 注册
let req = Request::builder()
/// 注册并登录,返回 JWT token
async fn register_and_login(app: &axum::Router, username: &str, email: &str) -> String {
let register_req = Request::builder()
.method("POST")
.uri("/api/v1/auth/register")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&json!({
"username": "testuser",
"email": "test@example.com",
"username": username,
"email": email,
"password": "password123"
})).unwrap()))
.unwrap();
app.clone().oneshot(register_req).await.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED);
// 登录
let req = Request::builder()
let login_req = Request::builder()
.method("POST")
.uri("/api/v1/auth/login")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&json!({
"username": "testuser",
"username": username,
"password": "password123"
})).unwrap()))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let resp = app.clone().oneshot(login_req).await.unwrap();
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(body.get("token").is_some());
assert_eq!(body["account"]["username"], "testuser");
body["token"].as_str().unwrap().to_string()
}
fn auth_header(token: &str) -> String {
format!("Bearer {}", token)
}
#[tokio::test]
async fn test_register_and_login() {
let app = build_test_app().await;
let token = register_and_login(&app, "testuser", "test@example.com").await;
assert!(!token.is_empty());
}
#[tokio::test]
@@ -119,21 +122,8 @@ async fn test_unauthorized_access() {
#[tokio::test]
async fn test_login_wrong_password() {
let app = build_test_app().await;
register_and_login(&app, "wrongpwd", "wrongpwd@example.com").await;
// 先注册
let req = Request::builder()
.method("POST")
.uri("/api/v1/auth/register")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&json!({
"username": "wrongpwd",
"email": "wrongpwd@example.com",
"password": "password123"
})).unwrap()))
.unwrap();
app.clone().oneshot(req).await.unwrap();
// 错误密码登录
let req = Request::builder()
.method("POST")
.uri("/api/v1/auth/login")
@@ -151,41 +141,14 @@ async fn test_login_wrong_password() {
#[tokio::test]
async fn test_full_authenticated_flow() {
let app = build_test_app().await;
// 注册 + 登录
let register_req = Request::builder()
.method("POST")
.uri("/api/v1/auth/register")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&json!({
"username": "fulltest",
"email": "full@example.com",
"password": "password123"
})).unwrap()))
.unwrap();
app.clone().oneshot(register_req).await.unwrap();
let login_req = Request::builder()
.method("POST")
.uri("/api/v1/auth/login")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&json!({
"username": "fulltest",
"password": "password123"
})).unwrap()))
.unwrap();
let resp = app.clone().oneshot(login_req).await.unwrap();
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
let token = body["token"].as_str().unwrap().to_string();
let token = register_and_login(&app, "fulltest", "full@example.com").await;
// 创建 API Token
let create_token_req = Request::builder()
.method("POST")
.uri("/api/v1/tokens")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.header("Authorization", auth_header(&token))
.body(Body::from(serde_json::to_string(&json!({
"name": "test-token",
"permissions": ["model:read", "relay:use"]
@@ -196,13 +159,13 @@ async fn test_full_authenticated_flow() {
assert_eq!(resp.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(!body["token"].is_null()); // 原始 token 仅创建时返回
assert!(!body["token"].is_null());
// 列出 Tokens
let list_req = Request::builder()
.method("GET")
.uri("/api/v1/tokens")
.header("Authorization", format!("Bearer {}", token))
.header("Authorization", auth_header(&token))
.body(Body::empty())
.unwrap();
@@ -213,10 +176,115 @@ async fn test_full_authenticated_flow() {
let logs_req = Request::builder()
.method("GET")
.uri("/api/v1/logs/operations")
.header("Authorization", format!("Bearer {}", token))
.header("Authorization", auth_header(&token))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(logs_req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
// ============ Phase 2: 模型配置测试 ============
#[tokio::test]
async fn test_providers_crud() {
let app = build_test_app().await;
// 注册 super_admin 角色用户 (通过直接插入角色权限)
let token = register_and_login(&app, "adminprov", "adminprov@example.com").await;
// 创建 provider (普通用户无权限 → 403)
let create_req = Request::builder()
.method("POST")
.uri("/api/v1/providers")
.header("Content-Type", "application/json")
.header("Authorization", auth_header(&token))
.body(Body::from(serde_json::to_string(&json!({
"name": "test-provider",
"display_name": "Test Provider",
"base_url": "https://api.example.com/v1"
})).unwrap()))
.unwrap();
let resp = app.clone().oneshot(create_req).await.unwrap();
// user 角色默认无 provider:manage 权限 → 403
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
// 列出 providers (只读权限 → 200)
let list_req = Request::builder()
.method("GET")
.uri("/api/v1/providers")
.header("Authorization", auth_header(&token))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(list_req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_models_list_and_usage() {
let app = build_test_app().await;
let token = register_and_login(&app, "modeluser", "modeluser@example.com").await;
// 列出模型 (空列表)
let list_req = Request::builder()
.method("GET")
.uri("/api/v1/models")
.header("Authorization", auth_header(&token))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(list_req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(body.is_array());
assert_eq!(body.as_array().unwrap().len(), 0);
// 查看用量统计
let usage_req = Request::builder()
.method("GET")
.uri("/api/v1/usage")
.header("Authorization", auth_header(&token))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(usage_req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["total_requests"], 0);
}
#[tokio::test]
async fn test_api_keys_lifecycle() {
let app = build_test_app().await;
let token = register_and_login(&app, "keyuser", "keyuser@example.com").await;
// 列出 keys (空)
let list_req = Request::builder()
.method("GET")
.uri("/api/v1/keys")
.header("Authorization", auth_header(&token))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(list_req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// 创建 key (需要已有 provider → 404 或由 service 层验证)
let create_req = Request::builder()
.method("POST")
.uri("/api/v1/keys")
.header("Content-Type", "application/json")
.header("Authorization", auth_header(&token))
.body(Body::from(serde_json::to_string(&json!({
"provider_id": "nonexistent",
"key_value": "sk-test-12345",
"key_label": "Test Key"
})).unwrap()))
.unwrap();
let resp = app.clone().oneshot(create_req).await.unwrap();
// provider 不存在 → 404
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}