Files
zclaw_openfang/crates/zclaw-saas/src/model_config/service.rs
iven e0eb7173c5
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix: 三端联调 P1 修复 — API密钥页崩溃 + 桌面端401恢复 + 用量统计全零
P1-03: vite.config.ts proxy '/api' → '/api/' 加尾部斜杠,
  防止前缀匹配 /api-keys 导致 SPA 路由崩溃

P1-01: kernel_init 增加 api_key 变更检测(token 刷新后自动重连),
  streamStore 增加 401 自动恢复(refresh token → kernel reconnect),
  KernelClient 新增 getConfig() 方法

P1-02: /api/v1/usage 总计改从 billing_usage_quotas 读取
  (authoritative source,SSE 和 JSON 均写入),
  by_model/by_day 仍从 usage_records 读取
2026-04-14 22:02:02 +08:00

723 lines
31 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 模型配置业务逻辑
use sqlx::{PgPool, Row};
use crate::error::{SaasError, SaasResult};
use crate::common::{PaginatedResponse, normalize_pagination};
use crate::crypto;
use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow};
use super::types::*;
// ============ Providers ============
pub async fn list_providers(
db: &PgPool, page: Option<u32>, page_size: Option<u32>, enabled_filter: Option<bool>,
) -> SaasResult<PaginatedResponse<ProviderInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
let (count_sql, data_sql) = if enabled_filter.is_some() {
(
"SELECT COUNT(*) FROM providers WHERE enabled = $1",
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT
FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3",
)
} else {
(
"SELECT COUNT(*) FROM providers",
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT
FROM providers ORDER BY name LIMIT $1 OFFSET $2",
)
};
let total: (i64,) = if let Some(en) = enabled_filter {
sqlx::query_as(count_sql).bind(en).fetch_one(db).await?
} else {
sqlx::query_as(count_sql).fetch_one(db).await?
};
let rows: Vec<ProviderRow> =
if let Some(en) = enabled_filter {
sqlx::query_as(data_sql)
.bind(en).bind(ps as i64).bind(offset)
.fetch_all(db).await?
} else {
sqlx::query_as(data_sql)
.bind(ps as i64).bind(offset)
.fetch_all(db).await?
};
let items = rows.into_iter().map(|r| {
ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<ProviderRow> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT
FROM providers WHERE id = $1"
)
.bind(provider_id)
.fetch_optional(db)
.await?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
}
// 加密 API Key 后存储
let encrypted_api_key = if let Some(ref key) = req.api_key {
if key.is_empty() {
String::new()
} else {
crypto::encrypt_value(key, enc_key)?
}
} else {
String::new()
};
let display_name = req.display_name.as_deref().unwrap_or(&req.name);
sqlx::query(
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
)
.bind(&id).bind(&req.name).bind(display_name).bind(&encrypted_api_key)
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider '{}'", req.name)))?;
get_provider(db, &id).await
}
pub async fn update_provider(
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now();
// Encrypt api_key upfront if provided
let encrypted_api_key = match req.api_key {
Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?),
Some(ref v) if v.is_empty() => Some(String::new()),
_ => None,
};
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE providers SET
display_name = COALESCE($1, display_name),
base_url = COALESCE($2, base_url),
api_protocol = COALESCE($3, api_protocol),
api_key = COALESCE($4, api_key),
enabled = COALESCE($5, enabled),
rate_limit_rpm = COALESCE($6, rate_limit_rpm),
rate_limit_tpm = COALESCE($7, rate_limit_tpm),
updated_at = $8
WHERE id = $9"
)
.bind(req.display_name.as_deref())
.bind(req.base_url.as_deref())
.bind(req.api_protocol.as_deref())
.bind(encrypted_api_key.as_deref())
.bind(req.enabled)
.bind(req.rate_limit_rpm)
.bind(req.rate_limit_tpm)
.bind(&now)
.bind(provider_id)
.execute(db).await?;
get_provider(db, provider_id).await
}
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
.bind(provider_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
}
Ok(())
}
// ============ Models ============
pub async fn list_models(
db: &PgPool, provider_id: Option<&str>, page: Option<u32>, page_size: Option<u32>,
) -> SaasResult<PaginatedResponse<ModelInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
)
} else {
(
"SELECT COUNT(*) FROM models",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
)
};
let total: (i64,) = if let Some(pid) = provider_id {
sqlx::query_as(count_sql).bind(pid).fetch_one(db).await?
} else {
sqlx::query_as(count_sql).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, ModelRow>(data_sql);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|r| {
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
// 验证 provider 存在
let provider = get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
// 检查 model 唯一性
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
)
.bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!(
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
)));
}
// M-2: 检查 model_id 不与模型组名冲突(避免路由歧义)
let group_conflict: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
.bind(&req.model_id).fetch_optional(db).await?;
if group_conflict.is_some() {
return Err(SaasError::InvalidInput(
format!("模型 ID '{}' 与已有模型组名称冲突,请使用不同的 ID", req.model_id)
));
}
let ctx = req.context_window.unwrap_or(8192);
let max_out = req.max_output_tokens.unwrap_or(4096);
let streaming = req.supports_streaming.unwrap_or(true);
let vision = req.supports_vision.unwrap_or(false);
let is_embedding = req.is_embedding.unwrap_or(false);
let model_type = req.model_type.as_deref().unwrap_or(if is_embedding { "embedding" } else { "chat" });
let pi = req.pricing_input.unwrap_or(0.0);
let po = req.pricing_output.unwrap_or(0.0);
sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $12, $13, $13)"
)
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(req.alias.as_deref().unwrap_or(&req.model_id))
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(is_embedding).bind(model_type).bind(pi).bind(po).bind(&now)
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型 '{}' 在 Provider '{}'", req.model_id, req.provider_id)))?;
get_model(db, &id).await
}
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<ModelRow> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
FROM models WHERE id = $1"
)
.bind(model_id)
.fetch_optional(db)
.await?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn update_model(
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now();
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE models SET
alias = COALESCE($1, alias),
context_window = COALESCE($2, context_window),
max_output_tokens = COALESCE($3, max_output_tokens),
supports_streaming = COALESCE($4, supports_streaming),
supports_vision = COALESCE($5, supports_vision),
enabled = COALESCE($6, enabled),
is_embedding = COALESCE($7, is_embedding),
model_type = COALESCE($8, model_type),
pricing_input = COALESCE($9, pricing_input),
pricing_output = COALESCE($10, pricing_output),
updated_at = $11
WHERE id = $12"
)
.bind(req.alias.as_deref())
.bind(req.context_window)
.bind(req.max_output_tokens)
.bind(req.supports_streaming)
.bind(req.supports_vision)
.bind(req.enabled)
.bind(req.is_embedding)
.bind(req.model_type.as_deref())
.bind(req.pricing_input)
.bind(req.pricing_output)
.bind(&now)
.bind(model_id)
.execute(db).await?;
get_model(db, model_id).await
}
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = $1")
.bind(model_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
}
Ok(())
}
// ============ Account API Keys ============
pub async fn list_account_api_keys(
db: &PgPool, account_id: &str, provider_id: Option<&str>,
page: Option<u32>, page_size: Option<u32>,
) -> SaasResult<PaginatedResponse<AccountApiKeyInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
// Build COUNT and data queries based on whether provider_id is provided
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL",
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at::TEXT, created_at::TEXT, key_value
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4",
)
} else {
(
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL",
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at::TEXT, created_at::TEXT, key_value
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3",
)
};
let total: (i64,) = if provider_id.is_some() {
let mut q = sqlx::query_as(count_sql).bind(account_id);
if let Some(pid) = provider_id { q = q.bind(pid); }
q.fetch_one(db).await?
} else {
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql)
.bind(account_id);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|r| {
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
let masked = mask_api_key(&r.key_value);
AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn create_account_api_key(
db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32],
) -> SaasResult<AccountApiKeyInfo> {
// 验证 provider 存在
get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let permissions = serde_json::to_string(&req.permissions)?;
// 加密 key_value 后存储
let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?;
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
)
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
.bind(&req.key_label).bind(&permissions).bind(&now)
.execute(db).await?;
let masked = mask_api_key(&req.key_value);
Ok(AccountApiKeyInfo {
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
created_at: now.to_rfc3339(), masked_key: masked,
})
}
pub async fn rotate_account_api_key(
db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32],
) -> SaasResult<()> {
let now = chrono::Utc::now();
let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?;
let result = sqlx::query(
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
)
.bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
pub async fn revoke_account_api_key(
db: &PgPool, key_id: &str, account_id: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now();
let result = sqlx::query(
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
)
.bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
// ============ Usage Statistics ============
pub async fn get_usage_stats(
db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> {
// === Totals: from billing_usage_quotas (authoritative source) ===
// billing_usage_quotas is written to on every relay request (both JSON and SSE),
// whereas usage_records has 0 tokens for SSE requests. Use billing as the primary source.
let billing_row = sqlx::query(
"SELECT COALESCE(SUM(input_tokens), 0)::bigint,
COALESCE(SUM(output_tokens), 0)::bigint,
COALESCE(SUM(relay_requests), 0)::bigint
FROM billing_usage_quotas WHERE account_id = $1"
)
.bind(account_id)
.fetch_one(db)
.await?;
let total_input: i64 = billing_row.try_get(0).unwrap_or(0);
let total_output: i64 = billing_row.try_get(1).unwrap_or(0);
let total_requests: i64 = billing_row.try_get(2).unwrap_or(0);
// === Breakdowns: from usage_records (per-request detail) ===
// Optional date filters: pass as TEXT with explicit SQL cast.
let from_str: Option<&str> = query.from.as_deref();
let to_str: Option<String> = query.to.as_ref().map(|s| {
if s.len() == 10 { format!("{}T23:59:59", s) } else { s.clone() }
});
// Build SQL dynamically for usage_records breakdowns.
// Date parameters are injected as SQL literals (validated via chrono parse).
let mut where_parts = vec![format!("account_id = '{}'", account_id.replace('\'', "''"))];
if let Some(f) = from_str {
let valid = chrono::NaiveDate::parse_from_str(f, "%Y-%m-%d").is_ok()
|| chrono::NaiveDateTime::parse_from_str(f, "%Y-%m-%dT%H:%M:%S%.f").is_ok();
if !valid {
return Err(SaasError::InvalidInput(format!("Invalid 'from' date: {}", f)));
}
where_parts.push(format!("created_at::timestamptz >= '{}T00:00:00Z'::timestamptz", f.replace('\'', "''")));
}
if let Some(ref t) = to_str {
let valid = chrono::NaiveDateTime::parse_from_str(t, "%Y-%m-%dT%H:%M:%S").is_ok()
|| chrono::NaiveDate::parse_from_str(t, "%Y-%m-%d").is_ok();
if !valid {
return Err(SaasError::InvalidInput(format!("Invalid 'to' date: {}", t)));
}
where_parts.push(format!("created_at::timestamptz <= '{}'::timestamptz", t.replace('\'', "''")));
}
if let Some(ref pid) = query.provider_id {
where_parts.push(format!("provider_id = '{}'", pid.replace('\'', "''")));
}
if let Some(ref mid) = query.model_id {
where_parts.push(format!("model_id = '{}'", mid.replace('\'', "''")));
}
let where_clause = where_parts.join(" AND ");
// 按模型统计
let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20", where_clause
);
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(&by_model_sql).fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|r| {
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
}).collect();
// 按天统计 (使用 days 参数或默认 30 天)
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
let from_days_str = (chrono::Utc::now() - chrono::Duration::days(days))
.format("%Y-%m-%d").to_string();
let daily_sql = format!(
"SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens
FROM usage_records WHERE account_id = '{}' AND created_at::timestamptz >= '{}T00:00:00Z'::timestamptz
GROUP BY created_at::date ORDER BY day DESC LIMIT {}",
account_id.replace('\'', "''"), from_days_str.replace('\'', "''"), days
);
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(&daily_sql).fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|r| {
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
}).collect();
// 按 group_by 过滤返回
let group_by = query.group_by.as_deref();
let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] };
let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] };
Ok(UsageStats {
total_requests,
total_input_tokens: total_input,
total_output_tokens: total_output,
by_model,
by_day,
})
}
pub async fn record_usage(
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
status: &str, error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(account_id).bind(provider_id).bind(model_id)
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
.bind(status).bind(error_message).bind(&now)
.execute(db).await?;
Ok(())
}
// ============ Helpers ============
fn mask_api_key(key: &str) -> String {
if key.len() <= 8 {
return "*".repeat(key.len());
}
format!("{}...{}", &key[..4], &key[key.len()-4..])
}
// ============ Model Groups ============
pub async fn list_model_groups(db: &PgPool) -> SaasResult<Vec<ModelGroupInfo>> {
let group_rows: Vec<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
COALESCE(failover_strategy, 'quota_aware'), created_at::TEXT, updated_at::TEXT
FROM model_groups ORDER BY name"
).fetch_all(db).await?;
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
"SELECT id, group_id, provider_id, model_id, priority, enabled
FROM model_group_members ORDER BY priority ASC"
).fetch_all(db).await?;
let groups = group_rows.into_iter().map(|(id, name, display_name, description, enabled, failover_strategy, created_at, updated_at)| {
let members: Vec<ModelGroupMemberInfo> = member_rows.iter()
.filter(|(_, gid, _, _, _, _)| gid == &id)
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
id: mid.clone(),
provider_id: pid.clone(),
model_id: mid2.clone(),
priority: *pri,
enabled: *en,
})
.collect();
ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at }
}).collect();
Ok(groups)
}
pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult<ModelGroupInfo> {
let row: Option<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
COALESCE(failover_strategy, 'quota_aware'), created_at::TEXT, updated_at::TEXT
FROM model_groups WHERE id = $1"
).bind(group_id).fetch_optional(db).await?;
let (id, name, display_name, description, enabled, failover_strategy, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
"SELECT id, group_id, provider_id, model_id, priority, enabled
FROM model_group_members WHERE group_id = $1 ORDER BY priority ASC"
).bind(group_id).fetch_all(db).await?;
let members = member_rows.into_iter()
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
id: mid, provider_id: pid, model_id: mid2, priority: pri, enabled: en,
})
.collect();
Ok(ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at })
}
pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult<ModelGroupInfo> {
// M-S1: failover_strategy 白名单校验
const VALID_STRATEGIES: &[&str] = &["quota_aware", "priority", "random"];
if !VALID_STRATEGIES.contains(&req.failover_strategy.as_str()) {
return Err(SaasError::InvalidInput(
format!("failover_strategy 必须是 {:?} 之一", VALID_STRATEGIES)
));
}
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("模型组 '{}' 已存在", req.name)));
}
// 名称不能和已有 model_id 冲突(避免路由歧义)
let model_conflict: Option<(String,)> = sqlx::query_as("SELECT model_id FROM models WHERE model_id = $1")
.bind(&req.name).fetch_optional(db).await?;
if model_conflict.is_some() {
return Err(SaasError::InvalidInput(
format!("模型组名称 '{}' 与已有模型 ID 冲突,请使用不同的名称", req.name)
));
}
sqlx::query(
"INSERT INTO model_groups (id, name, display_name, description, failover_strategy, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $6)"
)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description)
.bind(&req.failover_strategy).bind(&now)
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型组 '{}'", req.name)))?;
get_model_group(db, &id).await
}
pub async fn update_model_group(
db: &PgPool, group_id: &str, req: &UpdateModelGroupRequest,
) -> SaasResult<ModelGroupInfo> {
let now = chrono::Utc::now();
sqlx::query(
"UPDATE model_groups SET
display_name = COALESCE($1, display_name),
description = COALESCE($2, description),
enabled = COALESCE($3, enabled),
failover_strategy = COALESCE($4, failover_strategy),
updated_at = $5
WHERE id = $6"
)
.bind(req.display_name.as_deref())
.bind(req.description.as_deref())
.bind(req.enabled)
.bind(req.failover_strategy.as_deref())
.bind(&now)
.bind(group_id)
.execute(db).await?;
get_model_group(db, group_id).await
}
pub async fn delete_model_group(db: &PgPool, group_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM model_groups WHERE id = $1")
.bind(group_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("模型组 {} 不存在", group_id)));
}
Ok(())
}
pub async fn add_group_member(
db: &PgPool, group_id: &str, req: &AddGroupMemberRequest,
) -> SaasResult<ModelGroupMemberInfo> {
// 验证 group 存在
sqlx::query_scalar::<_, String>("SELECT id FROM model_groups WHERE id = $1")
.bind(group_id).fetch_optional(db).await?
.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
// 验证 provider 存在
sqlx::query_scalar::<_, String>("SELECT id FROM providers WHERE id = $1")
.bind(&req.provider_id).fetch_optional(db).await?
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", req.provider_id)))?;
// 验证 model 存在(避免插入无效 model_id 导致 relay 运行时找不到模型)
sqlx::query_scalar::<_, String>("SELECT model_id FROM models WHERE model_id = $1")
.bind(&req.model_id).fetch_optional(db).await?
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?;
// M-S4: 检查重复成员(避免 DB unique violation 返回 500
let duplicate: Option<(String,)> = sqlx::query_as(
"SELECT id FROM model_group_members WHERE group_id = $1 AND provider_id = $2 AND model_id = $3"
)
.bind(group_id).bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?;
if duplicate.is_some() {
return Err(SaasError::AlreadyExists(
format!("Provider {} 的模型 {} 已在该模型组中", req.provider_id, req.model_id)
));
}
let id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())"
)
.bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority)
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider {} 的模型 {} 在该模型组", req.provider_id, req.model_id)))?;
Ok(ModelGroupMemberInfo {
id,
provider_id: req.provider_id.clone(),
model_id: req.model_id.clone(),
priority: req.priority,
enabled: true,
})
}
pub async fn remove_group_member(db: &PgPool, group_id: &str, member_id: &str) -> SaasResult<()> {
// M-5: 验证成员确实属于该组
let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1 AND group_id = $2")
.bind(member_id).bind(group_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("成员 {} 不属于该模型组", member_id)));
}
Ok(())
}