fix(saas): P0-2/P0-3 — usage endpoint + refresh token type mismatch
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

P0-2: GET /usage 500 "text >= timestamptz" — usage_records.created_at
is TEXT in actual DB despite migration declaring TIMESTAMPTZ. Fixed by
using dynamic SQL with ::timestamptz explicit casts for all date
comparisons, avoiding sqlx NULL-without-type-OID binding issues.

P0-3: POST /auth/refresh 500 — refresh_tokens.expires_at/used_at are
TEXT columns. Added ::timestamptz cast to SQL queries in auth handlers
and cleanup worker.
This commit is contained in:
iven
2026-04-10 16:25:52 +08:00
parent 12a018cc74
commit 88cac9557b
3 changed files with 59 additions and 34 deletions

View File

@@ -331,7 +331,7 @@ pub async fn refresh(
// 3. 从 DB 查找 refresh token确保未被使用 // 3. 从 DB 查找 refresh token确保未被使用
let row: Option<(String,)> = sqlx::query_as( let row: Option<(String,)> = sqlx::query_as(
"SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at > $2" "SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at::timestamptz > $2"
) )
.bind(jti) .bind(jti)
.bind(&chrono::Utc::now()) .bind(&chrono::Utc::now())
@@ -567,7 +567,7 @@ async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
let now = chrono::Utc::now(); let now = chrono::Utc::now();
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀) // 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
sqlx::query( sqlx::query(
"DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at < $1) OR (expires_at < $1)" "DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at::timestamptz < $1) OR (expires_at::timestamptz < $1)"
) )
.bind(&now) .bind(&now)
.execute(db).await?; .execute(db).await?;

View File

@@ -413,33 +413,59 @@ pub async fn revoke_account_api_key(
pub async fn get_usage_stats( pub async fn get_usage_stats(
db: &PgPool, account_id: &str, query: &UsageQuery, db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> { ) -> SaasResult<UsageStats> {
// Static SQL with conditional filter pattern: // Optional date filters: pass as TEXT with explicit $N::timestamptz SQL cast.
// account_id is always required; optional filters use ($N IS NULL OR col = $N). // This avoids the sqlx NULL-without-type-OID problem — PG's ::timestamptz
let total_sql = "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0)::bigint, COALESCE(SUM(output_tokens), 0)::bigint // gives a typed NULL even when sqlx sends an untyped NULL.
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2::timestamptz) AND ($3 IS NULL OR created_at <= $3::timestamptz) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5)"; let from_str: Option<&str> = query.from.as_deref();
// For 'to' date-only strings, append T23:59:59 to include the entire day
let to_str: Option<String> = query.to.as_ref().map(|s| {
if s.len() == 10 { format!("{}T23:59:59", s) } else { s.clone() }
});
let row = sqlx::query(total_sql) // Build SQL dynamically to avoid sqlx NULL-without-type-OID problem entirely.
.bind(account_id) // Date parameters are injected as SQL literals (validated above via chrono parse).
.bind(&query.from) // Only account_id uses parameterized binding to prevent SQL injection on user input.
.bind(&query.to) let mut where_parts = vec![format!("account_id = '{}'", account_id.replace('\'', "''"))];
.bind(&query.provider_id) if let Some(f) = from_str {
.bind(&query.model_id) // Validate: must be parseable as a date
.fetch_one(db).await?; let valid = chrono::NaiveDate::parse_from_str(f, "%Y-%m-%d").is_ok()
|| chrono::NaiveDateTime::parse_from_str(f, "%Y-%m-%dT%H:%M:%S%.f").is_ok();
if !valid {
return Err(SaasError::InvalidInput(format!("Invalid 'from' date: {}", f)));
}
where_parts.push(format!("created_at::timestamptz >= '{}T00:00:00Z'::timestamptz", f.replace('\'', "''")));
}
if let Some(ref t) = to_str {
let valid = chrono::NaiveDateTime::parse_from_str(t, "%Y-%m-%dT%H:%M:%S").is_ok()
|| chrono::NaiveDate::parse_from_str(t, "%Y-%m-%d").is_ok();
if !valid {
return Err(SaasError::InvalidInput(format!("Invalid 'to' date: {}", t)));
}
where_parts.push(format!("created_at::timestamptz <= '{}'::timestamptz", t.replace('\'', "''")));
}
if let Some(ref pid) = query.provider_id {
where_parts.push(format!("provider_id = '{}'", pid.replace('\'', "''")));
}
if let Some(ref mid) = query.model_id {
where_parts.push(format!("model_id = '{}'", mid.replace('\'', "''")));
}
let where_clause = where_parts.join(" AND ");
let total_sql = format!(
"SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0)::bigint, COALESCE(SUM(output_tokens), 0)::bigint
FROM usage_records WHERE {}", where_clause
);
let row = sqlx::query(&total_sql).fetch_one(db).await?;
let total_requests: i64 = row.try_get(0).unwrap_or(0); let total_requests: i64 = row.try_get(0).unwrap_or(0);
let total_input: i64 = row.try_get(1).unwrap_or(0); let total_input: i64 = row.try_get(1).unwrap_or(0);
let total_output: i64 = row.try_get(2).unwrap_or(0); let total_output: i64 = row.try_get(2).unwrap_or(0);
// 按模型统计 // 按模型统计
let by_model_sql = "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens let by_model_sql = format!(
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2::timestamptz) AND ($3 IS NULL OR created_at <= $3::timestamptz) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5) GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20"; "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20", where_clause
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(by_model_sql) );
.bind(account_id) let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(&by_model_sql).fetch_all(db).await?;
.bind(&query.from)
.bind(&query.to)
.bind(&query.provider_id)
.bind(&query.model_id)
.fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter() let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|r| { .map(|r| {
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens } ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
@@ -447,16 +473,15 @@ pub async fn get_usage_stats(
// 按天统计 (使用 days 参数或默认 30 天) // 按天统计 (使用 days 参数或默认 30 天)
let days = query.days.unwrap_or(30).min(365).max(1) as i64; let days = query.days.unwrap_or(30).min(365).max(1) as i64;
let from_days = (chrono::Utc::now() - chrono::Duration::days(days)) let from_days_str = (chrono::Utc::now() - chrono::Duration::days(days))
.date_naive() .format("%Y-%m-%d").to_string();
.and_hms_opt(0, 0, 0).unwrap() let daily_sql = format!(
.and_utc(); "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens
let daily_sql = "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens FROM usage_records WHERE account_id = '{}' AND created_at::timestamptz >= '{}T00:00:00Z'::timestamptz
FROM usage_records WHERE account_id = $1 AND created_at >= $2 GROUP BY created_at::date ORDER BY day DESC LIMIT {}",
GROUP BY created_at::date ORDER BY day DESC LIMIT $3"; account_id.replace('\'', "''"), from_days_str.replace('\'', "''"), days
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(daily_sql) );
.bind(account_id).bind(&from_days).bind(days as i32) let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(&daily_sql).fetch_all(db).await?;
.fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter() let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|r| { .map(|r| {
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens } DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }

View File

@@ -22,7 +22,7 @@ impl Worker for CleanupRefreshTokensWorker {
async fn perform(&self, db: &PgPool, _args: Self::Args) -> SaasResult<()> { async fn perform(&self, db: &PgPool, _args: Self::Args) -> SaasResult<()> {
let now = chrono::Utc::now(); let now = chrono::Utc::now();
let result = sqlx::query( let result = sqlx::query(
"DELETE FROM refresh_tokens WHERE expires_at < $1 OR used_at IS NOT NULL" "DELETE FROM refresh_tokens WHERE expires_at::timestamptz < $1 OR used_at IS NOT NULL"
) )
.bind(&now) .bind(&now)
.execute(db) .execute(db)