Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P1-03: vite.config.ts proxy '/api' → '/api/' 加尾部斜杠, 防止前缀匹配 /api-keys 导致 SPA 路由崩溃 P1-01: kernel_init 增加 api_key 变更检测(token 刷新后自动重连), streamStore 增加 401 自动恢复(refresh token → kernel reconnect), KernelClient 新增 getConfig() 方法 P1-02: /api/v1/usage 总计改从 billing_usage_quotas 读取 (authoritative source,SSE 和 JSON 均写入), by_model/by_day 仍从 usage_records 读取
723 lines
31 KiB
Rust
723 lines
31 KiB
Rust
//! 模型配置业务逻辑
|
||
|
||
use sqlx::{PgPool, Row};
|
||
use crate::error::{SaasError, SaasResult};
|
||
use crate::common::{PaginatedResponse, normalize_pagination};
|
||
use crate::crypto;
|
||
use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow};
|
||
use super::types::*;
|
||
|
||
// ============ Providers ============
|
||
|
||
pub async fn list_providers(
|
||
db: &PgPool, page: Option<u32>, page_size: Option<u32>, enabled_filter: Option<bool>,
|
||
) -> SaasResult<PaginatedResponse<ProviderInfo>> {
|
||
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||
|
||
let (count_sql, data_sql) = if enabled_filter.is_some() {
|
||
(
|
||
"SELECT COUNT(*) FROM providers WHERE enabled = $1",
|
||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT
|
||
FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3",
|
||
)
|
||
} else {
|
||
(
|
||
"SELECT COUNT(*) FROM providers",
|
||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT
|
||
FROM providers ORDER BY name LIMIT $1 OFFSET $2",
|
||
)
|
||
};
|
||
|
||
let total: (i64,) = if let Some(en) = enabled_filter {
|
||
sqlx::query_as(count_sql).bind(en).fetch_one(db).await?
|
||
} else {
|
||
sqlx::query_as(count_sql).fetch_one(db).await?
|
||
};
|
||
|
||
let rows: Vec<ProviderRow> =
|
||
if let Some(en) = enabled_filter {
|
||
sqlx::query_as(data_sql)
|
||
.bind(en).bind(ps as i64).bind(offset)
|
||
.fetch_all(db).await?
|
||
} else {
|
||
sqlx::query_as(data_sql)
|
||
.bind(ps as i64).bind(offset)
|
||
.fetch_all(db).await?
|
||
};
|
||
|
||
let items = rows.into_iter().map(|r| {
|
||
ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }
|
||
}).collect();
|
||
|
||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||
}
|
||
|
||
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
||
let row: Option<ProviderRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at::TEXT, updated_at::TEXT
|
||
FROM providers WHERE id = $1"
|
||
)
|
||
.bind(provider_id)
|
||
.fetch_optional(db)
|
||
.await?;
|
||
|
||
let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
|
||
|
||
Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at })
|
||
}
|
||
|
||
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
|
||
// 检查名称唯一性
|
||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
|
||
.bind(&req.name).fetch_optional(db).await?;
|
||
if existing.is_some() {
|
||
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
|
||
}
|
||
|
||
// 加密 API Key 后存储
|
||
let encrypted_api_key = if let Some(ref key) = req.api_key {
|
||
if key.is_empty() {
|
||
String::new()
|
||
} else {
|
||
crypto::encrypt_value(key, enc_key)?
|
||
}
|
||
} else {
|
||
String::new()
|
||
};
|
||
|
||
let display_name = req.display_name.as_deref().unwrap_or(&req.name);
|
||
|
||
sqlx::query(
|
||
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
|
||
)
|
||
.bind(&id).bind(&req.name).bind(display_name).bind(&encrypted_api_key)
|
||
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider '{}'", req.name)))?;
|
||
|
||
get_provider(db, &id).await
|
||
}
|
||
|
||
pub async fn update_provider(
|
||
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
|
||
) -> SaasResult<ProviderInfo> {
|
||
let now = chrono::Utc::now();
|
||
|
||
// Encrypt api_key upfront if provided
|
||
let encrypted_api_key = match req.api_key {
|
||
Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?),
|
||
Some(ref v) if v.is_empty() => Some(String::new()),
|
||
_ => None,
|
||
};
|
||
|
||
// COALESCE pattern: all updatable fields in a single static SQL.
|
||
// NULL parameters leave the column unchanged.
|
||
sqlx::query(
|
||
"UPDATE providers SET
|
||
display_name = COALESCE($1, display_name),
|
||
base_url = COALESCE($2, base_url),
|
||
api_protocol = COALESCE($3, api_protocol),
|
||
api_key = COALESCE($4, api_key),
|
||
enabled = COALESCE($5, enabled),
|
||
rate_limit_rpm = COALESCE($6, rate_limit_rpm),
|
||
rate_limit_tpm = COALESCE($7, rate_limit_tpm),
|
||
updated_at = $8
|
||
WHERE id = $9"
|
||
)
|
||
.bind(req.display_name.as_deref())
|
||
.bind(req.base_url.as_deref())
|
||
.bind(req.api_protocol.as_deref())
|
||
.bind(encrypted_api_key.as_deref())
|
||
.bind(req.enabled)
|
||
.bind(req.rate_limit_rpm)
|
||
.bind(req.rate_limit_tpm)
|
||
.bind(&now)
|
||
.bind(provider_id)
|
||
.execute(db).await?;
|
||
|
||
get_provider(db, provider_id).await
|
||
}
|
||
|
||
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
|
||
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
|
||
.bind(provider_id).execute(db).await?;
|
||
|
||
if result.rows_affected() == 0 {
|
||
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
// ============ Models ============
|
||
|
||
pub async fn list_models(
|
||
db: &PgPool, provider_id: Option<&str>, page: Option<u32>, page_size: Option<u32>,
|
||
) -> SaasResult<PaginatedResponse<ModelInfo>> {
|
||
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||
|
||
let (count_sql, data_sql) = if provider_id.is_some() {
|
||
(
|
||
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
|
||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
|
||
)
|
||
} else {
|
||
(
|
||
"SELECT COUNT(*) FROM models",
|
||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
|
||
)
|
||
};
|
||
|
||
let total: (i64,) = if let Some(pid) = provider_id {
|
||
sqlx::query_as(count_sql).bind(pid).fetch_one(db).await?
|
||
} else {
|
||
sqlx::query_as(count_sql).fetch_one(db).await?
|
||
};
|
||
|
||
let mut query = sqlx::query_as::<_, ModelRow>(data_sql);
|
||
if let Some(pid) = provider_id {
|
||
query = query.bind(pid);
|
||
}
|
||
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||
|
||
let items = rows.into_iter().map(|r| {
|
||
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
|
||
}).collect();
|
||
|
||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||
}
|
||
|
||
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
||
// 验证 provider 存在
|
||
let provider = get_provider(db, &req.provider_id).await?;
|
||
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
|
||
// 检查 model 唯一性
|
||
let existing: Option<(String,)> = sqlx::query_as(
|
||
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
|
||
)
|
||
.bind(&req.provider_id).bind(&req.model_id)
|
||
.fetch_optional(db).await?;
|
||
|
||
if existing.is_some() {
|
||
return Err(SaasError::AlreadyExists(format!(
|
||
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
|
||
)));
|
||
}
|
||
|
||
// M-2: 检查 model_id 不与模型组名冲突(避免路由歧义)
|
||
let group_conflict: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
|
||
.bind(&req.model_id).fetch_optional(db).await?;
|
||
if group_conflict.is_some() {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("模型 ID '{}' 与已有模型组名称冲突,请使用不同的 ID", req.model_id)
|
||
));
|
||
}
|
||
|
||
let ctx = req.context_window.unwrap_or(8192);
|
||
let max_out = req.max_output_tokens.unwrap_or(4096);
|
||
let streaming = req.supports_streaming.unwrap_or(true);
|
||
let vision = req.supports_vision.unwrap_or(false);
|
||
let is_embedding = req.is_embedding.unwrap_or(false);
|
||
let model_type = req.model_type.as_deref().unwrap_or(if is_embedding { "embedding" } else { "chat" });
|
||
let pi = req.pricing_input.unwrap_or(0.0);
|
||
let po = req.pricing_output.unwrap_or(0.0);
|
||
|
||
sqlx::query(
|
||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $12, $13, $13)"
|
||
)
|
||
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(req.alias.as_deref().unwrap_or(&req.model_id))
|
||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(is_embedding).bind(model_type).bind(pi).bind(po).bind(&now)
|
||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型 '{}' 在 Provider '{}'", req.model_id, req.provider_id)))?;
|
||
|
||
get_model(db, &id).await
|
||
}
|
||
|
||
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||
let row: Option<ModelRow> =
|
||
sqlx::query_as(
|
||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||
FROM models WHERE id = $1"
|
||
)
|
||
.bind(model_id)
|
||
.fetch_optional(db)
|
||
.await?;
|
||
|
||
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
||
|
||
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
|
||
}
|
||
|
||
pub async fn update_model(
|
||
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
|
||
) -> SaasResult<ModelInfo> {
|
||
let now = chrono::Utc::now();
|
||
|
||
// COALESCE pattern: all updatable fields in a single static SQL.
|
||
// NULL parameters leave the column unchanged.
|
||
sqlx::query(
|
||
"UPDATE models SET
|
||
alias = COALESCE($1, alias),
|
||
context_window = COALESCE($2, context_window),
|
||
max_output_tokens = COALESCE($3, max_output_tokens),
|
||
supports_streaming = COALESCE($4, supports_streaming),
|
||
supports_vision = COALESCE($5, supports_vision),
|
||
enabled = COALESCE($6, enabled),
|
||
is_embedding = COALESCE($7, is_embedding),
|
||
model_type = COALESCE($8, model_type),
|
||
pricing_input = COALESCE($9, pricing_input),
|
||
pricing_output = COALESCE($10, pricing_output),
|
||
updated_at = $11
|
||
WHERE id = $12"
|
||
)
|
||
.bind(req.alias.as_deref())
|
||
.bind(req.context_window)
|
||
.bind(req.max_output_tokens)
|
||
.bind(req.supports_streaming)
|
||
.bind(req.supports_vision)
|
||
.bind(req.enabled)
|
||
.bind(req.is_embedding)
|
||
.bind(req.model_type.as_deref())
|
||
.bind(req.pricing_input)
|
||
.bind(req.pricing_output)
|
||
.bind(&now)
|
||
.bind(model_id)
|
||
.execute(db).await?;
|
||
|
||
get_model(db, model_id).await
|
||
}
|
||
|
||
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
|
||
let result = sqlx::query("DELETE FROM models WHERE id = $1")
|
||
.bind(model_id).execute(db).await?;
|
||
|
||
if result.rows_affected() == 0 {
|
||
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
// ============ Account API Keys ============
|
||
|
||
pub async fn list_account_api_keys(
|
||
db: &PgPool, account_id: &str, provider_id: Option<&str>,
|
||
page: Option<u32>, page_size: Option<u32>,
|
||
) -> SaasResult<PaginatedResponse<AccountApiKeyInfo>> {
|
||
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||
|
||
// Build COUNT and data queries based on whether provider_id is provided
|
||
let (count_sql, data_sql) = if provider_id.is_some() {
|
||
(
|
||
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL",
|
||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at::TEXT, created_at::TEXT, key_value
|
||
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4",
|
||
)
|
||
} else {
|
||
(
|
||
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL",
|
||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at::TEXT, created_at::TEXT, key_value
|
||
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3",
|
||
)
|
||
};
|
||
|
||
let total: (i64,) = if provider_id.is_some() {
|
||
let mut q = sqlx::query_as(count_sql).bind(account_id);
|
||
if let Some(pid) = provider_id { q = q.bind(pid); }
|
||
q.fetch_one(db).await?
|
||
} else {
|
||
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
|
||
};
|
||
|
||
let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql)
|
||
.bind(account_id);
|
||
if let Some(pid) = provider_id {
|
||
query = query.bind(pid);
|
||
}
|
||
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||
|
||
let items = rows.into_iter().map(|r| {
|
||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
||
let masked = mask_api_key(&r.key_value);
|
||
AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked }
|
||
}).collect();
|
||
|
||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||
}
|
||
|
||
pub async fn create_account_api_key(
|
||
db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32],
|
||
) -> SaasResult<AccountApiKeyInfo> {
|
||
// 验证 provider 存在
|
||
get_provider(db, &req.provider_id).await?;
|
||
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
let permissions = serde_json::to_string(&req.permissions)?;
|
||
|
||
// 加密 key_value 后存储
|
||
let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?;
|
||
|
||
sqlx::query(
|
||
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
|
||
)
|
||
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
|
||
.bind(&req.key_label).bind(&permissions).bind(&now)
|
||
.execute(db).await?;
|
||
|
||
let masked = mask_api_key(&req.key_value);
|
||
Ok(AccountApiKeyInfo {
|
||
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
|
||
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
|
||
created_at: now.to_rfc3339(), masked_key: masked,
|
||
})
|
||
}
|
||
|
||
pub async fn rotate_account_api_key(
|
||
db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32],
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?;
|
||
let result = sqlx::query(
|
||
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
|
||
)
|
||
.bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id)
|
||
.execute(db).await?;
|
||
|
||
if result.rows_affected() == 0 {
|
||
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn revoke_account_api_key(
|
||
db: &PgPool, key_id: &str, account_id: &str,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
let result = sqlx::query(
|
||
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
||
)
|
||
.bind(&now).bind(key_id).bind(account_id)
|
||
.execute(db).await?;
|
||
|
||
if result.rows_affected() == 0 {
|
||
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
// ============ Usage Statistics ============
|
||
|
||
pub async fn get_usage_stats(
|
||
db: &PgPool, account_id: &str, query: &UsageQuery,
|
||
) -> SaasResult<UsageStats> {
|
||
// === Totals: from billing_usage_quotas (authoritative source) ===
|
||
// billing_usage_quotas is written to on every relay request (both JSON and SSE),
|
||
// whereas usage_records has 0 tokens for SSE requests. Use billing as the primary source.
|
||
let billing_row = sqlx::query(
|
||
"SELECT COALESCE(SUM(input_tokens), 0)::bigint,
|
||
COALESCE(SUM(output_tokens), 0)::bigint,
|
||
COALESCE(SUM(relay_requests), 0)::bigint
|
||
FROM billing_usage_quotas WHERE account_id = $1"
|
||
)
|
||
.bind(account_id)
|
||
.fetch_one(db)
|
||
.await?;
|
||
let total_input: i64 = billing_row.try_get(0).unwrap_or(0);
|
||
let total_output: i64 = billing_row.try_get(1).unwrap_or(0);
|
||
let total_requests: i64 = billing_row.try_get(2).unwrap_or(0);
|
||
|
||
// === Breakdowns: from usage_records (per-request detail) ===
|
||
// Optional date filters: pass as TEXT with explicit SQL cast.
|
||
let from_str: Option<&str> = query.from.as_deref();
|
||
let to_str: Option<String> = query.to.as_ref().map(|s| {
|
||
if s.len() == 10 { format!("{}T23:59:59", s) } else { s.clone() }
|
||
});
|
||
|
||
// Build SQL dynamically for usage_records breakdowns.
|
||
// Date parameters are injected as SQL literals (validated via chrono parse).
|
||
let mut where_parts = vec![format!("account_id = '{}'", account_id.replace('\'', "''"))];
|
||
if let Some(f) = from_str {
|
||
let valid = chrono::NaiveDate::parse_from_str(f, "%Y-%m-%d").is_ok()
|
||
|| chrono::NaiveDateTime::parse_from_str(f, "%Y-%m-%dT%H:%M:%S%.f").is_ok();
|
||
if !valid {
|
||
return Err(SaasError::InvalidInput(format!("Invalid 'from' date: {}", f)));
|
||
}
|
||
where_parts.push(format!("created_at::timestamptz >= '{}T00:00:00Z'::timestamptz", f.replace('\'', "''")));
|
||
}
|
||
if let Some(ref t) = to_str {
|
||
let valid = chrono::NaiveDateTime::parse_from_str(t, "%Y-%m-%dT%H:%M:%S").is_ok()
|
||
|| chrono::NaiveDate::parse_from_str(t, "%Y-%m-%d").is_ok();
|
||
if !valid {
|
||
return Err(SaasError::InvalidInput(format!("Invalid 'to' date: {}", t)));
|
||
}
|
||
where_parts.push(format!("created_at::timestamptz <= '{}'::timestamptz", t.replace('\'', "''")));
|
||
}
|
||
if let Some(ref pid) = query.provider_id {
|
||
where_parts.push(format!("provider_id = '{}'", pid.replace('\'', "''")));
|
||
}
|
||
if let Some(ref mid) = query.model_id {
|
||
where_parts.push(format!("model_id = '{}'", mid.replace('\'', "''")));
|
||
}
|
||
let where_clause = where_parts.join(" AND ");
|
||
|
||
// 按模型统计
|
||
let by_model_sql = format!(
|
||
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens
|
||
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20", where_clause
|
||
);
|
||
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(&by_model_sql).fetch_all(db).await?;
|
||
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
|
||
.map(|r| {
|
||
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
||
}).collect();
|
||
|
||
// 按天统计 (使用 days 参数或默认 30 天)
|
||
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
|
||
let from_days_str = (chrono::Utc::now() - chrono::Duration::days(days))
|
||
.format("%Y-%m-%d").to_string();
|
||
let daily_sql = format!(
|
||
"SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens
|
||
FROM usage_records WHERE account_id = '{}' AND created_at::timestamptz >= '{}T00:00:00Z'::timestamptz
|
||
GROUP BY created_at::date ORDER BY day DESC LIMIT {}",
|
||
account_id.replace('\'', "''"), from_days_str.replace('\'', "''"), days
|
||
);
|
||
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(&daily_sql).fetch_all(db).await?;
|
||
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
|
||
.map(|r| {
|
||
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
||
}).collect();
|
||
|
||
// 按 group_by 过滤返回
|
||
let group_by = query.group_by.as_deref();
|
||
let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] };
|
||
let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] };
|
||
|
||
Ok(UsageStats {
|
||
total_requests,
|
||
total_input_tokens: total_input,
|
||
total_output_tokens: total_output,
|
||
by_model,
|
||
by_day,
|
||
})
|
||
}
|
||
|
||
pub async fn record_usage(
|
||
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
|
||
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
|
||
status: &str, error_message: Option<&str>,
|
||
) -> SaasResult<()> {
|
||
let now = chrono::Utc::now();
|
||
sqlx::query(
|
||
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||
)
|
||
.bind(account_id).bind(provider_id).bind(model_id)
|
||
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
|
||
.bind(status).bind(error_message).bind(&now)
|
||
.execute(db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
// ============ Helpers ============
|
||
|
||
fn mask_api_key(key: &str) -> String {
|
||
if key.len() <= 8 {
|
||
return "*".repeat(key.len());
|
||
}
|
||
format!("{}...{}", &key[..4], &key[key.len()-4..])
|
||
}
|
||
|
||
// ============ Model Groups ============
|
||
|
||
pub async fn list_model_groups(db: &PgPool) -> SaasResult<Vec<ModelGroupInfo>> {
|
||
let group_rows: Vec<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
|
||
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
|
||
COALESCE(failover_strategy, 'quota_aware'), created_at::TEXT, updated_at::TEXT
|
||
FROM model_groups ORDER BY name"
|
||
).fetch_all(db).await?;
|
||
|
||
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
|
||
"SELECT id, group_id, provider_id, model_id, priority, enabled
|
||
FROM model_group_members ORDER BY priority ASC"
|
||
).fetch_all(db).await?;
|
||
|
||
let groups = group_rows.into_iter().map(|(id, name, display_name, description, enabled, failover_strategy, created_at, updated_at)| {
|
||
let members: Vec<ModelGroupMemberInfo> = member_rows.iter()
|
||
.filter(|(_, gid, _, _, _, _)| gid == &id)
|
||
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
|
||
id: mid.clone(),
|
||
provider_id: pid.clone(),
|
||
model_id: mid2.clone(),
|
||
priority: *pri,
|
||
enabled: *en,
|
||
})
|
||
.collect();
|
||
ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at }
|
||
}).collect();
|
||
|
||
Ok(groups)
|
||
}
|
||
|
||
pub async fn get_model_group(db: &PgPool, group_id: &str) -> SaasResult<ModelGroupInfo> {
|
||
let row: Option<(String, String, String, String, bool, String, String, String)> = sqlx::query_as(
|
||
"SELECT id, name, display_name, COALESCE(description, ''), enabled,
|
||
COALESCE(failover_strategy, 'quota_aware'), created_at::TEXT, updated_at::TEXT
|
||
FROM model_groups WHERE id = $1"
|
||
).bind(group_id).fetch_optional(db).await?;
|
||
|
||
let (id, name, display_name, description, enabled, failover_strategy, created_at, updated_at) =
|
||
row.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
|
||
|
||
let member_rows: Vec<(String, String, String, String, i32, bool)> = sqlx::query_as(
|
||
"SELECT id, group_id, provider_id, model_id, priority, enabled
|
||
FROM model_group_members WHERE group_id = $1 ORDER BY priority ASC"
|
||
).bind(group_id).fetch_all(db).await?;
|
||
|
||
let members = member_rows.into_iter()
|
||
.map(|(mid, _, pid, mid2, pri, en)| ModelGroupMemberInfo {
|
||
id: mid, provider_id: pid, model_id: mid2, priority: pri, enabled: en,
|
||
})
|
||
.collect();
|
||
|
||
Ok(ModelGroupInfo { id, name, display_name, description, enabled, failover_strategy, members, created_at, updated_at })
|
||
}
|
||
|
||
pub async fn create_model_group(db: &PgPool, req: &CreateModelGroupRequest) -> SaasResult<ModelGroupInfo> {
|
||
// M-S1: failover_strategy 白名单校验
|
||
const VALID_STRATEGIES: &[&str] = &["quota_aware", "priority", "random"];
|
||
if !VALID_STRATEGIES.contains(&req.failover_strategy.as_str()) {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("failover_strategy 必须是 {:?} 之一", VALID_STRATEGIES)
|
||
));
|
||
}
|
||
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
let now = chrono::Utc::now();
|
||
|
||
// 检查名称唯一性
|
||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM model_groups WHERE name = $1")
|
||
.bind(&req.name).fetch_optional(db).await?;
|
||
if existing.is_some() {
|
||
return Err(SaasError::AlreadyExists(format!("模型组 '{}' 已存在", req.name)));
|
||
}
|
||
|
||
// 名称不能和已有 model_id 冲突(避免路由歧义)
|
||
let model_conflict: Option<(String,)> = sqlx::query_as("SELECT model_id FROM models WHERE model_id = $1")
|
||
.bind(&req.name).fetch_optional(db).await?;
|
||
if model_conflict.is_some() {
|
||
return Err(SaasError::InvalidInput(
|
||
format!("模型组名称 '{}' 与已有模型 ID 冲突,请使用不同的名称", req.name)
|
||
));
|
||
}
|
||
|
||
sqlx::query(
|
||
"INSERT INTO model_groups (id, name, display_name, description, failover_strategy, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6, $6)"
|
||
)
|
||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.description)
|
||
.bind(&req.failover_strategy).bind(&now)
|
||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型组 '{}'", req.name)))?;
|
||
|
||
get_model_group(db, &id).await
|
||
}
|
||
|
||
pub async fn update_model_group(
|
||
db: &PgPool, group_id: &str, req: &UpdateModelGroupRequest,
|
||
) -> SaasResult<ModelGroupInfo> {
|
||
let now = chrono::Utc::now();
|
||
|
||
sqlx::query(
|
||
"UPDATE model_groups SET
|
||
display_name = COALESCE($1, display_name),
|
||
description = COALESCE($2, description),
|
||
enabled = COALESCE($3, enabled),
|
||
failover_strategy = COALESCE($4, failover_strategy),
|
||
updated_at = $5
|
||
WHERE id = $6"
|
||
)
|
||
.bind(req.display_name.as_deref())
|
||
.bind(req.description.as_deref())
|
||
.bind(req.enabled)
|
||
.bind(req.failover_strategy.as_deref())
|
||
.bind(&now)
|
||
.bind(group_id)
|
||
.execute(db).await?;
|
||
|
||
get_model_group(db, group_id).await
|
||
}
|
||
|
||
pub async fn delete_model_group(db: &PgPool, group_id: &str) -> SaasResult<()> {
|
||
let result = sqlx::query("DELETE FROM model_groups WHERE id = $1")
|
||
.bind(group_id).execute(db).await?;
|
||
if result.rows_affected() == 0 {
|
||
return Err(SaasError::NotFound(format!("模型组 {} 不存在", group_id)));
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn add_group_member(
|
||
db: &PgPool, group_id: &str, req: &AddGroupMemberRequest,
|
||
) -> SaasResult<ModelGroupMemberInfo> {
|
||
// 验证 group 存在
|
||
sqlx::query_scalar::<_, String>("SELECT id FROM model_groups WHERE id = $1")
|
||
.bind(group_id).fetch_optional(db).await?
|
||
.ok_or_else(|| SaasError::NotFound(format!("模型组 {} 不存在", group_id)))?;
|
||
|
||
// 验证 provider 存在
|
||
sqlx::query_scalar::<_, String>("SELECT id FROM providers WHERE id = $1")
|
||
.bind(&req.provider_id).fetch_optional(db).await?
|
||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", req.provider_id)))?;
|
||
|
||
// 验证 model 存在(避免插入无效 model_id 导致 relay 运行时找不到模型)
|
||
sqlx::query_scalar::<_, String>("SELECT model_id FROM models WHERE model_id = $1")
|
||
.bind(&req.model_id).fetch_optional(db).await?
|
||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", req.model_id)))?;
|
||
|
||
// M-S4: 检查重复成员(避免 DB unique violation 返回 500)
|
||
let duplicate: Option<(String,)> = sqlx::query_as(
|
||
"SELECT id FROM model_group_members WHERE group_id = $1 AND provider_id = $2 AND model_id = $3"
|
||
)
|
||
.bind(group_id).bind(&req.provider_id).bind(&req.model_id)
|
||
.fetch_optional(db).await?;
|
||
if duplicate.is_some() {
|
||
return Err(SaasError::AlreadyExists(
|
||
format!("Provider {} 的模型 {} 已在该模型组中", req.provider_id, req.model_id)
|
||
));
|
||
}
|
||
|
||
let id = uuid::Uuid::new_v4().to_string();
|
||
sqlx::query(
|
||
"INSERT INTO model_group_members (id, group_id, provider_id, model_id, priority, created_at, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())"
|
||
)
|
||
.bind(&id).bind(group_id).bind(&req.provider_id).bind(&req.model_id).bind(req.priority)
|
||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("Provider {} 的模型 {} 在该模型组", req.provider_id, req.model_id)))?;
|
||
|
||
Ok(ModelGroupMemberInfo {
|
||
id,
|
||
provider_id: req.provider_id.clone(),
|
||
model_id: req.model_id.clone(),
|
||
priority: req.priority,
|
||
enabled: true,
|
||
})
|
||
}
|
||
|
||
pub async fn remove_group_member(db: &PgPool, group_id: &str, member_id: &str) -> SaasResult<()> {
|
||
// M-5: 验证成员确实属于该组
|
||
let result = sqlx::query("DELETE FROM model_group_members WHERE id = $1 AND group_id = $2")
|
||
.bind(member_id).bind(group_id).execute(db).await?;
|
||
if result.rows_affected() == 0 {
|
||
return Err(SaasError::NotFound(format!("成员 {} 不属于该模型组", member_id)));
|
||
}
|
||
Ok(())
|
||
}
|