diff --git a/crates/zclaw-saas/src/relay/key_pool.rs b/crates/zclaw-saas/src/relay/key_pool.rs index 63e6da3..b679fca 100644 --- a/crates/zclaw-saas/src/relay/key_pool.rs +++ b/crates/zclaw-saas/src/relay/key_pool.rs @@ -3,10 +3,37 @@ //! 管理 provider 的多个 API Key,实现智能轮转绕过限额。 use sqlx::PgPool; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; +use dashmap::DashMap; use crate::error::{SaasError, SaasResult}; - use crate::models::ProviderKeyRow; +use crate::models::ProviderKeyRow; use crate::crypto; +// ============ Key Pool Cache ============ + +/// TTL for cached key selections (seconds) +const KEY_CACHE_TTL: Duration = Duration::from_secs(5); + +/// Cached key selection entry +struct CachedSelection { + selection: KeySelection, + cached_at: Instant, +} + +/// Global cache for key selections, keyed by provider_id +static KEY_SELECTION_CACHE: OnceLock> = OnceLock::new(); + +fn get_cache() -> &'static DashMap { + KEY_SELECTION_CACHE.get_or_init(DashMap::new) +} + +/// Invalidate cached selection for a provider (called on usage record and 429 marking) +fn invalidate_cache(provider_id: &str) { + let cache = get_cache(); + cache.remove(provider_id); +} + /// 解密 key_value (如果已加密),否则原样返回 fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult { if crypto::is_encrypted(encrypted) { @@ -29,6 +56,7 @@ pub struct PoolKey { } /// Key 选择结果 +#[derive(Clone)] pub struct KeySelection { pub key: PoolKey, pub key_id: String, @@ -36,22 +64,34 @@ pub struct KeySelection { /// 从 provider 的 Key Pool 中选择最佳可用 Key /// -/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询 +/// 优化: 单次 JOIN 查询获取 Key + 滑动窗口(60s) RPM/TPM 使用量 pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult { - let now = chrono::Utc::now(); - let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string(); + // Check in-memory cache first (TTL 5s) + { + let cache = get_cache(); + if let Some(entry) = cache.get(provider_id) { + if entry.cached_at.elapsed() < KEY_CACHE_TTL { + return Ok(entry.selection.clone()); + } + } + } - // 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN) + let now = chrono::Utc::now(); + + // 滑动窗口: 聚合最近 60 秒内所有窗口行的 RPM/TPM,避免分钟边界突发 let rows: Vec<(String, String, i32, Option, Option, Option, Option)> = sqlx::query_as( "SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, - uw.request_count::bigint, uw.token_count + COALESCE(SUM(uw.request_count), 0)::bigint, + COALESCE(SUM(uw.token_count), 0) FROM provider_keys pk - LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1 - WHERE pk.provider_id = $2 AND pk.is_active = TRUE - AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $3) + LEFT JOIN key_usage_window uw ON pk.id = uw.key_id + AND uw.window_minute >= (NOW() - INTERVAL '1 minute')::TEXT + WHERE pk.provider_id = $1 AND pk.is_active = TRUE + AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2) + GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST" - ).bind(¤t_minute).bind(provider_id).bind(&now).fetch_all(db).await?; + ).bind(provider_id).bind(&now).fetch_all(db).await?; for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows { // RPM 检查 @@ -78,7 +118,7 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) // 此 Key 可用 — 解密 key_value let decrypted_kv = decrypt_key_value(key_value, enc_key)?; - return Ok(KeySelection { + let selection = KeySelection { key: PoolKey { id: id.clone(), key_value: decrypted_kv, @@ -87,12 +127,22 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) max_tpm: *max_tpm, }, key_id: id.clone(), + }; + // Cache the selection + get_cache().insert(provider_id.to_string(), CachedSelection { + selection: selection.clone(), + cached_at: Instant::now(), }); + return Ok(selection); } - // 所有 Key 都超限或无 Key - if rows.is_empty() { - // 检查是否有冷却中的 Key,返回预计等待时间 + // 所有 Key 都超限或无 Key — 先检查是否存在活跃 Key + let has_any_key: Option<(bool,)> = sqlx::query_as( + "SELECT COUNT(*) > 0 FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE" + ).bind(provider_id).fetch_optional(db).await?; + + if has_any_key.is_some_and(|(b,)| b) { + // 有 key 但全部 cooldown 或超限 — 检查最快恢复时间 let cooldown_row: Option<(String,)> = sqlx::query_as( "SELECT cooldown_until::TEXT FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2 @@ -106,34 +156,14 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs) )); } - } - // 回退到 provider 单 Key - let provider_key: Option = sqlx::query_scalar( - "SELECT api_key FROM providers WHERE id = $1" - ).bind(provider_id).fetch_optional(db).await?.flatten(); - - if let Some(key) = provider_key { - let decrypted = decrypt_key_value(&key, enc_key)?; - return Ok(KeySelection { - key: PoolKey { - id: "provider-fallback".to_string(), - key_value: decrypted, - priority: 0, - max_rpm: None, - max_tpm: None, - }, - key_id: "provider-fallback".to_string(), - }); - } - - if rows.is_empty() { - Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id))) - } else { - Err(SaasError::RateLimited( + // Key 存在但 RPM/TPM 全部用尽(无 cooldown) + return Err(SaasError::RateLimited( format!("Provider {} 所有 Key 均已达限额", provider_id) - )) + )); } + + Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id))) } /// 记录 Key 使用量(滑动窗口) @@ -168,6 +198,12 @@ pub async fn record_key_usage( .bind(tokens).bind(key_id) .execute(db).await?; + // 3. 清理过期的滑动窗口行(保留最近 2 分钟即可) + let _ = sqlx::query( + "DELETE FROM key_usage_window WHERE window_minute < (NOW() - INTERVAL '2 minutes')::TEXT" + ) + .execute(db).await; // 忽略错误,非关键操作 + Ok(()) } @@ -180,8 +216,8 @@ pub async fn mark_key_429( let cooldown = if let Some(secs) = retry_after_seconds { (chrono::Utc::now() + chrono::Duration::seconds(secs as i64)) } else { - // 默认 5 分钟冷却 - (chrono::Utc::now() + chrono::Duration::minutes(5)) + // 默认 60 秒冷却(适合小配额 Coding Plan 账号) + chrono::Utc::now() + chrono::Duration::seconds(60) }; let now = chrono::Utc::now(); @@ -199,6 +235,14 @@ pub async fn mark_key_429( cooldown ); + // Invalidate cache for this key's provider (query provider_id then clear) + let pid_result: Result, _> = sqlx::query_as( + "SELECT provider_id FROM provider_keys WHERE id = $1" + ).bind(key_id).fetch_optional(db).await; + if let Ok(Some((pid,))) = pid_result { + invalidate_cache(&pid); + } + Ok(()) } @@ -324,6 +368,6 @@ fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 { let diff = c.signed_duration_since(n); diff.num_seconds().max(0) } - _ => 300, // 默认 5 分钟 + _ => 60, // 默认 60 秒 } } diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index 79a7c7f..fcfa28e 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -19,8 +19,8 @@ const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15); /// 实测 Kimi for Coding 的 thinking→content 间隔可达 60s+,需要更宽容的超时。 const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(180); -/// 流结束后延迟清理的时间窗口 -const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(60); +/// 流结束后延迟清理的时间窗口(缩短到 5s,仅用于 Arc 引用释放) +const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(5); /// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429) fn is_retryable_status(status: u16) -> bool { @@ -357,7 +357,7 @@ pub async fn execute_relay( // SSE 流结束后异步记录 usage + Key 使用量 // 使用全局 Arc 限制并发 spawned tasks,防止高并发时耗尽连接池 static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock> = std::sync::OnceLock::new(); - let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(16))); + let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(64))); let permit = match semaphore.clone().try_acquire_owned() { Ok(p) => p, Err(_) => { diff --git a/desktop/src/components/ChatArea.tsx b/desktop/src/components/ChatArea.tsx index da20cec..a977c7c 100644 --- a/desktop/src/components/ChatArea.tsx +++ b/desktop/src/components/ChatArea.tsx @@ -31,7 +31,6 @@ import { ReasoningBlock } from './ai/ReasoningBlock'; import { StreamingText } from './ai/StreamingText'; import { ChatMode } from './ai/ChatMode'; import { ModelSelector } from './ai/ModelSelector'; -import { isTauriRuntime } from '../lib/tauri-gateway'; import { SuggestionChips } from './ai/SuggestionChips'; import { PipelineResultPreview } from './pipeline/PipelineResultPreview'; import { PresentationContainer } from './presentation/PresentationContainer'; @@ -563,7 +562,7 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD }
- {!isTauriRuntime() && ( + {models.length > 0 && ( ({ id: m.id, name: m.name, provider: m.provider }))} currentModel={currentModel} diff --git a/desktop/src/lib/saas-relay-client.ts b/desktop/src/lib/saas-relay-client.ts index fa06f8c..3642714 100644 --- a/desktop/src/lib/saas-relay-client.ts +++ b/desktop/src/lib/saas-relay-client.ts @@ -92,6 +92,9 @@ export function createSaaSRelayGatewayClient( // ----------------------------------------------------------------------- // Helper: OpenAI SSE streaming via SaaS relay // ----------------------------------------------------------------------- + // AbortController for cancelling active streams + let activeAbortController: AbortController | null = null; + async function chatStream( message: string, callbacks: { @@ -112,10 +115,13 @@ export function createSaaSRelayGatewayClient( }, ): Promise<{ runId: string }> { const runId = `run_${Date.now()}`; + const abortController = new AbortController(); + activeAbortController = abortController; + const aborted = () => abortController.signal.aborted; try { const body: Record = { - model: getModel(), + model: getModel() || 'glm-4-flash-250414', messages: [{ role: 'user', content: message }], stream: true, }; @@ -148,67 +154,88 @@ export function createSaaSRelayGatewayClient( const decoder = new TextDecoder(); let buffer = ''; - while (true) { + while (!aborted()) { const { done, value } = await reader.read(); if (done) break; buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split('\n'); - buffer = lines.pop() || ''; // keep incomplete last line - for (const line of lines) { - if (!line.startsWith('data: ')) continue; - const data = line.slice(6).trim(); - if (data === '[DONE]') continue; + // Optimized SSE parsing: split by double-newline (event boundaries) + let boundary: number; + while ((boundary = buffer.indexOf('\n\n')) !== -1) { + const eventBlock = buffer.slice(0, boundary); + buffer = buffer.slice(boundary + 2); - try { - const parsed = JSON.parse(data); + // Process each line in the event block + const lines = eventBlock.split('\n'); + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + const data = line.slice(6).trim(); + if (data === '[DONE]') continue; - // Handle SSE error events from relay (e.g. stream_timeout) - if (parsed.error) { - const errMsg = parsed.message || parsed.error || 'Unknown stream error'; - log.warn('SSE stream error:', errMsg); - callbacks.onError(errMsg); - callbacks.onComplete(); - return { runId }; + try { + const parsed = JSON.parse(data); + + // Handle SSE error events from relay (e.g. stream_timeout) + if (parsed.error) { + const errMsg = parsed.message || parsed.error || 'Unknown stream error'; + log.warn('SSE stream error:', errMsg); + callbacks.onError(errMsg); + callbacks.onComplete(); + return { runId }; + } + + const choices = parsed.choices?.[0]; + if (!choices) continue; + + const delta = choices.delta; + + // Handle thinking/reasoning content + if (delta?.reasoning_content) { + callbacks.onThinkingDelta?.(delta.reasoning_content); + } + + // Handle regular content + if (delta?.content) { + callbacks.onDelta(delta.content); + } + + // Check for completion + if (choices.finish_reason) { + const usage = parsed.usage; + callbacks.onComplete( + usage?.prompt_tokens, + usage?.completion_tokens, + ); + return { runId }; + } + } catch { + // Skip malformed SSE lines } - - const choices = parsed.choices?.[0]; - if (!choices) continue; - - const delta = choices.delta; - - // Handle thinking/reasoning content - if (delta?.reasoning_content) { - callbacks.onThinkingDelta?.(delta.reasoning_content); - } - - // Handle regular content - if (delta?.content) { - callbacks.onDelta(delta.content); - } - - // Check for completion - if (choices.finish_reason) { - const usage = parsed.usage; - callbacks.onComplete( - usage?.prompt_tokens, - usage?.completion_tokens, - ); - return { runId }; - } - } catch { - // Skip malformed SSE lines } } } + // If aborted, cancel the reader + if (aborted()) { + try { reader.cancel(); } catch { /* already closed */ } + } + // Stream ended without explicit finish_reason callbacks.onComplete(); } catch (err) { + if (aborted()) { + // Cancelled by user — don't report as error + callbacks.onComplete(); + return { runId }; + } const msg = err instanceof Error ? err.message : String(err); callbacks.onError(msg); callbacks.onComplete(); + } finally { + if (activeAbortController === abortController) { + activeAbortController = null; + } } return { runId }; @@ -256,6 +283,13 @@ export function createSaaSRelayGatewayClient( // --- Chat --- chatStream, + cancelStream: () => { + if (activeAbortController) { + activeAbortController.abort(); + activeAbortController = null; + log.info('SSE stream cancelled by user'); + } + }, // --- Hands --- listHands: async () => ({ hands: [] }), diff --git a/desktop/src/store/chat/streamStore.ts b/desktop/src/store/chat/streamStore.ts index 6ae7478..2107a15 100644 --- a/desktop/src/store/chat/streamStore.ts +++ b/desktop/src/store/chat/streamStore.ts @@ -581,11 +581,20 @@ export const useStreamStore = create()( if (!isStreaming) return; // 1. Tell backend to abort — use sessionKey (which is the sessionId in Tauri) + // Also abort the frontend SSE fetch via cancelStream() try { - const client = getClient(); + const client = getClient() as unknown as Record; if ('cancelStream' in client) { - const sessionId = useConversationStore.getState().sessionKey || activeRunId || ''; - (client as { cancelStream: (id: string) => void }).cancelStream(sessionId); + const fn = client.cancelStream; + if (typeof fn === 'function') { + // Call with or without sessionId depending on arity + if (fn.length > 0) { + const sessionId = useConversationStore.getState().sessionKey || activeRunId || ''; + (fn as (id: string) => void)(sessionId); + } else { + (fn as () => void)(); + } + } } } catch { // Backend cancel is best-effort; proceed with local cleanup diff --git a/desktop/src/store/connectionStore.ts b/desktop/src/store/connectionStore.ts index 97bef3f..94f9ec1 100644 --- a/desktop/src/store/connectionStore.ts +++ b/desktop/src/store/connectionStore.ts @@ -441,9 +441,10 @@ export const useConnectionStore = create((set, get) => { // Configure the singleton client (cookie auth — no token needed) saasClient.setBaseUrl(session.saasUrl); - // Health check via GET /api/v1/relay/models + // Health check + model list: merged single listModels() call + let relayModels: Array<{ id: string; alias?: string }> | null = null; try { - await saasClient.listModels(); + relayModels = await saasClient.listModels(); } catch (err: unknown) { // Handle expired session — clear auth and trigger re-login const status = (err as { status?: number })?.status; @@ -473,15 +474,8 @@ export const useConnectionStore = create((set, get) => { // baseUrl = saasUrl + /api/v1/relay → kernel appends /chat/completions // apiKey = SaaS JWT token → sent as Authorization: Bearer - // Fetch available models from SaaS relay (shared by both branches) - let relayModels: Array<{ id: string }>; - try { - relayModels = await saasClient.listModels(); - } catch { - throw new Error('无法获取可用模型列表,请确认管理后台已配置 Provider 和模型'); - } - - if (relayModels.length === 0) { + // Models already fetched during health check above + if (!relayModels || relayModels.length === 0) { throw new Error('SaaS 平台没有可用模型,请先在管理后台配置 Provider 和模型'); } diff --git a/desktop/src/store/saasStore.ts b/desktop/src/store/saasStore.ts index 6bc0a11..71cd1bb 100644 --- a/desktop/src/store/saasStore.ts +++ b/desktop/src/store/saasStore.ts @@ -425,6 +425,12 @@ export const useSaaSStore = create((set, get) => { stopTelemetryCollector(); stopPromptOTASync(); + // Clear currentModel so next connection uses fresh model resolution + try { + const { useConversationStore } = require('./chat/conversationStore'); + useConversationStore.getState().setCurrentModel(''); + } catch { /* non-critical */ } + set({ isLoggedIn: false, account: null,