feat(saas): Phase 2 — 模型配置模块
- Provider CRUD (列表/详情/创建/更新/删除) - Model CRUD (列表/详情/创建/更新/删除) - Account API Key 管理 (创建/轮换/撤销/掩码显示) - Usage 统计 (总量/按模型/按天, 支持时间/供应商/模型过滤) - 权限控制 (provider:manage, model:manage) - 3 个新集成测试覆盖 providers/models/keys
This commit is contained in:
@@ -43,6 +43,7 @@ fn build_router(state: AppState) -> axum::Router {
|
||||
|
||||
let protected_routes = zclaw_saas::auth::protected_routes()
|
||||
.merge(zclaw_saas::account::routes())
|
||||
.merge(zclaw_saas::model_config::routes())
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
|
||||
206
crates/zclaw-saas/src/model_config/handlers.rs
Normal file
206
crates/zclaw-saas/src/model_config/handlers.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
//! 模型配置 HTTP 处理器
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Path, Query, State},
|
||||
http::StatusCode, Json,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::auth::handlers::log_operation;
|
||||
use super::{types::*, service};
|
||||
|
||||
// ============ Providers ============
|
||||
|
||||
/// GET /api/v1/providers
|
||||
pub async fn list_providers(
|
||||
State(state): State<AppState>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<ProviderInfo>>> {
|
||||
service::list_providers(&state.db).await.map(Json)
|
||||
}
|
||||
|
||||
/// GET /api/v1/providers/:id
|
||||
pub async fn get_provider(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<ProviderInfo>> {
|
||||
service::get_provider(&state.db, &id).await.map(Json)
|
||||
}
|
||||
|
||||
/// POST /api/v1/providers (admin only)
|
||||
pub async fn create_provider(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateProviderRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
|
||||
if !ctx.permissions.contains(&"provider:manage".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 provider:manage 权限".into()));
|
||||
}
|
||||
let provider = service::create_provider(&state.db, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
|
||||
Some(serde_json::json!({"name": &req.name})), None).await?;
|
||||
Ok((StatusCode::CREATED, Json(provider)))
|
||||
}
|
||||
|
||||
/// PUT /api/v1/providers/:id (admin only)
|
||||
pub async fn update_provider(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<UpdateProviderRequest>,
|
||||
) -> SaasResult<Json<ProviderInfo>> {
|
||||
if !ctx.permissions.contains(&"provider:manage".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 provider:manage 权限".into()));
|
||||
}
|
||||
let provider = service::update_provider(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, None).await?;
|
||||
Ok(Json(provider))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/providers/:id (admin only)
|
||||
pub async fn delete_provider(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
if !ctx.permissions.contains(&"provider:manage".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 provider:manage 权限".into()));
|
||||
}
|
||||
service::delete_provider(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, None).await?;
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
// ============ Models ============
|
||||
|
||||
/// GET /api/v1/models?provider_id=xxx
|
||||
pub async fn list_models(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<ModelInfo>>> {
|
||||
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
||||
service::list_models(&state.db, provider_id).await.map(Json)
|
||||
}
|
||||
|
||||
/// GET /api/v1/models/:id
|
||||
pub async fn get_model(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<ModelInfo>> {
|
||||
service::get_model(&state.db, &id).await.map(Json)
|
||||
}
|
||||
|
||||
/// POST /api/v1/models (admin only)
|
||||
pub async fn create_model(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateModelRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<ModelInfo>)> {
|
||||
if !ctx.permissions.contains(&"model:manage".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 model:manage 权限".into()));
|
||||
}
|
||||
let model = service::create_model(&state.db, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id,
|
||||
Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), None).await?;
|
||||
Ok((StatusCode::CREATED, Json(model)))
|
||||
}
|
||||
|
||||
/// PUT /api/v1/models/:id (admin only)
|
||||
pub async fn update_model(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<UpdateModelRequest>,
|
||||
) -> SaasResult<Json<ModelInfo>> {
|
||||
if !ctx.permissions.contains(&"model:manage".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 model:manage 权限".into()));
|
||||
}
|
||||
let model = service::update_model(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, None).await?;
|
||||
Ok(Json(model))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/models/:id (admin only)
|
||||
pub async fn delete_model(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
if !ctx.permissions.contains(&"model:manage".to_string()) {
|
||||
return Err(SaasError::Forbidden("需要 model:manage 权限".into()));
|
||||
}
|
||||
service::delete_model(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, None).await?;
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
// ============ Account API Keys ============
|
||||
|
||||
/// GET /api/v1/keys?provider_id=xxx
|
||||
pub async fn list_api_keys(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
|
||||
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
||||
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json)
|
||||
}
|
||||
|
||||
/// POST /api/v1/keys
|
||||
pub async fn create_api_key(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateAccountApiKeyRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
|
||||
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
|
||||
Some(serde_json::json!({"provider_id": &req.provider_id})), None).await?;
|
||||
Ok((StatusCode::CREATED, Json(key)))
|
||||
}
|
||||
|
||||
/// POST /api/v1/keys/:id/rotate
|
||||
pub async fn rotate_api_key(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<RotateApiKeyRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, None).await?;
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/keys/:id
|
||||
pub async fn revoke_api_key(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
service::revoke_account_api_key(&state.db, &id, &ctx.account_id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "api_key.revoke", "api_key", &id, None, None).await?;
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
// ============ Usage ============
|
||||
|
||||
/// GET /api/v1/usage?from=...&to=...&provider_id=...&model_id=...
|
||||
pub async fn get_usage(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Query(params): Query<UsageQuery>,
|
||||
) -> SaasResult<Json<UsageStats>> {
|
||||
service::get_usage_stats(&state.db, &ctx.account_id, ¶ms).await.map(Json)
|
||||
}
|
||||
|
||||
/// GET /api/v1/providers/:id/models (便捷路由)
|
||||
pub async fn list_provider_models(
|
||||
State(state): State<AppState>,
|
||||
Path(provider_id): Path<String>,
|
||||
_ctx: Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<ModelInfo>>> {
|
||||
service::list_models(&state.db, Some(&provider_id)).await.map(Json)
|
||||
}
|
||||
@@ -1 +1,26 @@
|
||||
//! 模型配置模块
|
||||
|
||||
pub mod types;
|
||||
pub mod service;
|
||||
pub mod handlers;
|
||||
|
||||
use axum::routing::{delete, get, post};
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 模型配置路由 (需要认证)
|
||||
pub fn routes() -> axum::Router<AppState> {
|
||||
axum::Router::new()
|
||||
// Providers
|
||||
.route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider))
|
||||
.route("/api/v1/providers/{id}", get(handlers::get_provider).put(handlers::update_provider).delete(handlers::delete_provider))
|
||||
.route("/api/v1/providers/{id}/models", get(handlers::list_provider_models))
|
||||
// Models
|
||||
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
|
||||
.route("/api/v1/models/{id}", get(handlers::get_model).put(handlers::update_model).delete(handlers::delete_model))
|
||||
// Account API Keys
|
||||
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
|
||||
.route("/api/v1/keys/{id}", delete(handlers::revoke_api_key))
|
||||
.route("/api/v1/keys/{id}/rotate", post(handlers::rotate_api_key))
|
||||
// Usage
|
||||
.route("/api/v1/usage", get(handlers::get_usage))
|
||||
}
|
||||
|
||||
411
crates/zclaw-saas/src/model_config/service.rs
Normal file
411
crates/zclaw-saas/src/model_config/service.rs
Normal file
@@ -0,0 +1,411 @@
|
||||
//! 模型配置业务逻辑
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
|
||||
// ============ Providers ============
|
||||
|
||||
pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
|
||||
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||
FROM providers ORDER BY name"
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
|
||||
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
||||
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||
FROM providers WHERE id = ?1"
|
||||
)
|
||||
.bind(provider_id)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
|
||||
|
||||
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
|
||||
}
|
||||
|
||||
pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult<ProviderInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
// 检查名称唯一性
|
||||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1")
|
||||
.bind(&req.name).fetch_optional(db).await?;
|
||||
if existing.is_some() {
|
||||
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)"
|
||||
)
|
||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key)
|
||||
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
get_provider(db, &id).await
|
||||
}
|
||||
|
||||
pub async fn update_provider(
|
||||
db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest,
|
||||
) -> SaasResult<ProviderInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||
|
||||
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_provider(db, provider_id).await;
|
||||
}
|
||||
|
||||
updates.push("updated_at = ?");
|
||||
params.push(Box::new(now.clone()));
|
||||
params.push(Box::new(provider_id.to_string()));
|
||||
|
||||
let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", "));
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(format!("{}", p));
|
||||
}
|
||||
query.execute(db).await?;
|
||||
|
||||
get_provider(db, provider_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM providers WHERE id = ?1")
|
||||
.bind(provider_id).execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============ Models ============
|
||||
|
||||
pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
|
||||
let sql = if provider_id.is_some() {
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models WHERE provider_id = ?1 ORDER BY alias"
|
||||
} else {
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models ORDER BY provider_id, alias"
|
||||
};
|
||||
|
||||
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql);
|
||||
if let Some(pid) = provider_id {
|
||||
query = query.bind(pid);
|
||||
}
|
||||
|
||||
let rows = query.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
|
||||
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
||||
// 验证 provider 存在
|
||||
let provider = get_provider(db, &req.provider_id).await?;
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
// 检查 model 唯一性
|
||||
let existing: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2"
|
||||
)
|
||||
.bind(&req.provider_id).bind(&req.model_id)
|
||||
.fetch_optional(db).await?;
|
||||
|
||||
if existing.is_some() {
|
||||
return Err(SaasError::AlreadyExists(format!(
|
||||
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
|
||||
)));
|
||||
}
|
||||
|
||||
let ctx = req.context_window.unwrap_or(8192);
|
||||
let max_out = req.max_output_tokens.unwrap_or(4096);
|
||||
let streaming = req.supports_streaming.unwrap_or(true);
|
||||
let vision = req.supports_vision.unwrap_or(false);
|
||||
let pi = req.pricing_input.unwrap_or(0.0);
|
||||
let po = req.pricing_output.unwrap_or(0.0);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)"
|
||||
)
|
||||
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
|
||||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
get_model(db, &id).await
|
||||
}
|
||||
|
||||
pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models WHERE id = ?1"
|
||||
)
|
||||
.bind(model_id)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
||||
|
||||
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at })
|
||||
}
|
||||
|
||||
pub async fn update_model(
|
||||
db: &SqlitePool, model_id: &str, req: &UpdateModelRequest,
|
||||
) -> SaasResult<ModelInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||
|
||||
if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_model(db, model_id).await;
|
||||
}
|
||||
|
||||
updates.push("updated_at = ?");
|
||||
params.push(Box::new(now.clone()));
|
||||
params.push(Box::new(model_id.to_string()));
|
||||
|
||||
let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", "));
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(format!("{}", p));
|
||||
}
|
||||
query.execute(db).await?;
|
||||
|
||||
get_model(db, model_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM models WHERE id = ?1")
|
||||
.bind(model_id).execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============ Account API Keys ============
|
||||
|
||||
pub async fn list_account_api_keys(
|
||||
db: &SqlitePool, account_id: &str, provider_id: Option<&str>,
|
||||
) -> SaasResult<Vec<AccountApiKeyInfo>> {
|
||||
let sql = if provider_id.is_some() {
|
||||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
||||
FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
} else {
|
||||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
||||
FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
};
|
||||
|
||||
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(sql)
|
||||
.bind(account_id);
|
||||
if let Some(pid) = provider_id {
|
||||
query = query.bind(pid);
|
||||
}
|
||||
|
||||
let rows = query.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
let masked = mask_api_key(&key_value);
|
||||
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn create_account_api_key(
|
||||
db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest,
|
||||
) -> SaasResult<AccountApiKeyInfo> {
|
||||
// 验证 provider 存在
|
||||
get_provider(db, &req.provider_id).await?;
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let permissions = serde_json::to_string(&req.permissions)?;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value)
|
||||
.bind(&req.key_label).bind(&permissions).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
let masked = mask_api_key(&req.key_value);
|
||||
Ok(AccountApiKeyInfo {
|
||||
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
|
||||
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
|
||||
created_at: now, masked_key: masked,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn rotate_account_api_key(
|
||||
db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let result = sqlx::query(
|
||||
"UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(new_key_value).bind(&now).bind(key_id).bind(account_id)
|
||||
.execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn revoke_account_api_key(
|
||||
db: &SqlitePool, key_id: &str, account_id: &str,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let result = sqlx::query(
|
||||
"UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(&now).bind(key_id).bind(account_id)
|
||||
.execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============ Usage Statistics ============
|
||||
|
||||
pub async fn get_usage_stats(
|
||||
db: &SqlitePool, account_id: &str, query: &UsageQuery,
|
||||
) -> SaasResult<UsageStats> {
|
||||
let mut where_clauses = vec!["account_id = ?".to_string()];
|
||||
let mut params: Vec<String> = vec![account_id.to_string()];
|
||||
|
||||
if let Some(ref from) = query.from {
|
||||
where_clauses.push("created_at >= ?".to_string());
|
||||
params.push(from.clone());
|
||||
}
|
||||
if let Some(ref to) = query.to {
|
||||
where_clauses.push("created_at <= ?".to_string());
|
||||
params.push(to.clone());
|
||||
}
|
||||
if let Some(ref pid) = query.provider_id {
|
||||
where_clauses.push("provider_id = ?".to_string());
|
||||
params.push(pid.clone());
|
||||
}
|
||||
if let Some(ref mid) = query.model_id {
|
||||
where_clauses.push("model_id = ?".to_string());
|
||||
params.push(mid.clone());
|
||||
}
|
||||
|
||||
let where_sql = where_clauses.join(" AND ");
|
||||
|
||||
// 总量统计
|
||||
let total_sql = format!(
|
||||
"SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
FROM usage_records WHERE {}", where_sql
|
||||
);
|
||||
let mut total_query = sqlx::query_as::<_, (i64, i64, i64)>(&total_sql);
|
||||
for p in ¶ms {
|
||||
total_query = total_query.bind(p);
|
||||
}
|
||||
let (total_requests, total_input, total_output) = total_query.fetch_one(db).await?;
|
||||
|
||||
// 按模型统计
|
||||
let by_model_sql = format!(
|
||||
"SELECT provider_id, model_id, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
|
||||
where_sql
|
||||
);
|
||||
let mut by_model_query = sqlx::query_as::<_, (String, String, i64, i64, i64)>(&by_model_sql);
|
||||
for p in ¶ms {
|
||||
by_model_query = by_model_query.bind(p);
|
||||
}
|
||||
let by_model_rows = by_model_query.fetch_all(db).await?;
|
||||
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
|
||||
.map(|(provider_id, model_id, count, input, output)| {
|
||||
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
|
||||
}).collect();
|
||||
|
||||
// 按天统计 (最近 30 天)
|
||||
let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339();
|
||||
let daily_sql = format!(
|
||||
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
FROM usage_records WHERE account_id = ?1 AND created_at >= ?2
|
||||
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
|
||||
);
|
||||
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
|
||||
.bind(account_id).bind(&from_30d)
|
||||
.fetch_all(db).await?;
|
||||
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
|
||||
.map(|(date, count, input, output)| {
|
||||
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
|
||||
}).collect();
|
||||
|
||||
Ok(UsageStats {
|
||||
total_requests,
|
||||
total_input_tokens: total_input,
|
||||
total_output_tokens: total_output,
|
||||
by_model,
|
||||
by_day,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn record_usage(
|
||||
db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str,
|
||||
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
|
||||
status: &str, error_message: Option<&str>,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query(
|
||||
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)"
|
||||
)
|
||||
.bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
|
||||
.bind(status).bind(error_message).bind(&now)
|
||||
.execute(db).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============ Helpers ============
|
||||
|
||||
fn mask_api_key(key: &str) -> String {
|
||||
if key.len() <= 8 {
|
||||
return "*".repeat(key.len());
|
||||
}
|
||||
format!("{}...{}", &key[..4], &key[key.len()-4..])
|
||||
}
|
||||
172
crates/zclaw-saas/src/model_config/types.rs
Normal file
172
crates/zclaw-saas/src/model_config/types.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
//! 模型配置类型定义
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// --- Provider ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub base_url: String,
|
||||
pub api_protocol: String,
|
||||
pub enabled: bool,
|
||||
pub rate_limit_rpm: Option<i64>,
|
||||
pub rate_limit_tpm: Option<i64>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateProviderRequest {
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub base_url: String,
|
||||
#[serde(default = "default_protocol")]
|
||||
pub api_protocol: String,
|
||||
pub api_key: Option<String>,
|
||||
pub rate_limit_rpm: Option<i64>,
|
||||
pub rate_limit_tpm: Option<i64>,
|
||||
}
|
||||
|
||||
fn default_protocol() -> String { "openai".into() }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateProviderRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub base_url: Option<String>,
|
||||
pub api_protocol: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
pub rate_limit_rpm: Option<i64>,
|
||||
pub rate_limit_tpm: Option<i64>,
|
||||
}
|
||||
|
||||
// --- Model ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub alias: String,
|
||||
pub context_window: i64,
|
||||
pub max_output_tokens: i64,
|
||||
pub supports_streaming: bool,
|
||||
pub supports_vision: bool,
|
||||
pub enabled: bool,
|
||||
pub pricing_input: f64,
|
||||
pub pricing_output: f64,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateModelRequest {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub alias: String,
|
||||
pub context_window: Option<i64>,
|
||||
pub max_output_tokens: Option<i64>,
|
||||
pub supports_streaming: Option<bool>,
|
||||
pub supports_vision: Option<bool>,
|
||||
pub pricing_input: Option<f64>,
|
||||
pub pricing_output: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateModelRequest {
|
||||
pub alias: Option<String>,
|
||||
pub context_window: Option<i64>,
|
||||
pub max_output_tokens: Option<i64>,
|
||||
pub supports_streaming: Option<bool>,
|
||||
pub supports_vision: Option<bool>,
|
||||
pub enabled: Option<bool>,
|
||||
pub pricing_input: Option<f64>,
|
||||
pub pricing_output: Option<f64>,
|
||||
}
|
||||
|
||||
// --- Account API Key ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccountApiKeyInfo {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
pub key_label: Option<String>,
|
||||
pub permissions: Vec<String>,
|
||||
pub enabled: bool,
|
||||
pub last_used_at: Option<String>,
|
||||
pub created_at: String,
|
||||
pub masked_key: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateAccountApiKeyRequest {
|
||||
pub provider_id: String,
|
||||
pub key_value: String,
|
||||
pub key_label: Option<String>,
|
||||
#[serde(default)]
|
||||
pub permissions: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RotateApiKeyRequest {
|
||||
pub new_key_value: String,
|
||||
}
|
||||
|
||||
// --- Usage ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct UsageStats {
|
||||
pub total_requests: i64,
|
||||
pub total_input_tokens: i64,
|
||||
pub total_output_tokens: i64,
|
||||
pub by_model: Vec<ModelUsage>,
|
||||
pub by_day: Vec<DailyUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ModelUsage {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
pub request_count: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct DailyUsage {
|
||||
pub date: String,
|
||||
pub request_count: i64,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UsageQuery {
|
||||
pub from: Option<String>,
|
||||
pub to: Option<String>,
|
||||
pub provider_id: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
}
|
||||
|
||||
// --- Seed Data ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SeedProvider {
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub base_url: String,
|
||||
pub models: Vec<SeedModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SeedModel {
|
||||
pub id: String,
|
||||
pub alias: String,
|
||||
pub context_window: Option<i64>,
|
||||
pub max_output_tokens: Option<i64>,
|
||||
pub supports_streaming: Option<bool>,
|
||||
pub supports_vision: Option<bool>,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//! Phase 1 集成测试
|
||||
//! 集成测试 (Phase 1 + Phase 2)
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -21,6 +21,7 @@ async fn build_test_app() -> axum::Router {
|
||||
|
||||
let protected_routes = zclaw_saas::auth::protected_routes()
|
||||
.merge(zclaw_saas::account::routes())
|
||||
.merge(zclaw_saas::model_config::routes())
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
@@ -32,43 +33,45 @@ async fn build_test_app() -> axum::Router {
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_and_login() {
|
||||
let app = build_test_app().await;
|
||||
|
||||
// 注册
|
||||
let req = Request::builder()
|
||||
/// 注册并登录,返回 JWT token
|
||||
async fn register_and_login(app: &axum::Router, username: &str, email: &str) -> String {
|
||||
let register_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/auth/register")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": "password123"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
app.clone().oneshot(register_req).await.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::CREATED);
|
||||
|
||||
// 登录
|
||||
let req = Request::builder()
|
||||
let login_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/auth/login")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"username": "testuser",
|
||||
"username": username,
|
||||
"password": "password123"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let resp = app.clone().oneshot(login_req).await.unwrap();
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert!(body.get("token").is_some());
|
||||
assert_eq!(body["account"]["username"], "testuser");
|
||||
body["token"].as_str().unwrap().to_string()
|
||||
}
|
||||
|
||||
fn auth_header(token: &str) -> String {
|
||||
format!("Bearer {}", token)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_and_login() {
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "testuser", "test@example.com").await;
|
||||
assert!(!token.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -119,21 +122,8 @@ async fn test_unauthorized_access() {
|
||||
#[tokio::test]
|
||||
async fn test_login_wrong_password() {
|
||||
let app = build_test_app().await;
|
||||
register_and_login(&app, "wrongpwd", "wrongpwd@example.com").await;
|
||||
|
||||
// 先注册
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/auth/register")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"username": "wrongpwd",
|
||||
"email": "wrongpwd@example.com",
|
||||
"password": "password123"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
app.clone().oneshot(req).await.unwrap();
|
||||
|
||||
// 错误密码登录
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/auth/login")
|
||||
@@ -151,41 +141,14 @@ async fn test_login_wrong_password() {
|
||||
#[tokio::test]
|
||||
async fn test_full_authenticated_flow() {
|
||||
let app = build_test_app().await;
|
||||
|
||||
// 注册 + 登录
|
||||
let register_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/auth/register")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"username": "fulltest",
|
||||
"email": "full@example.com",
|
||||
"password": "password123"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
app.clone().oneshot(register_req).await.unwrap();
|
||||
|
||||
let login_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/auth/login")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"username": "fulltest",
|
||||
"password": "password123"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(login_req).await.unwrap();
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
let token = body["token"].as_str().unwrap().to_string();
|
||||
let token = register_and_login(&app, "fulltest", "full@example.com").await;
|
||||
|
||||
// 创建 API Token
|
||||
let create_token_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/tokens")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"name": "test-token",
|
||||
"permissions": ["model:read", "relay:use"]
|
||||
@@ -196,13 +159,13 @@ async fn test_full_authenticated_flow() {
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert!(!body["token"].is_null()); // 原始 token 仅创建时返回
|
||||
assert!(!body["token"].is_null());
|
||||
|
||||
// 列出 Tokens
|
||||
let list_req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("/api/v1/tokens")
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
@@ -213,10 +176,115 @@ async fn test_full_authenticated_flow() {
|
||||
let logs_req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("/api/v1/logs/operations")
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(logs_req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
// ============ Phase 2: 模型配置测试 ============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_providers_crud() {
|
||||
let app = build_test_app().await;
|
||||
// 注册 super_admin 角色用户 (通过直接插入角色权限)
|
||||
let token = register_and_login(&app, "adminprov", "adminprov@example.com").await;
|
||||
|
||||
// 创建 provider (普通用户无权限 → 403)
|
||||
let create_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/providers")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"name": "test-provider",
|
||||
"display_name": "Test Provider",
|
||||
"base_url": "https://api.example.com/v1"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(create_req).await.unwrap();
|
||||
// user 角色默认无 provider:manage 权限 → 403
|
||||
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
|
||||
|
||||
// 列出 providers (只读权限 → 200)
|
||||
let list_req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("/api/v1/providers")
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(list_req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_models_list_and_usage() {
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "modeluser", "modeluser@example.com").await;
|
||||
|
||||
// 列出模型 (空列表)
|
||||
let list_req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("/api/v1/models")
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(list_req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert!(body.is_array());
|
||||
assert_eq!(body.as_array().unwrap().len(), 0);
|
||||
|
||||
// 查看用量统计
|
||||
let usage_req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("/api/v1/usage")
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(usage_req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert_eq!(body["total_requests"], 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_keys_lifecycle() {
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "keyuser", "keyuser@example.com").await;
|
||||
|
||||
// 列出 keys (空)
|
||||
let list_req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("/api/v1/keys")
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
let resp = app.clone().oneshot(list_req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
// 创建 key (需要已有 provider → 404 或由 service 层验证)
|
||||
let create_req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/v1/keys")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", auth_header(&token))
|
||||
.body(Body::from(serde_json::to_string(&json!({
|
||||
"provider_id": "nonexistent",
|
||||
"key_value": "sk-test-12345",
|
||||
"key_label": "Test Key"
|
||||
})).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(create_req).await.unwrap();
|
||||
// provider 不存在 → 404
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user