fix: 三端联调 P1 修复 — API密钥页崩溃 + 桌面端401恢复 + 用量统计全零
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

P1-03: vite.config.ts proxy '/api' → '/api/' 加尾部斜杠,
  防止前缀匹配 /api-keys 导致 SPA 路由崩溃

P1-01: kernel_init 增加 api_key 变更检测(token 刷新后自动重连),
  streamStore 增加 401 自动恢复(refresh token → kernel reconnect),
  KernelClient 新增 getConfig() 方法

P1-02: /api/v1/usage 总计改从 billing_usage_quotas 读取
  (authoritative source,SSE 和 JSON 均写入),
  by_model/by_day 仍从 usage_records 读取
This commit is contained in:
iven
2026-04-14 22:02:02 +08:00
parent 6721a1cc6e
commit e0eb7173c5
5 changed files with 85 additions and 24 deletions

View File

@@ -20,7 +20,7 @@ export default defineConfig({
timeout: 600_000, timeout: 600_000,
proxyTimeout: 600_000, proxyTimeout: 600_000,
}, },
'/api': { '/api/': {
target: 'http://localhost:8080', target: 'http://localhost:8080',
changeOrigin: true, changeOrigin: true,
timeout: 30_000, timeout: 30_000,

View File

@@ -419,21 +419,33 @@ pub async fn revoke_account_api_key(
pub async fn get_usage_stats( pub async fn get_usage_stats(
db: &PgPool, account_id: &str, query: &UsageQuery, db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> { ) -> SaasResult<UsageStats> {
// Optional date filters: pass as TEXT with explicit $N::timestamptz SQL cast. // === Totals: from billing_usage_quotas (authoritative source) ===
// This avoids the sqlx NULL-without-type-OID problem — PG's ::timestamptz // billing_usage_quotas is written to on every relay request (both JSON and SSE),
// gives a typed NULL even when sqlx sends an untyped NULL. // whereas usage_records has 0 tokens for SSE requests. Use billing as the primary source.
let billing_row = sqlx::query(
"SELECT COALESCE(SUM(input_tokens), 0)::bigint,
COALESCE(SUM(output_tokens), 0)::bigint,
COALESCE(SUM(relay_requests), 0)::bigint
FROM billing_usage_quotas WHERE account_id = $1"
)
.bind(account_id)
.fetch_one(db)
.await?;
let total_input: i64 = billing_row.try_get(0).unwrap_or(0);
let total_output: i64 = billing_row.try_get(1).unwrap_or(0);
let total_requests: i64 = billing_row.try_get(2).unwrap_or(0);
// === Breakdowns: from usage_records (per-request detail) ===
// Optional date filters: pass as TEXT with explicit SQL cast.
let from_str: Option<&str> = query.from.as_deref(); let from_str: Option<&str> = query.from.as_deref();
// For 'to' date-only strings, append T23:59:59 to include the entire day
let to_str: Option<String> = query.to.as_ref().map(|s| { let to_str: Option<String> = query.to.as_ref().map(|s| {
if s.len() == 10 { format!("{}T23:59:59", s) } else { s.clone() } if s.len() == 10 { format!("{}T23:59:59", s) } else { s.clone() }
}); });
// Build SQL dynamically to avoid sqlx NULL-without-type-OID problem entirely. // Build SQL dynamically for usage_records breakdowns.
// Date parameters are injected as SQL literals (validated above via chrono parse). // Date parameters are injected as SQL literals (validated via chrono parse).
// Only account_id uses parameterized binding to prevent SQL injection on user input.
let mut where_parts = vec![format!("account_id = '{}'", account_id.replace('\'', "''"))]; let mut where_parts = vec![format!("account_id = '{}'", account_id.replace('\'', "''"))];
if let Some(f) = from_str { if let Some(f) = from_str {
// Validate: must be parseable as a date
let valid = chrono::NaiveDate::parse_from_str(f, "%Y-%m-%d").is_ok() let valid = chrono::NaiveDate::parse_from_str(f, "%Y-%m-%d").is_ok()
|| chrono::NaiveDateTime::parse_from_str(f, "%Y-%m-%dT%H:%M:%S%.f").is_ok(); || chrono::NaiveDateTime::parse_from_str(f, "%Y-%m-%dT%H:%M:%S%.f").is_ok();
if !valid { if !valid {
@@ -457,15 +469,6 @@ pub async fn get_usage_stats(
} }
let where_clause = where_parts.join(" AND "); let where_clause = where_parts.join(" AND ");
let total_sql = format!(
"SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0)::bigint, COALESCE(SUM(output_tokens), 0)::bigint
FROM usage_records WHERE {}", where_clause
);
let row = sqlx::query(&total_sql).fetch_one(db).await?;
let total_requests: i64 = row.try_get(0).unwrap_or(0);
let total_input: i64 = row.try_get(1).unwrap_or(0);
let total_output: i64 = row.try_get(2).unwrap_or(0);
// 按模型统计 // 按模型统计
let by_model_sql = format!( let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0)::bigint AS input_tokens, COALESCE(SUM(output_tokens), 0)::bigint AS output_tokens

View File

@@ -73,15 +73,18 @@ pub async fn kernel_init(
// Get current config from kernel // Get current config from kernel
let current_config = kernel.config(); let current_config = kernel.config();
// Check if config changed // Check if config changed (model, base_url, or api_key)
let config_changed = if let Some(ref req) = config_request { let config_changed = if let Some(ref req) = config_request {
let default_base_url = zclaw_kernel::config::KernelConfig::from_provider( let default_base_url = zclaw_kernel::config::KernelConfig::from_provider(
&req.provider, "", &req.model, None, &req.api_protocol &req.provider, "", &req.model, None, &req.api_protocol
).llm.base_url; ).llm.base_url;
let request_base_url = req.base_url.clone().unwrap_or(default_base_url.clone()); let request_base_url = req.base_url.clone().unwrap_or(default_base_url.clone());
let current_api_key = &current_config.llm.api_key;
let request_api_key = req.api_key.as_deref().unwrap_or("");
current_config.llm.model != req.model || current_config.llm.model != req.model ||
current_config.llm.base_url != request_base_url current_config.llm.base_url != request_base_url ||
current_api_key != request_api_key
} else { } else {
false false
}; };

View File

@@ -164,6 +164,11 @@ export class KernelClient {
this.config = config; this.config = config;
} }
/** Get current kernel configuration (for auth token refresh) */
getConfig(): KernelConfig | undefined {
return this.config;
}
getState(): ConnectionState { getState(): ConnectionState {
return this.state; return this.state;
} }

View File

@@ -38,6 +38,47 @@ import { useArtifactStore } from './artifactStore';
const log = createLogger('StreamStore'); const log = createLogger('StreamStore');
// ---------------------------------------------------------------------------
// 401 Auth Error Recovery
// ---------------------------------------------------------------------------
/**
* Detect and handle 401 auth errors during chat streaming.
* Attempts token refresh → kernel reconnect → auto-retry.
* Returns a user-friendly error message if recovery fails.
*/
async function tryRecoverFromAuthError(error: string): Promise<string | null> {
const is401 = /401|Unauthorized|UNAUTHORIZED|未认证|认证已过期/.test(error);
if (!is401) return null;
log.info('Detected 401 auth error, attempting token refresh...');
try {
const { saasClient } = await import('../../lib/saas-client');
const newToken = await saasClient.refreshMutex();
if (newToken) {
// Update kernel config with refreshed token → triggers kernel re-init via changed api_key detection
const { useConnectionStore } = await import('../connectionStore');
const { getKernelClient } = await import('../../lib/kernel-client');
const kernelClient = getKernelClient();
const currentConfig = kernelClient.getConfig();
if (currentConfig) {
kernelClient.setConfig({ ...currentConfig, apiKey: newToken });
await kernelClient.connect();
log.info('Kernel reconnected with refreshed token');
}
return '认证已刷新,请重新发送消息';
}
} catch (refreshErr) {
log.warn('Token refresh failed, triggering logout:', refreshErr);
try {
const { useSaaSStore } = await import('../saasStore');
useSaaSStore.getState().logout();
} catch { /* non-critical */ }
return 'SaaS 会话已过期,请重新登录';
}
return '认证失败,请重新登录';
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Types // Types
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -517,7 +558,7 @@ export const useStreamStore = create<StreamState>()(
} }
} }
}, },
onError: (error: string) => { onError: async (error: string) => {
// Flush any remaining buffered deltas before erroring // Flush any remaining buffered deltas before erroring
if (flushTimer !== null) { if (flushTimer !== null) {
clearTimeout(flushTimer); clearTimeout(flushTimer);
@@ -525,10 +566,14 @@ export const useStreamStore = create<StreamState>()(
} }
flushBuffers(); flushBuffers();
// Attempt 401 auth recovery (token refresh + kernel reconnect)
const recoveryMsg = await tryRecoverFromAuthError(error);
const displayError = recoveryMsg || error;
_chat?.updateMessages(msgs => _chat?.updateMessages(msgs =>
msgs.map(m => msgs.map(m =>
m.id === assistantId m.id === assistantId
? { ...m, content: error, streaming: false, error } ? { ...m, content: displayError, streaming: false, error: displayError }
: m.role === 'user' && m.optimistic && m.timestamp.getTime() >= streamStartTime : m.role === 'user' && m.optimistic && m.timestamp.getTime() >= streamStartTime
? { ...m, optimistic: false } ? { ...m, optimistic: false }
: m : m
@@ -573,13 +618,18 @@ export const useStreamStore = create<StreamState>()(
textBuffer = ''; textBuffer = '';
thinkBuffer = ''; thinkBuffer = '';
const errorMessage = err instanceof Error ? err.message : '无法连接 Gateway'; let errorMessage = err instanceof Error ? err.message : '无法连接 Gateway';
// Attempt 401 auth recovery
const recoveryMsg = await tryRecoverFromAuthError(errorMessage);
if (recoveryMsg) errorMessage = recoveryMsg;
_chat?.updateMessages(msgs => _chat?.updateMessages(msgs =>
msgs.map(m => msgs.map(m =>
m.id === assistantId m.id === assistantId
? { ? {
...m, ...m,
content: `⚠️ ${errorMessage}`, content: errorMessage,
streaming: false, streaming: false,
error: errorMessage, error: errorMessage,
} }