//! 模型配置业务逻辑 use sqlx::{PgPool, Row}; use crate::error::{SaasError, SaasResult}; use crate::common::{PaginatedResponse, normalize_pagination}; use crate::crypto; use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow}; use super::types::*; // ============ Providers ============ pub async fn list_providers( db: &PgPool, page: Option, page_size: Option, enabled_filter: Option, ) -> SaasResult> { let (p, ps, offset) = normalize_pagination(page, page_size); let (count_sql, data_sql) = if enabled_filter.is_some() { ( "SELECT COUNT(*) FROM providers WHERE enabled = $1", "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3", ) } else { ( "SELECT COUNT(*) FROM providers", "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at FROM providers ORDER BY name LIMIT $1 OFFSET $2", ) }; let total: (i64,) = if let Some(en) = enabled_filter { sqlx::query_as(count_sql).bind(en).fetch_one(db).await? } else { sqlx::query_as(count_sql).fetch_one(db).await? }; let rows: Vec = if let Some(en) = enabled_filter { sqlx::query_as(data_sql) .bind(en).bind(ps as i64).bind(offset) .fetch_all(db).await? } else { sqlx::query_as(data_sql) .bind(ps as i64).bind(offset) .fetch_all(db).await? }; let items = rows.into_iter().map(|r| { ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at } }).collect(); Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps }) } pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult { let row: Option = sqlx::query_as( "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at FROM providers WHERE id = $1" ) .bind(provider_id) .fetch_optional(db) .await?; let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?; Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }) } pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult { let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().to_rfc3339(); // 检查名称唯一性 let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1") .bind(&req.name).fetch_optional(db).await?; if existing.is_some() { return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name))); } // 加密 API Key 后存储 let encrypted_api_key = if let Some(ref key) = req.api_key { if key.is_empty() { String::new() } else { crypto::encrypt_value(key, enc_key)? } } else { String::new() }; sqlx::query( "INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)" ) .bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key) .bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now) .execute(db).await?; get_provider(db, &id).await } pub async fn update_provider( db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32], ) -> SaasResult { let now = chrono::Utc::now().to_rfc3339(); // Encrypt api_key upfront if provided let encrypted_api_key = match req.api_key { Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?), Some(ref v) if v.is_empty() => Some(String::new()), _ => None, }; // COALESCE pattern: all updatable fields in a single static SQL. // NULL parameters leave the column unchanged. sqlx::query( "UPDATE providers SET display_name = COALESCE($1, display_name), base_url = COALESCE($2, base_url), api_protocol = COALESCE($3, api_protocol), api_key = COALESCE($4, api_key), enabled = COALESCE($5, enabled), rate_limit_rpm = COALESCE($6, rate_limit_rpm), rate_limit_tpm = COALESCE($7, rate_limit_tpm), updated_at = $8 WHERE id = $9" ) .bind(req.display_name.as_deref()) .bind(req.base_url.as_deref()) .bind(req.api_protocol.as_deref()) .bind(encrypted_api_key.as_deref()) .bind(req.enabled) .bind(req.rate_limit_rpm) .bind(req.rate_limit_tpm) .bind(&now) .bind(provider_id) .execute(db).await?; get_provider(db, provider_id).await } pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> { let result = sqlx::query("DELETE FROM providers WHERE id = $1") .bind(provider_id).execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id))); } Ok(()) } // ============ Models ============ pub async fn list_models( db: &PgPool, provider_id: Option<&str>, page: Option, page_size: Option, ) -> SaasResult> { let (p, ps, offset) = normalize_pagination(page, page_size); let (count_sql, data_sql) = if provider_id.is_some() { ( "SELECT COUNT(*) FROM models WHERE provider_id = $1", "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3", ) } else { ( "SELECT COUNT(*) FROM models", "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2", ) }; let total: (i64,) = if let Some(pid) = provider_id { sqlx::query_as(count_sql).bind(pid).fetch_one(db).await? } else { sqlx::query_as(count_sql).fetch_one(db).await? }; let mut query = sqlx::query_as::<_, ModelRow>(data_sql); if let Some(pid) = provider_id { query = query.bind(pid); } let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?; let items = rows.into_iter().map(|r| { ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at } }).collect(); Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps }) } pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult { // 验证 provider 存在 let provider = get_provider(db, &req.provider_id).await?; let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().to_rfc3339(); // 检查 model 唯一性 let existing: Option<(String,)> = sqlx::query_as( "SELECT id FROM models WHERE provider_id = $1 AND model_id = $2" ) .bind(&req.provider_id).bind(&req.model_id) .fetch_optional(db).await?; if existing.is_some() { return Err(SaasError::AlreadyExists(format!( "模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name ))); } let ctx = req.context_window.unwrap_or(8192); let max_out = req.max_output_tokens.unwrap_or(4096); let streaming = req.supports_streaming.unwrap_or(true); let vision = req.supports_vision.unwrap_or(false); let pi = req.pricing_input.unwrap_or(0.0); let po = req.pricing_output.unwrap_or(0.0); sqlx::query( "INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)" ) .bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias) .bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now) .execute(db).await?; get_model(db, &id).await } pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult { let row: Option = sqlx::query_as( "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at FROM models WHERE id = $1" ) .bind(model_id) .fetch_optional(db) .await?; let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?; Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }) } pub async fn update_model( db: &PgPool, model_id: &str, req: &UpdateModelRequest, ) -> SaasResult { let now = chrono::Utc::now().to_rfc3339(); // COALESCE pattern: all updatable fields in a single static SQL. // NULL parameters leave the column unchanged. sqlx::query( "UPDATE models SET alias = COALESCE($1, alias), context_window = COALESCE($2, context_window), max_output_tokens = COALESCE($3, max_output_tokens), supports_streaming = COALESCE($4, supports_streaming), supports_vision = COALESCE($5, supports_vision), enabled = COALESCE($6, enabled), pricing_input = COALESCE($7, pricing_input), pricing_output = COALESCE($8, pricing_output), updated_at = $9 WHERE id = $10" ) .bind(req.alias.as_deref()) .bind(req.context_window) .bind(req.max_output_tokens) .bind(req.supports_streaming) .bind(req.supports_vision) .bind(req.enabled) .bind(req.pricing_input) .bind(req.pricing_output) .bind(&now) .bind(model_id) .execute(db).await?; get_model(db, model_id).await } pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> { let result = sqlx::query("DELETE FROM models WHERE id = $1") .bind(model_id).execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id))); } Ok(()) } // ============ Account API Keys ============ pub async fn list_account_api_keys( db: &PgPool, account_id: &str, provider_id: Option<&str>, page: Option, page_size: Option, ) -> SaasResult> { let (p, ps, offset) = normalize_pagination(page, page_size); // Build COUNT and data queries based on whether provider_id is provided let (count_sql, data_sql) = if provider_id.is_some() { ( "SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL", "SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4", ) } else { ( "SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL", "SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3", ) }; let total: (i64,) = if provider_id.is_some() { let mut q = sqlx::query_as(count_sql).bind(account_id); if let Some(pid) = provider_id { q = q.bind(pid); } q.fetch_one(db).await? } else { sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await? }; let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql) .bind(account_id); if let Some(pid) = provider_id { query = query.bind(pid); } let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?; let items = rows.into_iter().map(|r| { let permissions: Vec = serde_json::from_str(&r.permissions).unwrap_or_default(); let masked = mask_api_key(&r.key_value); AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked } }).collect(); Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps }) } pub async fn create_account_api_key( db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32], ) -> SaasResult { // 验证 provider 存在 get_provider(db, &req.provider_id).await?; let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().to_rfc3339(); let permissions = serde_json::to_string(&req.permissions)?; // 加密 key_value 后存储 let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?; sqlx::query( "INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)" ) .bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value) .bind(&req.key_label).bind(&permissions).bind(&now) .execute(db).await?; let masked = mask_api_key(&req.key_value); Ok(AccountApiKeyInfo { id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(), permissions: req.permissions.clone(), enabled: true, last_used_at: None, created_at: now, masked_key: masked, }) } pub async fn rotate_account_api_key( db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32], ) -> SaasResult<()> { let now = chrono::Utc::now().to_rfc3339(); let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?; let result = sqlx::query( "UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL" ) .bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id) .execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound("API Key 不存在或已撤销".into())); } Ok(()) } pub async fn revoke_account_api_key( db: &PgPool, key_id: &str, account_id: &str, ) -> SaasResult<()> { let now = chrono::Utc::now().to_rfc3339(); let result = sqlx::query( "UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL" ) .bind(&now).bind(key_id).bind(account_id) .execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound("API Key 不存在或已撤销".into())); } Ok(()) } // ============ Usage Statistics ============ pub async fn get_usage_stats( db: &PgPool, account_id: &str, query: &UsageQuery, ) -> SaasResult { // Static SQL with conditional filter pattern: // account_id is always required; optional filters use ($N IS NULL OR col = $N). let total_sql = "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5)"; let row = sqlx::query(total_sql) .bind(account_id) .bind(&query.from) .bind(&query.to) .bind(&query.provider_id) .bind(&query.model_id) .fetch_one(db).await?; let total_requests: i64 = row.try_get(0).unwrap_or(0); let total_input: i64 = row.try_get(1).unwrap_or(0); let total_output: i64 = row.try_get(2).unwrap_or(0); // 按模型统计 let by_model_sql = "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5) GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20"; let by_model_rows: Vec = sqlx::query_as(by_model_sql) .bind(account_id) .bind(&query.from) .bind(&query.to) .bind(&query.provider_id) .bind(&query.model_id) .fetch_all(db).await?; let by_model: Vec = by_model_rows.into_iter() .map(|r| { ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens } }).collect(); // 按天统计 (使用 days 参数或默认 30 天) let days = query.days.unwrap_or(30).min(365).max(1) as i64; let from_days = (chrono::Utc::now() - chrono::Duration::days(days)) .date_naive() .and_hms_opt(0, 0, 0).unwrap() .and_utc() .to_rfc3339(); let daily_sql = "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens FROM usage_records WHERE account_id = $1 AND created_at >= $2 GROUP BY created_at::date ORDER BY day DESC LIMIT $3"; let daily_rows: Vec = sqlx::query_as(daily_sql) .bind(account_id).bind(&from_days).bind(days as i32) .fetch_all(db).await?; let by_day: Vec = daily_rows.into_iter() .map(|r| { DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens } }).collect(); // 按 group_by 过滤返回 let group_by = query.group_by.as_deref(); let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] }; let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] }; Ok(UsageStats { total_requests, total_input_tokens: total_input, total_output_tokens: total_output, by_model, by_day, }) } pub async fn record_usage( db: &PgPool, account_id: &str, provider_id: &str, model_id: &str, input_tokens: i64, output_tokens: i64, latency_ms: Option, status: &str, error_message: Option<&str>, ) -> SaasResult<()> { let now = chrono::Utc::now().to_rfc3339(); sqlx::query( "INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" ) .bind(account_id).bind(provider_id).bind(model_id) .bind(input_tokens).bind(output_tokens).bind(latency_ms) .bind(status).bind(error_message).bind(&now) .execute(db).await?; Ok(()) } // ============ Helpers ============ fn mask_api_key(key: &str) -> String { if key.len() <= 8 { return "*".repeat(key.len()); } format!("{}...{}", &key[..4], &key[key.len()-4..]) }