diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index 6035fa0..3c72030 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -43,6 +43,7 @@ fn build_router(state: AppState) -> axum::Router { let protected_routes = zclaw_saas::auth::protected_routes() .merge(zclaw_saas::account::routes()) + .merge(zclaw_saas::model_config::routes()) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, diff --git a/crates/zclaw-saas/src/model_config/handlers.rs b/crates/zclaw-saas/src/model_config/handlers.rs new file mode 100644 index 0000000..4a2be00 --- /dev/null +++ b/crates/zclaw-saas/src/model_config/handlers.rs @@ -0,0 +1,206 @@ +//! 模型配置 HTTP 处理器 + +use axum::{ + extract::{Extension, Path, Query, State}, + http::StatusCode, Json, +}; +use crate::state::AppState; +use crate::error::{SaasError, SaasResult}; +use crate::auth::types::AuthContext; +use crate::auth::handlers::log_operation; +use super::{types::*, service}; + +// ============ Providers ============ + +/// GET /api/v1/providers +pub async fn list_providers( + State(state): State, + _ctx: Extension, +) -> SaasResult>> { + service::list_providers(&state.db).await.map(Json) +} + +/// GET /api/v1/providers/:id +pub async fn get_provider( + State(state): State, + Path(id): Path, + _ctx: Extension, +) -> SaasResult> { + service::get_provider(&state.db, &id).await.map(Json) +} + +/// POST /api/v1/providers (admin only) +pub async fn create_provider( + State(state): State, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult<(StatusCode, Json)> { + if !ctx.permissions.contains(&"provider:manage".to_string()) { + return Err(SaasError::Forbidden("需要 provider:manage 权限".into())); + } + let provider = service::create_provider(&state.db, &req).await?; + log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id, + Some(serde_json::json!({"name": &req.name})), None).await?; + Ok((StatusCode::CREATED, Json(provider))) +} + +/// PUT /api/v1/providers/:id (admin only) +pub async fn update_provider( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult> { + if !ctx.permissions.contains(&"provider:manage".to_string()) { + return Err(SaasError::Forbidden("需要 provider:manage 权限".into())); + } + let provider = service::update_provider(&state.db, &id, &req).await?; + log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, None).await?; + Ok(Json(provider)) +} + +/// DELETE /api/v1/providers/:id (admin only) +pub async fn delete_provider( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, +) -> SaasResult> { + if !ctx.permissions.contains(&"provider:manage".to_string()) { + return Err(SaasError::Forbidden("需要 provider:manage 权限".into())); + } + service::delete_provider(&state.db, &id).await?; + log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, None).await?; + Ok(Json(serde_json::json!({"ok": true}))) +} + +// ============ Models ============ + +/// GET /api/v1/models?provider_id=xxx +pub async fn list_models( + State(state): State, + Query(params): Query>, + _ctx: Extension, +) -> SaasResult>> { + let provider_id = params.get("provider_id").map(|s| s.as_str()); + service::list_models(&state.db, provider_id).await.map(Json) +} + +/// GET /api/v1/models/:id +pub async fn get_model( + State(state): State, + Path(id): Path, + _ctx: Extension, +) -> SaasResult> { + service::get_model(&state.db, &id).await.map(Json) +} + +/// POST /api/v1/models (admin only) +pub async fn create_model( + State(state): State, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult<(StatusCode, Json)> { + if !ctx.permissions.contains(&"model:manage".to_string()) { + return Err(SaasError::Forbidden("需要 model:manage 权限".into())); + } + let model = service::create_model(&state.db, &req).await?; + log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id, + Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), None).await?; + Ok((StatusCode::CREATED, Json(model))) +} + +/// PUT /api/v1/models/:id (admin only) +pub async fn update_model( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult> { + if !ctx.permissions.contains(&"model:manage".to_string()) { + return Err(SaasError::Forbidden("需要 model:manage 权限".into())); + } + let model = service::update_model(&state.db, &id, &req).await?; + log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, None).await?; + Ok(Json(model)) +} + +/// DELETE /api/v1/models/:id (admin only) +pub async fn delete_model( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, +) -> SaasResult> { + if !ctx.permissions.contains(&"model:manage".to_string()) { + return Err(SaasError::Forbidden("需要 model:manage 权限".into())); + } + service::delete_model(&state.db, &id).await?; + log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, None).await?; + Ok(Json(serde_json::json!({"ok": true}))) +} + +// ============ Account API Keys ============ + +/// GET /api/v1/keys?provider_id=xxx +pub async fn list_api_keys( + State(state): State, + Extension(ctx): Extension, + Query(params): Query>, +) -> SaasResult>> { + let provider_id = params.get("provider_id").map(|s| s.as_str()); + service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json) +} + +/// POST /api/v1/keys +pub async fn create_api_key( + State(state): State, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult<(StatusCode, Json)> { + let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?; + log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id, + Some(serde_json::json!({"provider_id": &req.provider_id})), None).await?; + Ok((StatusCode::CREATED, Json(key))) +} + +/// POST /api/v1/keys/:id/rotate +pub async fn rotate_api_key( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, + Json(req): Json, +) -> SaasResult> { + service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?; + log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, None).await?; + Ok(Json(serde_json::json!({"ok": true}))) +} + +/// DELETE /api/v1/keys/:id +pub async fn revoke_api_key( + State(state): State, + Path(id): Path, + Extension(ctx): Extension, +) -> SaasResult> { + service::revoke_account_api_key(&state.db, &id, &ctx.account_id).await?; + log_operation(&state.db, &ctx.account_id, "api_key.revoke", "api_key", &id, None, None).await?; + Ok(Json(serde_json::json!({"ok": true}))) +} + +// ============ Usage ============ + +/// GET /api/v1/usage?from=...&to=...&provider_id=...&model_id=... +pub async fn get_usage( + State(state): State, + Extension(ctx): Extension, + Query(params): Query, +) -> SaasResult> { + service::get_usage_stats(&state.db, &ctx.account_id, ¶ms).await.map(Json) +} + +/// GET /api/v1/providers/:id/models (便捷路由) +pub async fn list_provider_models( + State(state): State, + Path(provider_id): Path, + _ctx: Extension, +) -> SaasResult>> { + service::list_models(&state.db, Some(&provider_id)).await.map(Json) +} diff --git a/crates/zclaw-saas/src/model_config/mod.rs b/crates/zclaw-saas/src/model_config/mod.rs index 7eae0b7..48a3102 100644 --- a/crates/zclaw-saas/src/model_config/mod.rs +++ b/crates/zclaw-saas/src/model_config/mod.rs @@ -1 +1,26 @@ //! 模型配置模块 + +pub mod types; +pub mod service; +pub mod handlers; + +use axum::routing::{delete, get, post}; +use crate::state::AppState; + +/// 模型配置路由 (需要认证) +pub fn routes() -> axum::Router { + axum::Router::new() + // Providers + .route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider)) + .route("/api/v1/providers/{id}", get(handlers::get_provider).put(handlers::update_provider).delete(handlers::delete_provider)) + .route("/api/v1/providers/{id}/models", get(handlers::list_provider_models)) + // Models + .route("/api/v1/models", get(handlers::list_models).post(handlers::create_model)) + .route("/api/v1/models/{id}", get(handlers::get_model).put(handlers::update_model).delete(handlers::delete_model)) + // Account API Keys + .route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key)) + .route("/api/v1/keys/{id}", delete(handlers::revoke_api_key)) + .route("/api/v1/keys/{id}/rotate", post(handlers::rotate_api_key)) + // Usage + .route("/api/v1/usage", get(handlers::get_usage)) +} diff --git a/crates/zclaw-saas/src/model_config/service.rs b/crates/zclaw-saas/src/model_config/service.rs new file mode 100644 index 0000000..220706b --- /dev/null +++ b/crates/zclaw-saas/src/model_config/service.rs @@ -0,0 +1,411 @@ +//! 模型配置业务逻辑 + +use sqlx::SqlitePool; +use crate::error::{SaasError, SaasResult}; +use super::types::*; + +// ============ Providers ============ + +pub async fn list_providers(db: &SqlitePool) -> SaasResult> { + let rows: Vec<(String, String, String, String, String, bool, Option, Option, String, String)> = + sqlx::query_as( + "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at + FROM providers ORDER BY name" + ) + .fetch_all(db) + .await?; + + Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| { + ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at } + }).collect()) +} + +pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult { + let row: Option<(String, String, String, String, String, bool, Option, Option, String, String)> = + sqlx::query_as( + "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at + FROM providers WHERE id = ?1" + ) + .bind(provider_id) + .fetch_optional(db) + .await?; + + let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) = + row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?; + + Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }) +} + +pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult { + let id = uuid::Uuid::new_v4().to_string(); + let now = chrono::Utc::now().to_rfc3339(); + + // 检查名称唯一性 + let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1") + .bind(&req.name).fetch_optional(db).await?; + if existing.is_some() { + return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name))); + } + + sqlx::query( + "INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)" + ) + .bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key) + .bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now) + .execute(db).await?; + + get_provider(db, &id).await +} + +pub async fn update_provider( + db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest, +) -> SaasResult { + let now = chrono::Utc::now().to_rfc3339(); + let mut updates = Vec::new(); + let mut params: Vec> = Vec::new(); + + if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); } + if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); } + if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); } + if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); } + if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); } + if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); } + if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); } + + if updates.is_empty() { + return get_provider(db, provider_id).await; + } + + updates.push("updated_at = ?"); + params.push(Box::new(now.clone())); + params.push(Box::new(provider_id.to_string())); + + let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", ")); + let mut query = sqlx::query(&sql); + for p in ¶ms { + query = query.bind(format!("{}", p)); + } + query.execute(db).await?; + + get_provider(db, provider_id).await +} + +pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> { + let result = sqlx::query("DELETE FROM providers WHERE id = ?1") + .bind(provider_id).execute(db).await?; + + if result.rows_affected() == 0 { + return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id))); + } + Ok(()) +} + +// ============ Models ============ + +pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult> { + let sql = if provider_id.is_some() { + "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at + FROM models WHERE provider_id = ?1 ORDER BY alias" + } else { + "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at + FROM models ORDER BY provider_id, alias" + }; + + let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql); + if let Some(pid) = provider_id { + query = query.bind(pid); + } + + let rows = query.fetch_all(db).await?; + Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| { + ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at } + }).collect()) +} + +pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult { + // 验证 provider 存在 + let provider = get_provider(db, &req.provider_id).await?; + + let id = uuid::Uuid::new_v4().to_string(); + let now = chrono::Utc::now().to_rfc3339(); + + // 检查 model 唯一性 + let existing: Option<(String,)> = sqlx::query_as( + "SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2" + ) + .bind(&req.provider_id).bind(&req.model_id) + .fetch_optional(db).await?; + + if existing.is_some() { + return Err(SaasError::AlreadyExists(format!( + "模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name + ))); + } + + let ctx = req.context_window.unwrap_or(8192); + let max_out = req.max_output_tokens.unwrap_or(4096); + let streaming = req.supports_streaming.unwrap_or(true); + let vision = req.supports_vision.unwrap_or(false); + let pi = req.pricing_input.unwrap_or(0.0); + let po = req.pricing_output.unwrap_or(0.0); + + sqlx::query( + "INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)" + ) + .bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias) + .bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now) + .execute(db).await?; + + get_model(db, &id).await +} + +pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult { + let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> = + sqlx::query_as( + "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at + FROM models WHERE id = ?1" + ) + .bind(model_id) + .fetch_optional(db) + .await?; + + let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) = + row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?; + + Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }) +} + +pub async fn update_model( + db: &SqlitePool, model_id: &str, req: &UpdateModelRequest, +) -> SaasResult { + let now = chrono::Utc::now().to_rfc3339(); + let mut updates = Vec::new(); + let mut params: Vec> = Vec::new(); + + if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); } + if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); } + if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); } + if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); } + if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); } + if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); } + if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); } + if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); } + + if updates.is_empty() { + return get_model(db, model_id).await; + } + + updates.push("updated_at = ?"); + params.push(Box::new(now.clone())); + params.push(Box::new(model_id.to_string())); + + let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", ")); + let mut query = sqlx::query(&sql); + for p in ¶ms { + query = query.bind(format!("{}", p)); + } + query.execute(db).await?; + + get_model(db, model_id).await +} + +pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> { + let result = sqlx::query("DELETE FROM models WHERE id = ?1") + .bind(model_id).execute(db).await?; + + if result.rows_affected() == 0 { + return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id))); + } + Ok(()) +} + +// ============ Account API Keys ============ + +pub async fn list_account_api_keys( + db: &SqlitePool, account_id: &str, provider_id: Option<&str>, +) -> SaasResult> { + let sql = if provider_id.is_some() { + "SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value + FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC" + } else { + "SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value + FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC" + }; + + let mut query = sqlx::query_as::<_, (String, String, Option, String, bool, Option, String, String)>(sql) + .bind(account_id); + if let Some(pid) = provider_id { + query = query.bind(pid); + } + + let rows = query.fetch_all(db).await?; + Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| { + let permissions: Vec = serde_json::from_str(&perms).unwrap_or_default(); + let masked = mask_api_key(&key_value); + AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked } + }).collect()) +} + +pub async fn create_account_api_key( + db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest, +) -> SaasResult { + // 验证 provider 存在 + get_provider(db, &req.provider_id).await?; + + let id = uuid::Uuid::new_v4().to_string(); + let now = chrono::Utc::now().to_rfc3339(); + let permissions = serde_json::to_string(&req.permissions)?; + + sqlx::query( + "INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)" + ) + .bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value) + .bind(&req.key_label).bind(&permissions).bind(&now) + .execute(db).await?; + + let masked = mask_api_key(&req.key_value); + Ok(AccountApiKeyInfo { + id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(), + permissions: req.permissions.clone(), enabled: true, last_used_at: None, + created_at: now, masked_key: masked, + }) +} + +pub async fn rotate_account_api_key( + db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str, +) -> SaasResult<()> { + let now = chrono::Utc::now().to_rfc3339(); + let result = sqlx::query( + "UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL" + ) + .bind(new_key_value).bind(&now).bind(key_id).bind(account_id) + .execute(db).await?; + + if result.rows_affected() == 0 { + return Err(SaasError::NotFound("API Key 不存在或已撤销".into())); + } + Ok(()) +} + +pub async fn revoke_account_api_key( + db: &SqlitePool, key_id: &str, account_id: &str, +) -> SaasResult<()> { + let now = chrono::Utc::now().to_rfc3339(); + let result = sqlx::query( + "UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL" + ) + .bind(&now).bind(key_id).bind(account_id) + .execute(db).await?; + + if result.rows_affected() == 0 { + return Err(SaasError::NotFound("API Key 不存在或已撤销".into())); + } + Ok(()) +} + +// ============ Usage Statistics ============ + +pub async fn get_usage_stats( + db: &SqlitePool, account_id: &str, query: &UsageQuery, +) -> SaasResult { + let mut where_clauses = vec!["account_id = ?".to_string()]; + let mut params: Vec = vec![account_id.to_string()]; + + if let Some(ref from) = query.from { + where_clauses.push("created_at >= ?".to_string()); + params.push(from.clone()); + } + if let Some(ref to) = query.to { + where_clauses.push("created_at <= ?".to_string()); + params.push(to.clone()); + } + if let Some(ref pid) = query.provider_id { + where_clauses.push("provider_id = ?".to_string()); + params.push(pid.clone()); + } + if let Some(ref mid) = query.model_id { + where_clauses.push("model_id = ?".to_string()); + params.push(mid.clone()); + } + + let where_sql = where_clauses.join(" AND "); + + // 总量统计 + let total_sql = format!( + "SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0) + FROM usage_records WHERE {}", where_sql + ); + let mut total_query = sqlx::query_as::<_, (i64, i64, i64)>(&total_sql); + for p in ¶ms { + total_query = total_query.bind(p); + } + let (total_requests, total_input, total_output) = total_query.fetch_one(db).await?; + + // 按模型统计 + let by_model_sql = format!( + "SELECT provider_id, model_id, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0) + FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20", + where_sql + ); + let mut by_model_query = sqlx::query_as::<_, (String, String, i64, i64, i64)>(&by_model_sql); + for p in ¶ms { + by_model_query = by_model_query.bind(p); + } + let by_model_rows = by_model_query.fetch_all(db).await?; + let by_model: Vec = by_model_rows.into_iter() + .map(|(provider_id, model_id, count, input, output)| { + ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output } + }).collect(); + + // 按天统计 (最近 30 天) + let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339(); + let daily_sql = format!( + "SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0) + FROM usage_records WHERE account_id = ?1 AND created_at >= ?2 + GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30" + ); + let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql) + .bind(account_id).bind(&from_30d) + .fetch_all(db).await?; + let by_day: Vec = daily_rows.into_iter() + .map(|(date, count, input, output)| { + DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output } + }).collect(); + + Ok(UsageStats { + total_requests, + total_input_tokens: total_input, + total_output_tokens: total_output, + by_model, + by_day, + }) +} + +pub async fn record_usage( + db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str, + input_tokens: i64, output_tokens: i64, latency_ms: Option, + status: &str, error_message: Option<&str>, +) -> SaasResult<()> { + let now = chrono::Utc::now().to_rfc3339(); + sqlx::query( + "INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)" + ) + .bind(account_id).bind(provider_id).bind(model_id) + .bind(input_tokens).bind(output_tokens).bind(latency_ms) + .bind(status).bind(error_message).bind(&now) + .execute(db).await?; + Ok(()) +} + +// ============ Helpers ============ + +fn mask_api_key(key: &str) -> String { + if key.len() <= 8 { + return "*".repeat(key.len()); + } + format!("{}...{}", &key[..4], &key[key.len()-4..]) +} diff --git a/crates/zclaw-saas/src/model_config/types.rs b/crates/zclaw-saas/src/model_config/types.rs new file mode 100644 index 0000000..c6e79cc --- /dev/null +++ b/crates/zclaw-saas/src/model_config/types.rs @@ -0,0 +1,172 @@ +//! 模型配置类型定义 + +use serde::{Deserialize, Serialize}; + +// --- Provider --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderInfo { + pub id: String, + pub name: String, + pub display_name: String, + pub base_url: String, + pub api_protocol: String, + pub enabled: bool, + pub rate_limit_rpm: Option, + pub rate_limit_tpm: Option, + pub created_at: String, + pub updated_at: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateProviderRequest { + pub name: String, + pub display_name: String, + pub base_url: String, + #[serde(default = "default_protocol")] + pub api_protocol: String, + pub api_key: Option, + pub rate_limit_rpm: Option, + pub rate_limit_tpm: Option, +} + +fn default_protocol() -> String { "openai".into() } + +#[derive(Debug, Deserialize)] +pub struct UpdateProviderRequest { + pub display_name: Option, + pub base_url: Option, + pub api_protocol: Option, + pub api_key: Option, + pub enabled: Option, + pub rate_limit_rpm: Option, + pub rate_limit_tpm: Option, +} + +// --- Model --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelInfo { + pub id: String, + pub provider_id: String, + pub model_id: String, + pub alias: String, + pub context_window: i64, + pub max_output_tokens: i64, + pub supports_streaming: bool, + pub supports_vision: bool, + pub enabled: bool, + pub pricing_input: f64, + pub pricing_output: f64, + pub created_at: String, + pub updated_at: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateModelRequest { + pub provider_id: String, + pub model_id: String, + pub alias: String, + pub context_window: Option, + pub max_output_tokens: Option, + pub supports_streaming: Option, + pub supports_vision: Option, + pub pricing_input: Option, + pub pricing_output: Option, +} + +#[derive(Debug, Deserialize)] +pub struct UpdateModelRequest { + pub alias: Option, + pub context_window: Option, + pub max_output_tokens: Option, + pub supports_streaming: Option, + pub supports_vision: Option, + pub enabled: Option, + pub pricing_input: Option, + pub pricing_output: Option, +} + +// --- Account API Key --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccountApiKeyInfo { + pub id: String, + pub provider_id: String, + pub key_label: Option, + pub permissions: Vec, + pub enabled: bool, + pub last_used_at: Option, + pub created_at: String, + pub masked_key: String, +} + +#[derive(Debug, Deserialize)] +pub struct CreateAccountApiKeyRequest { + pub provider_id: String, + pub key_value: String, + pub key_label: Option, + #[serde(default)] + pub permissions: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct RotateApiKeyRequest { + pub new_key_value: String, +} + +// --- Usage --- + +#[derive(Debug, Serialize)] +pub struct UsageStats { + pub total_requests: i64, + pub total_input_tokens: i64, + pub total_output_tokens: i64, + pub by_model: Vec, + pub by_day: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ModelUsage { + pub provider_id: String, + pub model_id: String, + pub request_count: i64, + pub input_tokens: i64, + pub output_tokens: i64, +} + +#[derive(Debug, Serialize)] +pub struct DailyUsage { + pub date: String, + pub request_count: i64, + pub input_tokens: i64, + pub output_tokens: i64, +} + +#[derive(Debug, Deserialize)] +pub struct UsageQuery { + pub from: Option, + pub to: Option, + pub provider_id: Option, + pub model_id: Option, +} + +// --- Seed Data --- + +#[derive(Debug, Deserialize)] +pub struct SeedProvider { + pub name: String, + pub display_name: String, + pub base_url: String, + pub models: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct SeedModel { + pub id: String, + pub alias: String, + pub context_window: Option, + pub max_output_tokens: Option, + pub supports_streaming: Option, + pub supports_vision: Option, +} diff --git a/crates/zclaw-saas/tests/integration_test.rs b/crates/zclaw-saas/tests/integration_test.rs index 3f547a0..d47d016 100644 --- a/crates/zclaw-saas/tests/integration_test.rs +++ b/crates/zclaw-saas/tests/integration_test.rs @@ -1,4 +1,4 @@ -//! Phase 1 集成测试 +//! 集成测试 (Phase 1 + Phase 2) use axum::{ body::Body, @@ -21,6 +21,7 @@ async fn build_test_app() -> axum::Router { let protected_routes = zclaw_saas::auth::protected_routes() .merge(zclaw_saas::account::routes()) + .merge(zclaw_saas::model_config::routes()) .layer(axum::middleware::from_fn_with_state( state.clone(), zclaw_saas::auth::auth_middleware, @@ -32,43 +33,45 @@ async fn build_test_app() -> axum::Router { .with_state(state) } -#[tokio::test] -async fn test_register_and_login() { - let app = build_test_app().await; - - // 注册 - let req = Request::builder() +/// 注册并登录,返回 JWT token +async fn register_and_login(app: &axum::Router, username: &str, email: &str) -> String { + let register_req = Request::builder() .method("POST") .uri("/api/v1/auth/register") .header("Content-Type", "application/json") .body(Body::from(serde_json::to_string(&json!({ - "username": "testuser", - "email": "test@example.com", + "username": username, + "email": email, "password": "password123" })).unwrap())) .unwrap(); + app.clone().oneshot(register_req).await.unwrap(); - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - // 登录 - let req = Request::builder() + let login_req = Request::builder() .method("POST") .uri("/api/v1/auth/login") .header("Content-Type", "application/json") .body(Body::from(serde_json::to_string(&json!({ - "username": "testuser", + "username": username, "password": "password123" })).unwrap())) .unwrap(); - let resp = app.clone().oneshot(req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - + let resp = app.clone().oneshot(login_req).await.unwrap(); let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap(); let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); - assert!(body.get("token").is_some()); - assert_eq!(body["account"]["username"], "testuser"); + body["token"].as_str().unwrap().to_string() +} + +fn auth_header(token: &str) -> String { + format!("Bearer {}", token) +} + +#[tokio::test] +async fn test_register_and_login() { + let app = build_test_app().await; + let token = register_and_login(&app, "testuser", "test@example.com").await; + assert!(!token.is_empty()); } #[tokio::test] @@ -119,21 +122,8 @@ async fn test_unauthorized_access() { #[tokio::test] async fn test_login_wrong_password() { let app = build_test_app().await; + register_and_login(&app, "wrongpwd", "wrongpwd@example.com").await; - // 先注册 - let req = Request::builder() - .method("POST") - .uri("/api/v1/auth/register") - .header("Content-Type", "application/json") - .body(Body::from(serde_json::to_string(&json!({ - "username": "wrongpwd", - "email": "wrongpwd@example.com", - "password": "password123" - })).unwrap())) - .unwrap(); - app.clone().oneshot(req).await.unwrap(); - - // 错误密码登录 let req = Request::builder() .method("POST") .uri("/api/v1/auth/login") @@ -151,41 +141,14 @@ async fn test_login_wrong_password() { #[tokio::test] async fn test_full_authenticated_flow() { let app = build_test_app().await; - - // 注册 + 登录 - let register_req = Request::builder() - .method("POST") - .uri("/api/v1/auth/register") - .header("Content-Type", "application/json") - .body(Body::from(serde_json::to_string(&json!({ - "username": "fulltest", - "email": "full@example.com", - "password": "password123" - })).unwrap())) - .unwrap(); - app.clone().oneshot(register_req).await.unwrap(); - - let login_req = Request::builder() - .method("POST") - .uri("/api/v1/auth/login") - .header("Content-Type", "application/json") - .body(Body::from(serde_json::to_string(&json!({ - "username": "fulltest", - "password": "password123" - })).unwrap())) - .unwrap(); - - let resp = app.clone().oneshot(login_req).await.unwrap(); - let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap(); - let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); - let token = body["token"].as_str().unwrap().to_string(); + let token = register_and_login(&app, "fulltest", "full@example.com").await; // 创建 API Token let create_token_req = Request::builder() .method("POST") .uri("/api/v1/tokens") .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) + .header("Authorization", auth_header(&token)) .body(Body::from(serde_json::to_string(&json!({ "name": "test-token", "permissions": ["model:read", "relay:use"] @@ -196,13 +159,13 @@ async fn test_full_authenticated_flow() { assert_eq!(resp.status(), StatusCode::OK); let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap(); let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); - assert!(!body["token"].is_null()); // 原始 token 仅创建时返回 + assert!(!body["token"].is_null()); // 列出 Tokens let list_req = Request::builder() .method("GET") .uri("/api/v1/tokens") - .header("Authorization", format!("Bearer {}", token)) + .header("Authorization", auth_header(&token)) .body(Body::empty()) .unwrap(); @@ -213,10 +176,115 @@ async fn test_full_authenticated_flow() { let logs_req = Request::builder() .method("GET") .uri("/api/v1/logs/operations") - .header("Authorization", format!("Bearer {}", token)) + .header("Authorization", auth_header(&token)) .body(Body::empty()) .unwrap(); let resp = app.oneshot(logs_req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); } + +// ============ Phase 2: 模型配置测试 ============ + +#[tokio::test] +async fn test_providers_crud() { + let app = build_test_app().await; + // 注册 super_admin 角色用户 (通过直接插入角色权限) + let token = register_and_login(&app, "adminprov", "adminprov@example.com").await; + + // 创建 provider (普通用户无权限 → 403) + let create_req = Request::builder() + .method("POST") + .uri("/api/v1/providers") + .header("Content-Type", "application/json") + .header("Authorization", auth_header(&token)) + .body(Body::from(serde_json::to_string(&json!({ + "name": "test-provider", + "display_name": "Test Provider", + "base_url": "https://api.example.com/v1" + })).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(create_req).await.unwrap(); + // user 角色默认无 provider:manage 权限 → 403 + assert_eq!(resp.status(), StatusCode::FORBIDDEN); + + // 列出 providers (只读权限 → 200) + let list_req = Request::builder() + .method("GET") + .uri("/api/v1/providers") + .header("Authorization", auth_header(&token)) + .body(Body::empty()) + .unwrap(); + + let resp = app.clone().oneshot(list_req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); +} + +#[tokio::test] +async fn test_models_list_and_usage() { + let app = build_test_app().await; + let token = register_and_login(&app, "modeluser", "modeluser@example.com").await; + + // 列出模型 (空列表) + let list_req = Request::builder() + .method("GET") + .uri("/api/v1/models") + .header("Authorization", auth_header(&token)) + .body(Body::empty()) + .unwrap(); + + let resp = app.clone().oneshot(list_req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert!(body.is_array()); + assert_eq!(body.as_array().unwrap().len(), 0); + + // 查看用量统计 + let usage_req = Request::builder() + .method("GET") + .uri("/api/v1/usage") + .header("Authorization", auth_header(&token)) + .body(Body::empty()) + .unwrap(); + + let resp = app.clone().oneshot(usage_req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body_bytes = axum::body::to_bytes(resp.into_body(), MAX_BODY_SIZE).await.unwrap(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + assert_eq!(body["total_requests"], 0); +} + +#[tokio::test] +async fn test_api_keys_lifecycle() { + let app = build_test_app().await; + let token = register_and_login(&app, "keyuser", "keyuser@example.com").await; + + // 列出 keys (空) + let list_req = Request::builder() + .method("GET") + .uri("/api/v1/keys") + .header("Authorization", auth_header(&token)) + .body(Body::empty()) + .unwrap(); + let resp = app.clone().oneshot(list_req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // 创建 key (需要已有 provider → 404 或由 service 层验证) + let create_req = Request::builder() + .method("POST") + .uri("/api/v1/keys") + .header("Content-Type", "application/json") + .header("Authorization", auth_header(&token)) + .body(Body::from(serde_json::to_string(&json!({ + "provider_id": "nonexistent", + "key_value": "sk-test-12345", + "key_label": "Test Key" + })).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(create_req).await.unwrap(); + // provider 不存在 → 404 + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +}