Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Batch 5 (P0): GrowthIntegration 接入 Tauri - Kernel 新增 set_viking()/set_extraction_driver() 桥接 SqliteStorage - 中间件链共享存储,MemoryExtractor 接入 LLM 驱动 Batch 6 (P1): 输入验证 + Heartbeat - Relay 验证补全(stream 兼容检查、API key 格式校验) - UUID 类型校验、SessionId 错误返回 - Heartbeat 默认开启 + 首次聊天自动初始化 Batch 7 (P2): 死代码清理 - zclaw-channels 整体移除(317 行) - multi-agent 特性门控、admin 方法标注 Batch 8 (P2): Pipeline 模板 - PipelineMetadata 新增 annotations 字段 - pipeline_templates 命令 + 2 个示例模板 - fallback driver base_url 修复(doubao/qwen/deepseek 端点) Batch 9 (P1): SpeechHand/TwitterHand 真实实现 - SpeechHand: tts_method 字段 + Browser TTS 前端集成 (Web Speech API) - TwitterHand: 12 个 action 全部替换为 Twitter API v2 真实 HTTP 调用 - chatStore/useAutomationEvents 双路径 TTS 触发
494 lines
20 KiB
Rust
494 lines
20 KiB
Rust
//! 模型配置业务逻辑
|
|
|
|
use sqlx::{PgPool, Row};
|
|
use crate::error::{SaasError, SaasResult};
|
|
use crate::common::{PaginatedResponse, normalize_pagination};
|
|
use crate::crypto;
|
|
use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow};
|
|
use super::types::*;
|
|
|
|
// ============ Providers ============
|
|
|
|
pub async fn list_providers(
|
|
db: &PgPool, page: Option<u32>, page_size: Option<u32>, enabled_filter: Option<bool>,
|
|
) -> SaasResult<PaginatedResponse<ProviderInfo>> {
|
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
|
|
|
let (count_sql, data_sql) = if enabled_filter.is_some() {
|
|
(
|
|
"SELECT COUNT(*) FROM providers WHERE enabled = $1",
|
|
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
|
FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3",
|
|
)
|
|
} else {
|
|
(
|
|
"SELECT COUNT(*) FROM providers",
|
|
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
|
FROM providers ORDER BY name LIMIT $1 OFFSET $2",
|
|
)
|
|
};
|
|
|
|
let total: (i64,) = if let Some(en) = enabled_filter {
|
|
sqlx::query_as(count_sql).bind(en).fetch_one(db).await?
|
|
} else {
|
|
sqlx::query_as(count_sql).fetch_one(db).await?
|
|
};
|
|
|
|
let rows: Vec<ProviderRow> =
|
|
if let Some(en) = enabled_filter {
|
|
sqlx::query_as(data_sql)
|
|
.bind(en).bind(ps as i64).bind(offset)
|
|
.fetch_all(db).await?
|
|
} else {
|
|
sqlx::query_as(data_sql)
|
|
.bind(ps as i64).bind(offset)
|
|
.fetch_all(db).await?
|
|
};
|
|
|
|
let items = rows.into_iter().map(|r| {
|
|
ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }
|
|
}).collect();
|
|
|
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
|
}
|
|
|
|
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
|
let row: Option<ProviderRow> =
|
|
sqlx::query_as(
|
|
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
|
FROM providers WHERE id = $1"
|
|
)
|
|
.bind(provider_id)
|
|
.fetch_optional(db)
|
|
.await?;
|
|
|
|
let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
|
|
|
|
Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at })
|
|
}
|
|
|
|
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
|
|
let id = uuid::Uuid::new_v4().to_string();
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
// 检查名称唯一性
|
|
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
|
|
.bind(&req.name).fetch_optional(db).await?;
|
|
if existing.is_some() {
|
|
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
|
|
}
|
|
|
|
// 加密 API Key 后存储
|
|
let encrypted_api_key = if let Some(ref key) = req.api_key {
|
|
if key.is_empty() {
|
|
String::new()
|
|
} else {
|
|
crypto::encrypt_value(key, enc_key)?
|
|
}
|
|
} else {
|
|
String::new()
|
|
};
|
|
|
|
sqlx::query(
|
|
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
|
|
)
|
|
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
|
|
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
|
.execute(db).await?;
|
|
|
|
get_provider(db, &id).await
|
|
}
|
|
|
|
pub async fn update_provider(
|
|
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
|
|
) -> SaasResult<ProviderInfo> {
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
// Encrypt api_key upfront if provided
|
|
let encrypted_api_key = match req.api_key {
|
|
Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?),
|
|
Some(ref v) if v.is_empty() => Some(String::new()),
|
|
_ => None,
|
|
};
|
|
|
|
// COALESCE pattern: all updatable fields in a single static SQL.
|
|
// NULL parameters leave the column unchanged.
|
|
sqlx::query(
|
|
"UPDATE providers SET
|
|
display_name = COALESCE($1, display_name),
|
|
base_url = COALESCE($2, base_url),
|
|
api_protocol = COALESCE($3, api_protocol),
|
|
api_key = COALESCE($4, api_key),
|
|
enabled = COALESCE($5, enabled),
|
|
rate_limit_rpm = COALESCE($6, rate_limit_rpm),
|
|
rate_limit_tpm = COALESCE($7, rate_limit_tpm),
|
|
updated_at = $8
|
|
WHERE id = $9"
|
|
)
|
|
.bind(req.display_name.as_deref())
|
|
.bind(req.base_url.as_deref())
|
|
.bind(req.api_protocol.as_deref())
|
|
.bind(encrypted_api_key.as_deref())
|
|
.bind(req.enabled)
|
|
.bind(req.rate_limit_rpm)
|
|
.bind(req.rate_limit_tpm)
|
|
.bind(&now)
|
|
.bind(provider_id)
|
|
.execute(db).await?;
|
|
|
|
get_provider(db, provider_id).await
|
|
}
|
|
|
|
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
|
|
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
|
|
.bind(provider_id).execute(db).await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// ============ Models ============
|
|
|
|
pub async fn list_models(
|
|
db: &PgPool, provider_id: Option<&str>, page: Option<u32>, page_size: Option<u32>,
|
|
) -> SaasResult<PaginatedResponse<ModelInfo>> {
|
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
|
|
|
let (count_sql, data_sql) = if provider_id.is_some() {
|
|
(
|
|
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
|
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
|
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
|
|
)
|
|
} else {
|
|
(
|
|
"SELECT COUNT(*) FROM models",
|
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
|
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
|
|
)
|
|
};
|
|
|
|
let total: (i64,) = if let Some(pid) = provider_id {
|
|
sqlx::query_as(count_sql).bind(pid).fetch_one(db).await?
|
|
} else {
|
|
sqlx::query_as(count_sql).fetch_one(db).await?
|
|
};
|
|
|
|
let mut query = sqlx::query_as::<_, ModelRow>(data_sql);
|
|
if let Some(pid) = provider_id {
|
|
query = query.bind(pid);
|
|
}
|
|
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
|
|
|
let items = rows.into_iter().map(|r| {
|
|
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
|
|
}).collect();
|
|
|
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
|
}
|
|
|
|
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
|
// 验证 provider 存在
|
|
let provider = get_provider(db, &req.provider_id).await?;
|
|
|
|
let id = uuid::Uuid::new_v4().to_string();
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
// 检查 model 唯一性
|
|
let existing: Option<(String,)> = sqlx::query_as(
|
|
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
|
|
)
|
|
.bind(&req.provider_id).bind(&req.model_id)
|
|
.fetch_optional(db).await?;
|
|
|
|
if existing.is_some() {
|
|
return Err(SaasError::AlreadyExists(format!(
|
|
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
|
|
)));
|
|
}
|
|
|
|
let ctx = req.context_window.unwrap_or(8192);
|
|
let max_out = req.max_output_tokens.unwrap_or(4096);
|
|
let streaming = req.supports_streaming.unwrap_or(true);
|
|
let vision = req.supports_vision.unwrap_or(false);
|
|
let pi = req.pricing_input.unwrap_or(0.0);
|
|
let po = req.pricing_output.unwrap_or(0.0);
|
|
|
|
sqlx::query(
|
|
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
|
|
)
|
|
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
|
|
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
|
.execute(db).await?;
|
|
|
|
get_model(db, &id).await
|
|
}
|
|
|
|
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
|
let row: Option<ModelRow> =
|
|
sqlx::query_as(
|
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
|
FROM models WHERE id = $1"
|
|
)
|
|
.bind(model_id)
|
|
.fetch_optional(db)
|
|
.await?;
|
|
|
|
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
|
|
|
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
|
|
}
|
|
|
|
pub async fn update_model(
|
|
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
|
|
) -> SaasResult<ModelInfo> {
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
// COALESCE pattern: all updatable fields in a single static SQL.
|
|
// NULL parameters leave the column unchanged.
|
|
sqlx::query(
|
|
"UPDATE models SET
|
|
alias = COALESCE($1, alias),
|
|
context_window = COALESCE($2, context_window),
|
|
max_output_tokens = COALESCE($3, max_output_tokens),
|
|
supports_streaming = COALESCE($4, supports_streaming),
|
|
supports_vision = COALESCE($5, supports_vision),
|
|
enabled = COALESCE($6, enabled),
|
|
pricing_input = COALESCE($7, pricing_input),
|
|
pricing_output = COALESCE($8, pricing_output),
|
|
updated_at = $9
|
|
WHERE id = $10"
|
|
)
|
|
.bind(req.alias.as_deref())
|
|
.bind(req.context_window)
|
|
.bind(req.max_output_tokens)
|
|
.bind(req.supports_streaming)
|
|
.bind(req.supports_vision)
|
|
.bind(req.enabled)
|
|
.bind(req.pricing_input)
|
|
.bind(req.pricing_output)
|
|
.bind(&now)
|
|
.bind(model_id)
|
|
.execute(db).await?;
|
|
|
|
get_model(db, model_id).await
|
|
}
|
|
|
|
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
|
|
let result = sqlx::query("DELETE FROM models WHERE id = $1")
|
|
.bind(model_id).execute(db).await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// ============ Account API Keys ============
|
|
|
|
pub async fn list_account_api_keys(
|
|
db: &PgPool, account_id: &str, provider_id: Option<&str>,
|
|
page: Option<u32>, page_size: Option<u32>,
|
|
) -> SaasResult<PaginatedResponse<AccountApiKeyInfo>> {
|
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
|
|
|
// Build COUNT and data queries based on whether provider_id is provided
|
|
let (count_sql, data_sql) = if provider_id.is_some() {
|
|
(
|
|
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL",
|
|
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
|
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4",
|
|
)
|
|
} else {
|
|
(
|
|
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL",
|
|
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
|
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3",
|
|
)
|
|
};
|
|
|
|
let total: (i64,) = if provider_id.is_some() {
|
|
let mut q = sqlx::query_as(count_sql).bind(account_id);
|
|
if let Some(pid) = provider_id { q = q.bind(pid); }
|
|
q.fetch_one(db).await?
|
|
} else {
|
|
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
|
|
};
|
|
|
|
let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql)
|
|
.bind(account_id);
|
|
if let Some(pid) = provider_id {
|
|
query = query.bind(pid);
|
|
}
|
|
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
|
|
|
let items = rows.into_iter().map(|r| {
|
|
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
|
let masked = mask_api_key(&r.key_value);
|
|
AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked }
|
|
}).collect();
|
|
|
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
|
}
|
|
|
|
pub async fn create_account_api_key(
|
|
db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32],
|
|
) -> SaasResult<AccountApiKeyInfo> {
|
|
// 验证 provider 存在
|
|
get_provider(db, &req.provider_id).await?;
|
|
|
|
let id = uuid::Uuid::new_v4().to_string();
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
let permissions = serde_json::to_string(&req.permissions)?;
|
|
|
|
// 加密 key_value 后存储
|
|
let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?;
|
|
|
|
sqlx::query(
|
|
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
|
|
)
|
|
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
|
|
.bind(&req.key_label).bind(&permissions).bind(&now)
|
|
.execute(db).await?;
|
|
|
|
let masked = mask_api_key(&req.key_value);
|
|
Ok(AccountApiKeyInfo {
|
|
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
|
|
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
|
|
created_at: now, masked_key: masked,
|
|
})
|
|
}
|
|
|
|
pub async fn rotate_account_api_key(
|
|
db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32],
|
|
) -> SaasResult<()> {
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?;
|
|
let result = sqlx::query(
|
|
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
|
|
)
|
|
.bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id)
|
|
.execute(db).await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn revoke_account_api_key(
|
|
db: &PgPool, key_id: &str, account_id: &str,
|
|
) -> SaasResult<()> {
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
let result = sqlx::query(
|
|
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
|
)
|
|
.bind(&now).bind(key_id).bind(account_id)
|
|
.execute(db).await?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
// ============ Usage Statistics ============
|
|
|
|
pub async fn get_usage_stats(
|
|
db: &PgPool, account_id: &str, query: &UsageQuery,
|
|
) -> SaasResult<UsageStats> {
|
|
// Static SQL with conditional filter pattern:
|
|
// account_id is always required; optional filters use ($N IS NULL OR col = $N).
|
|
let total_sql = "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
|
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5)";
|
|
|
|
let row = sqlx::query(total_sql)
|
|
.bind(account_id)
|
|
.bind(&query.from)
|
|
.bind(&query.to)
|
|
.bind(&query.provider_id)
|
|
.bind(&query.model_id)
|
|
.fetch_one(db).await?;
|
|
let total_requests: i64 = row.try_get(0).unwrap_or(0);
|
|
let total_input: i64 = row.try_get(1).unwrap_or(0);
|
|
let total_output: i64 = row.try_get(2).unwrap_or(0);
|
|
|
|
// 按模型统计
|
|
let by_model_sql = "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
|
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5) GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20";
|
|
|
|
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(by_model_sql)
|
|
.bind(account_id)
|
|
.bind(&query.from)
|
|
.bind(&query.to)
|
|
.bind(&query.provider_id)
|
|
.bind(&query.model_id)
|
|
.fetch_all(db).await?;
|
|
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
|
|
.map(|r| {
|
|
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
|
}).collect();
|
|
|
|
// 按天统计 (使用 days 参数或默认 30 天)
|
|
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
|
|
let from_days = (chrono::Utc::now() - chrono::Duration::days(days))
|
|
.date_naive()
|
|
.and_hms_opt(0, 0, 0).unwrap()
|
|
.and_utc()
|
|
.to_rfc3339();
|
|
let daily_sql = "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
|
FROM usage_records WHERE account_id = $1 AND created_at >= $2
|
|
GROUP BY created_at::date ORDER BY day DESC LIMIT $3";
|
|
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(daily_sql)
|
|
.bind(account_id).bind(&from_days).bind(days as i32)
|
|
.fetch_all(db).await?;
|
|
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
|
|
.map(|r| {
|
|
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
|
}).collect();
|
|
|
|
// 按 group_by 过滤返回
|
|
let group_by = query.group_by.as_deref();
|
|
let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] };
|
|
let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] };
|
|
|
|
Ok(UsageStats {
|
|
total_requests,
|
|
total_input_tokens: total_input,
|
|
total_output_tokens: total_output,
|
|
by_model,
|
|
by_day,
|
|
})
|
|
}
|
|
|
|
pub async fn record_usage(
|
|
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
|
|
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
|
|
status: &str, error_message: Option<&str>,
|
|
) -> SaasResult<()> {
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
sqlx::query(
|
|
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
|
)
|
|
.bind(account_id).bind(provider_id).bind(model_id)
|
|
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
|
|
.bind(status).bind(error_message).bind(&now)
|
|
.execute(db).await?;
|
|
Ok(())
|
|
}
|
|
|
|
// ============ Helpers ============
|
|
|
|
fn mask_api_key(key: &str) -> String {
|
|
if key.len() <= 8 {
|
|
return "*".repeat(key.len());
|
|
}
|
|
format!("{}...{}", &key[..4], &key[key.len()-4..])
|
|
}
|