//! 模型配置业务逻辑 use sqlx::{PgPool, Row}; use crate::error::{SaasError, SaasResult}; use crate::common::{PaginatedResponse, normalize_pagination}; use crate::crypto; use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow}; use super::types::*; // ============ Providers ============ pub async fn list_providers( db: &PgPool, page: Option, page_size: Option, enabled_filter: Option, ) -> SaasResult> { let (p, ps, offset) = normalize_pagination(page, page_size); let (count_sql, data_sql) = if enabled_filter.is_some() { ( "SELECT COUNT(*) FROM providers WHERE enabled = $1", "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3", ) } else { ( "SELECT COUNT(*) FROM providers", "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT FROM providers ORDER BY name LIMIT $1 OFFSET $2", ) }; let total: (i64,) = if let Some(en) = enabled_filter { sqlx::query_as(count_sql).bind(en).fetch_one(db).await? } else { sqlx::query_as(count_sql).fetch_one(db).await? }; let rows: Vec = if let Some(en) = enabled_filter { sqlx::query_as(data_sql) .bind(en).bind(ps as i64).bind(offset) .fetch_all(db).await? } else { sqlx::query_as(data_sql) .bind(ps as i64).bind(offset) .fetch_all(db).await? }; let items = rows.into_iter().map(|r| { ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at } }).collect(); Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps }) } pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult { let row: Option = sqlx::query_as( "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT FROM providers WHERE id = $1" ) .bind(provider_id) .fetch_optional(db) .await?; let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?; Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }) } pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult { let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); // 检查名称唯一性 let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1") .bind(&req.name).fetch_optional(db).await?; if existing.is_some() { return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name))); } // 加密 API Key 后存储 let encrypted_api_key = if let Some(ref key) = req.api_key { if key.is_empty() { String::new() } else { crypto::encrypt_value(key, enc_key)? } } else { String::new() }; let display_name = req.display_name.as_deref().unwrap_or(&req.name); sqlx::query( "INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)" ) .bind(&id).bind(&req.name).bind(display_name).bind(&encrypted_api_key) .bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now) .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider '{}'", req.name)))?; get_provider(db, &id).await } pub async fn update_provider( db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32], ) -> SaasResult { let now = chrono::Utc::now(); // Encrypt api_key upfront if provided let encrypted_api_key = match req.api_key { Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?), Some(ref v) if v.is_empty() => Some(String::new()), _ => None, }; // COALESCE pattern: all updatable fields in a single static SQL. // NULL parameters leave the column unchanged. sqlx::query( "UPDATE providers SET display_name = COALESCE($1, display_name), base_url = COALESCE($2, base_url), api_protocol = COALESCE($3, api_protocol), api_key = COALESCE($4, api_key), enabled = COALESCE($5, enabled), rate_limit_rpm = COALESCE($6, rate_limit_rpm), rate_limit_tpm = COALESCE($7, rate_limit_tpm), updated_at = $8 WHERE id = $9" ) .bind(req.display_name.as_deref()) .bind(req.base_url.as_deref()) .bind(req.api_protocol.as_deref()) .bind(encrypted_api_key.as_deref()) .bind(req.enabled) .bind(req.rate_limit_rpm) .bind(req.rate_limit_tpm) .bind(&now) .bind(provider_id) .execute(db).await?; get_provider(db, provider_id).await } pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> { let result = sqlx::query("DELETE FROM providers WHERE id = $1") .bind(provider_id).execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id))); } Ok(()) } // ============ Models ============ pub async fn list_models( db: &PgPool, provider_id: Option<&str>, page: Option, page_size: Option, ) -> SaasResult> { let (p, ps, offset) = normalize_pagination(page, page_size); let (count_sql, data_sql) = if provider_id.is_some() { ( "SELECT COUNT(*) FROM models WHERE provider_id = $1", "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3", ) } else { ( "SELECT COUNT(*) FROM models", "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2", ) }; let total: (i64,) = if let Some(pid) = provider_id { sqlx::query_as(count_sql).bind(pid).fetch_one(db).await? } else { sqlx::query_as(count_sql).fetch_one(db).await? }; let mut query = sqlx::query_as::<_, ModelRow>(data_sql); if let Some(pid) = provider_id { query = query.bind(pid); } let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?; let items = rows.into_iter().map(|r| { ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at } }).collect(); Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps }) } pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult { // 验证 provider 存在 let provider = get_provider(db, &req.provider_id).await?; let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); // 检查 model 唯一性 let existing: Option<(String,)> = sqlx::query_as( "SELECT id FROM models WHERE provider_id = $1 AND model_id = $2" ) .bind(&req.provider_id).bind(&req.model_id) .fetch_optional(db).await?; if existing.is_some() { return Err(SaasError::AlreadyExists(format!( "模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name ))); } // M-2: 检查 model_id 不与模型组名冲突(避免路由歧义) let group_conflict: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1") .bind(&req.model_id).fetch_optional(db).await?; if group_conflict.is_some() { return Err(SaasError::InvalidInput( format!("模型 ID '{}' 与已有模型组名称冲突,请使用不同的 ID", req.model_id) )); } let ctx = req.context_window.unwrap_or(8192); let max_out = req.max_output_tokens.unwrap_or(4096); let streaming = req.supports_streaming.unwrap_or(true); let vision = req.supports_vision.unwrap_or(false); let pi = req.pricing_input.unwrap_or(0.0); let po = req.pricing_output.unwrap_or(0.0); sqlx::query( "INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)" ) .bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(req.alias.as_deref().unwrap_or(&req.model_id)) .bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now) .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型 '{}' 在 Provider '{}'", req.model_id, req.provider_id)))?; get_model(db, &id).await } pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult { let row: Option = sqlx::query_as( "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT FROM models WHERE id = $1" ) .bind(model_id) .fetch_optional(db) .await?; let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?; Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }) } pub async fn update_model( db: &PgPool, model_id: &str, req: &UpdateModelRequest, ) -> SaasResult { let now = chrono::Utc::now(); // COALESCE pattern: all updatable fields in a single static SQL. // NULL parameters leave the column unchanged. sqlx::query( "UPDATE models SET alias = COALESCE($1, alias), context_window = COALESCE($2, context_window), max_output_tokens = COALESCE($3, max_output_tokens), supports_streaming = COALESCE($4, supports_streaming), supports_vision = COALESCE($5, supports_vision), enabled = COALESCE($6, enabled), pricing_input = COALESCE($7, pricing_input), pricing_output = COALESCE($8, pricing_output), updated_at = $9 WHERE id = $10" ) .bind(req.alias.as_deref()) .bind(req.context_window) .bind(req.max_output_tokens) .bind(req.supports_streaming) .bind(req.supports_vision) .bind(req.enabled) .bind(req.pricing_input) .bind(req.pricing_output) .bind(&now) .bind(model_id) .execute(db).await?; get_model(db, model_id).await } pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> { let result = sqlx::query("DELETE FROM models WHERE id = $1") .bind(model_id).execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id))); } Ok(()) } // ============ Account API Keys ============ pub async fn list_account_api_keys( db: &PgPool, account_id: &str, provider_id: Option<&str>, page: Option, page_size: Option, ) -> SaasResult> { let (p, ps, offset) = normalize_pagination(page, page_size); // Build COUNT and data queries based on whether provider_id is provided let (count_sql, data_sql) = if provider_id.is_some() { ( "SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL", "SELECT id, provider_id, key_label, permissions, enabled, last_used_at::TEXT, created_at::TEXT, key_value FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4", ) } else { ( "SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL", "SELECT id, provider_id, key_label, permissions, enabled, last_used_at::TEXT, created_at::TEXT, key_value FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3", ) }; let total: (i64,) = if provider_id.is_some() { let mut q = sqlx::query_as(count_sql).bind(account_id); if let Some(pid) = provider_id { q = q.bind(pid); } q.fetch_one(db).await? } else { sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await? }; let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql) .bind(account_id); if let Some(pid) = provider_id { query = query.bind(pid); } let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?; let items = rows.into_iter().map(|r| { let permissions: Vec = serde_json::from_str(&r.permissions).unwrap_or_default(); let masked = mask_api_key(&r.key_value); AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked } }).collect(); Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps }) } pub async fn create_account_api_key( db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32], ) -> SaasResult { // 验证 provider 存在 get_provider(db, &req.provider_id).await?; let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); let permissions = serde_json::to_string(&req.permissions)?; // 加密 key_value 后存储 let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?; sqlx::query( "INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)" ) .bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value) .bind(&req.key_label).bind(&permissions).bind(&now) .execute(db).await?; let masked = mask_api_key(&req.key_value); Ok(AccountApiKeyInfo { id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(), permissions: req.permissions.clone(), enabled: true, last_used_at: None, created_at: now.to_rfc3339(), masked_key: masked, }) } pub async fn rotate_account_api_key( db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32], ) -> SaasResult<()> { let now = chrono::Utc::now(); let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?; let result = sqlx::query( "UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL" ) .bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id) .execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound("API Key 不存在或已撤销".into())); } Ok(()) } pub async fn revoke_account_api_key( db: &PgPool, key_id: &str, account_id: &str, ) -> SaasResult<()> { let now = chrono::Utc::now(); let result = sqlx::query( "UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL" ) .bind(&now).bind(key_id).bind(account_id) .execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound("API Key 不存在或已撤销".into())); } Ok(()) } // ============ Usage Statistics ============ pub async fn get_usage_stats( db: &PgPool, account_id: &str, query: &UsageQuery, ) -> SaasResult { // Optional date filters: pass as TEXT with explicit $N::timestamptz SQL cast. // This avoids the sqlx NULL-without-type-OID problem — PG's ::timestamptz // gives a typed NULL even when sqlx sends an untyped NULL. let from_str: Option<&str> = query.from.as_deref(); // For 'to' date-only strings, append T23:59:59 to include the entire day let to_str: Option = query.to.as_ref().map(|s| { if s.len() == 10 { format!("{}T23:59:59", s) } else { s.clone() } }); // Build SQL dynamically to avoid sqlx NULL-without-type-OID problem entirely. // Date parameters are injected as SQL literals (validated above via chrono parse). // Only account_id uses parameterized binding to prevent SQL injection on user input. let mut where_parts = vec![format!("account_id = '{}'", account_id.replace('\'', "''"))]; if let Some(f) = from_str { // Validate: must be parseable as a date let valid = chrono::NaiveDate::parse_from_str(f, "%Y-%m-%d").is_ok() || chrono::NaiveDateTime::parse_from_str(f, "%Y-%m-%dT%H:%M:%S%.f").is_ok(); if !valid { return Err(SaasError::InvalidInput(format!("Invalid 'from' date: {}", f))); } where_parts.push(format!("created_at::timestamptz >= '{}T00:00:00Z'::timestamptz", f.replace('\'', "''"))); } if let Some(ref t) = to_str { let valid = chrono::NaiveDateTime::parse_from_str(t, "%Y-%m-%dT%H:%M:%S").is_ok() || chrono::NaiveDate::parse_from_str(t, "%Y-%m-%d").is_ok(); if !valid { return Err(SaasError::InvalidInput(format!("Invalid 'to' date: {}", t))); } where_parts.push(format!("created_at::timestamptz <= '{}'::timestamptz", t.replace('\'', "''"))); } if let Some(ref pid) = query.provider_id { where_parts.push(format!("provider_id = '{}'", pid.replace('\'', "''"))); } if let Some(ref mid) = query.model_id { where_parts.push(format!("model_id = '{}'", mid.replace('\'', "''"))); } let where_clause = where_parts.join(" AND "); let total_sql = format!( "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0)::bigint, COALESCE(SUM(output_tokens), 0)::bigint FROM usage_records WHERE {}", where_clause ); let row = sqlx::query(&total_sql).fetch_one(db).await?; let total_requests: i64 = row.try_get(0).unwrap_or(0); let total_input: i64 = row.try_get(1).unwrap_or(0); let total_output: i64 = row.try_get(2).unwrap_or(0); // 按模型统计 let by_model_sql = format!( "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20", where_clause ); let by_model_rows: Vec = sqlx::query_as(&by_model_sql).fetch_all(db).await?; let by_model: Vec = by_model_rows.into_iter() .map(|r| { ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens } }).collect(); // 按天统计 (使用 days 参数或默认 30 天) let days = query.days.unwrap_or(30).min(365).max(1) as i64; let from_days_str = (chrono::Utc::now() - chrono::Duration::days(days)) .format("%Y-%m-%d").to_string(); let daily_sql = format!( "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens FROM usage_records WHERE account_id = '{}' AND created_at::timestamptz >= '{}T00:00:00Z'::timestamptz GROUP BY created_at::date ORDER BY day DESC LIMIT {}", account_id.replace('\'', "''"), from_days_str.replace('\'', "''"), days ); let daily_rows: Vec = sqlx::query_as(&daily_sql).fetch_all(db).await?; let by_day: Vec = daily_rows.into_iter() .map(|r| { DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens } }).collect(); // 按 group_by 过滤返回 let group_by = query.group_by.as_deref(); let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] }; let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] }; Ok(UsageStats { total_requests, total_input_tokens: total_input, total_output_tokens: total_output, by_model, by_day, }) } pub async fn record_usage( db: &PgPool, account_id: &str, provider_id: &str, model_id: &str, input_tokens: i64, output_tokens: i64, latency_ms: Option, status: &str, error_message: Option<&str>, ) -> SaasResult<()> { let now = chrono::Utc::now(); sqlx::query( "INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" ) .bind(account_id).bind(provider_id).bind(model_id) .bind(input_tokens).bind(output_tokens).bind(latency_ms) .bind(status).bind(error_message).bind(&now) .execute(db).await?; Ok(()) } // ============ Helpers ============ fn mask_api_key(key: &str) -> String { if key.len() <= 8 { return "*".repeat(key.len()); } format!("{}...{}", &key[..4], &key[key.len()-4..]) } // ============ Model Groups ============ pub async fn list_model_groups(db: &PgPool) -> SaasResult> { let group_rows: Vec<(String, String, String, String, bool, String, String, String)> = sqlx::query_as( "SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware'), created_at::TEXT, updated_at::TEXT FROM model_groups ORDER BY name" ).fetch_all(db).await?; let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as( "SELECT id, group_id, provider_id, model_id, priority, enabled FROM model_group_members ORDER BY priority ASC" ).fetch_all(db).await?; let groups = group_rows.into_iter().map(|(id, name, display_name, description, enabled, failover_strategy, created_at, updated_at)| { let members: Vec = member_rows.iter() .filter(|(_, gid, _, _, _, _)| gid == &id) .map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo { id: mid.clone(), provider_id: pid.clone(), model_id: mid2.clone(), priority: *pri, enabled: *en, }) .collect(); ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at } }).collect(); Ok(groups) } pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult { let row: Option<(String, String, String, String, bool, String, String, String)> = sqlx::query_as( "SELECT id, name, display_name, COALESCE(description, ''), enabled, COALESCE(failover_strategy, 'quota_aware'), created_at::TEXT, updated_at::TEXT FROM model_groups WHERE id = $1" ).bind(group_id).fetch_optional(db).await?; let (id, name, display_name, description, enabled, failover_strategy, created_at, updated_at) = row.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?; let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as( "SELECT id, group_id, provider_id, model_id, priority, enabled FROM model_group_members WHERE group_id = $1 ORDER BY priority ASC" ).bind(group_id).fetch_all(db).await?; let members = member_rows.into_iter() .map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo { id: mid, provider_id: pid, model_id: mid2, priority: pri, enabled: en, }) .collect(); Ok(ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at }) } pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult { // M-S1: failover_strategy 白名单校验 const VALID_STRATEGIES: &[&str] = &["quota_aware", "priority", "random"]; if !VALID_STRATEGIES.contains(&req.failover_strategy.as_str()) { return Err(SaasError::InvalidInput( format!("failover_strategy 必须是 {:?} 之一", VALID_STRATEGIES) )); } let id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); // 检查名称唯一性 let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1") .bind(&req.name).fetch_optional(db).await?; if existing.is_some() { return Err(SaasError::AlreadyExists(format!("模型组 '{}' 已存在", req.name))); } // 名称不能和已有 model_id 冲突(避免路由歧义) let model_conflict: Option<(String,)> = sqlx::query_as("SELECT model_id FROM models WHERE model_id = $1") .bind(&req.name).fetch_optional(db).await?; if model_conflict.is_some() { return Err(SaasError::InvalidInput( format!("模型组名称 '{}' 与已有模型 ID 冲突,请使用不同的名称", req.name) )); } sqlx::query( "INSERT INTO model_groups (id, name, display_name, description, failover_strategy, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $6)" ) .bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description) .bind(&req.failover_strategy).bind(&now) .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型组 '{}'", req.name)))?; get_model_group(db, &id).await } pub async fn update_model_group( db: &PgPool, group_id: &str, req: &UpdateModelGroupRequest, ) -> SaasResult { let now = chrono::Utc::now(); sqlx::query( "UPDATE model_groups SET display_name = COALESCE($1, display_name), description = COALESCE($2, description), enabled = COALESCE($3, enabled), failover_strategy = COALESCE($4, failover_strategy), updated_at = $5 WHERE id = $6" ) .bind(req.display_name.as_deref()) .bind(req.description.as_deref()) .bind(req.enabled) .bind(req.failover_strategy.as_deref()) .bind(&now) .bind(group_id) .execute(db).await?; get_model_group(db, group_id).await } pub async fn delete_model_group(db: &PgPool, group_id: &str) -> SaasResult<()> { let result = sqlx::query("DELETE FROM model_groups WHERE id = $1") .bind(group_id).execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound(format!("模型组 {} 不存在", group_id))); } Ok(()) } pub async fn add_group_member( db: &PgPool, group_id: &str, req: &AddGroupMemberRequest, ) -> SaasResult { // 验证 group 存在 sqlx::query_scalar::<_, String>("SELECT id FROM model_groups WHERE id = $1") .bind(group_id).fetch_optional(db).await? .ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?; // 验证 provider 存在 sqlx::query_scalar::<_, String>("SELECT id FROM providers WHERE id = $1") .bind(&req.provider_id).fetch_optional(db).await? .ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", req.provider_id)))?; // 验证 model 存在(避免插入无效 model_id 导致 relay 运行时找不到模型) sqlx::query_scalar::<_, String>("SELECT model_id FROM models WHERE model_id = $1") .bind(&req.model_id).fetch_optional(db).await? .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?; // M-S4: 检查重复成员(避免 DB unique violation 返回 500) let duplicate: Option<(String,)> = sqlx::query_as( "SELECT id FROM model_group_members WHERE group_id = $1 AND provider_id = $2 AND model_id = $3" ) .bind(group_id).bind(&req.provider_id).bind(&req.model_id) .fetch_optional(db).await?; if duplicate.is_some() { return Err(SaasError::AlreadyExists( format!("Provider {} 的模型 {} 已在该模型组中", req.provider_id, req.model_id) )); } let id = uuid::Uuid::new_v4().to_string(); sqlx::query( "INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, NOW(), NOW())" ) .bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority) .execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider {} 的模型 {} 在该模型组", req.provider_id, req.model_id)))?; Ok(ModelGroupMemberInfo { id, provider_id: req.provider_id.clone(), model_id: req.model_id.clone(), priority: req.priority, enabled: true, }) } pub async fn remove_group_member(db: &PgPool, group_id: &str, member_id: &str) -> SaasResult<()> { // M-5: 验证成员确实属于该组 let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1 AND group_id = $2") .bind(member_id).bind(group_id).execute(db).await?; if result.rows_affected() == 0 { return Err(SaasError::NotFound(format!("成员 {} 不属于该模型组", member_id))); } Ok(()) }