perf(relay): full-chain optimization — key pool, model sync, SSE stream
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Phase 1 (Key Pool correctness): - RPM: fixed-minute window → sliding 60s aggregation (prevents 2x burst) - Remove fallback-to-provider-key bypass when all keys rate-limited - SSE semaphore: 16→64 permits, cleanup delay 60s→5s - Default 429 cooldown: 5min→60s (better for Coding Plan quotas) - Expire old key_usage_window rows on record Phase 2 (Frontend model sync): - currentModel empty-string fallback to glm-4-flash-250414 in relay client - Merge duplicate listModels() calls in connectionStore SaaS path - Show ModelSelector in Tauri mode when models available - Clear currentModel on SaaS logout Phase 3 (Relay performance): - Key Pool: DashMap in-memory cache (TTL 5s) for select_best_key - Cache invalidation on 429 marking Phase 4 (SSE stream): - AbortController integration for user-cancelled streams - SSE parsing: split by event boundaries (\n\n) instead of per-line - streamStore cancelStream adapts to 0-arg and 1-arg cancel fns
This commit is contained in:
@@ -3,10 +3,37 @@
|
|||||||
//! 管理 provider 的多个 API Key,实现智能轮转绕过限额。
|
//! 管理 provider 的多个 API Key,实现智能轮转绕过限额。
|
||||||
|
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use dashmap::DashMap;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::models::ProviderKeyRow;
|
use crate::models::ProviderKeyRow;
|
||||||
use crate::crypto;
|
use crate::crypto;
|
||||||
|
|
||||||
|
// ============ Key Pool Cache ============
|
||||||
|
|
||||||
|
/// TTL for cached key selections (seconds)
|
||||||
|
const KEY_CACHE_TTL: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
/// Cached key selection entry
|
||||||
|
struct CachedSelection {
|
||||||
|
selection: KeySelection,
|
||||||
|
cached_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Global cache for key selections, keyed by provider_id
|
||||||
|
static KEY_SELECTION_CACHE: OnceLock<DashMap<String, CachedSelection>> = OnceLock::new();
|
||||||
|
|
||||||
|
fn get_cache() -> &'static DashMap<String, CachedSelection> {
|
||||||
|
KEY_SELECTION_CACHE.get_or_init(DashMap::new)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invalidate cached selection for a provider (called on usage record and 429 marking)
|
||||||
|
fn invalidate_cache(provider_id: &str) {
|
||||||
|
let cache = get_cache();
|
||||||
|
cache.remove(provider_id);
|
||||||
|
}
|
||||||
|
|
||||||
/// 解密 key_value (如果已加密),否则原样返回
|
/// 解密 key_value (如果已加密),否则原样返回
|
||||||
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
||||||
if crypto::is_encrypted(encrypted) {
|
if crypto::is_encrypted(encrypted) {
|
||||||
@@ -29,6 +56,7 @@ pub struct PoolKey {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Key 选择结果
|
/// Key 选择结果
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct KeySelection {
|
pub struct KeySelection {
|
||||||
pub key: PoolKey,
|
pub key: PoolKey,
|
||||||
pub key_id: String,
|
pub key_id: String,
|
||||||
@@ -36,22 +64,34 @@ pub struct KeySelection {
|
|||||||
|
|
||||||
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||||||
///
|
///
|
||||||
/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询
|
/// 优化: 单次 JOIN 查询获取 Key + 滑动窗口(60s) RPM/TPM 使用量
|
||||||
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||||||
let now = chrono::Utc::now();
|
// Check in-memory cache first (TTL 5s)
|
||||||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
{
|
||||||
|
let cache = get_cache();
|
||||||
|
if let Some(entry) = cache.get(provider_id) {
|
||||||
|
if entry.cached_at.elapsed() < KEY_CACHE_TTL {
|
||||||
|
return Ok(entry.selection.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN)
|
let now = chrono::Utc::now();
|
||||||
|
|
||||||
|
// 滑动窗口: 聚合最近 60 秒内所有窗口行的 RPM/TPM,避免分钟边界突发
|
||||||
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
|
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
|
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
|
||||||
uw.request_count::bigint, uw.token_count
|
COALESCE(SUM(uw.request_count), 0)::bigint,
|
||||||
|
COALESCE(SUM(uw.token_count), 0)
|
||||||
FROM provider_keys pk
|
FROM provider_keys pk
|
||||||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1
|
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||||||
WHERE pk.provider_id = $2 AND pk.is_active = TRUE
|
AND uw.window_minute >= (NOW() - INTERVAL '1 minute')::TEXT
|
||||||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $3)
|
WHERE pk.provider_id = $1 AND pk.is_active = TRUE
|
||||||
|
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2)
|
||||||
|
GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm
|
||||||
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
|
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
|
||||||
).bind(¤t_minute).bind(provider_id).bind(&now).fetch_all(db).await?;
|
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||||
|
|
||||||
for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows {
|
for (id, key_value, priority, max_rpm, max_tpm, req_count, token_count) in &rows {
|
||||||
// RPM 检查
|
// RPM 检查
|
||||||
@@ -78,7 +118,7 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
|||||||
|
|
||||||
// 此 Key 可用 — 解密 key_value
|
// 此 Key 可用 — 解密 key_value
|
||||||
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
|
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
|
||||||
return Ok(KeySelection {
|
let selection = KeySelection {
|
||||||
key: PoolKey {
|
key: PoolKey {
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
key_value: decrypted_kv,
|
key_value: decrypted_kv,
|
||||||
@@ -87,12 +127,22 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
|||||||
max_tpm: *max_tpm,
|
max_tpm: *max_tpm,
|
||||||
},
|
},
|
||||||
key_id: id.clone(),
|
key_id: id.clone(),
|
||||||
|
};
|
||||||
|
// Cache the selection
|
||||||
|
get_cache().insert(provider_id.to_string(), CachedSelection {
|
||||||
|
selection: selection.clone(),
|
||||||
|
cached_at: Instant::now(),
|
||||||
});
|
});
|
||||||
|
return Ok(selection);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 所有 Key 都超限或无 Key
|
// 所有 Key 都超限或无 Key — 先检查是否存在活跃 Key
|
||||||
if rows.is_empty() {
|
let has_any_key: Option<(bool,)> = sqlx::query_as(
|
||||||
// 检查是否有冷却中的 Key,返回预计等待时间
|
"SELECT COUNT(*) > 0 FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE"
|
||||||
|
).bind(provider_id).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
if has_any_key.is_some_and(|(b,)| b) {
|
||||||
|
// 有 key 但全部 cooldown 或超限 — 检查最快恢复时间
|
||||||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT cooldown_until::TEXT FROM provider_keys
|
"SELECT cooldown_until::TEXT FROM provider_keys
|
||||||
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2
|
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2
|
||||||
@@ -106,34 +156,14 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
|||||||
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// 回退到 provider 单 Key
|
// Key 存在但 RPM/TPM 全部用尽(无 cooldown)
|
||||||
let provider_key: Option<String> = sqlx::query_scalar(
|
return Err(SaasError::RateLimited(
|
||||||
"SELECT api_key FROM providers WHERE id = $1"
|
|
||||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
|
||||||
|
|
||||||
if let Some(key) = provider_key {
|
|
||||||
let decrypted = decrypt_key_value(&key, enc_key)?;
|
|
||||||
return Ok(KeySelection {
|
|
||||||
key: PoolKey {
|
|
||||||
id: "provider-fallback".to_string(),
|
|
||||||
key_value: decrypted,
|
|
||||||
priority: 0,
|
|
||||||
max_rpm: None,
|
|
||||||
max_tpm: None,
|
|
||||||
},
|
|
||||||
key_id: "provider-fallback".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.is_empty() {
|
|
||||||
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
|
|
||||||
} else {
|
|
||||||
Err(SaasError::RateLimited(
|
|
||||||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||||||
))
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 记录 Key 使用量(滑动窗口)
|
/// 记录 Key 使用量(滑动窗口)
|
||||||
@@ -168,6 +198,12 @@ pub async fn record_key_usage(
|
|||||||
.bind(tokens).bind(key_id)
|
.bind(tokens).bind(key_id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
|
|
||||||
|
// 3. 清理过期的滑动窗口行(保留最近 2 分钟即可)
|
||||||
|
let _ = sqlx::query(
|
||||||
|
"DELETE FROM key_usage_window WHERE window_minute < (NOW() - INTERVAL '2 minutes')::TEXT"
|
||||||
|
)
|
||||||
|
.execute(db).await; // 忽略错误,非关键操作
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,8 +216,8 @@ pub async fn mark_key_429(
|
|||||||
let cooldown = if let Some(secs) = retry_after_seconds {
|
let cooldown = if let Some(secs) = retry_after_seconds {
|
||||||
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64))
|
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64))
|
||||||
} else {
|
} else {
|
||||||
// 默认 5 分钟冷却
|
// 默认 60 秒冷却(适合小配额 Coding Plan 账号)
|
||||||
(chrono::Utc::now() + chrono::Duration::minutes(5))
|
chrono::Utc::now() + chrono::Duration::seconds(60)
|
||||||
};
|
};
|
||||||
|
|
||||||
let now = chrono::Utc::now();
|
let now = chrono::Utc::now();
|
||||||
@@ -199,6 +235,14 @@ pub async fn mark_key_429(
|
|||||||
cooldown
|
cooldown
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Invalidate cache for this key's provider (query provider_id then clear)
|
||||||
|
let pid_result: Result<Option<(String,)>, _> = sqlx::query_as(
|
||||||
|
"SELECT provider_id FROM provider_keys WHERE id = $1"
|
||||||
|
).bind(key_id).fetch_optional(db).await;
|
||||||
|
if let Ok(Some((pid,))) = pid_result {
|
||||||
|
invalidate_cache(&pid);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,6 +368,6 @@ fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
|||||||
let diff = c.signed_duration_since(n);
|
let diff = c.signed_duration_since(n);
|
||||||
diff.num_seconds().max(0)
|
diff.num_seconds().max(0)
|
||||||
}
|
}
|
||||||
_ => 300, // 默认 5 分钟
|
_ => 60, // 默认 60 秒
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
|
|||||||
/// 实测 Kimi for Coding 的 thinking→content 间隔可达 60s+,需要更宽容的超时。
|
/// 实测 Kimi for Coding 的 thinking→content 间隔可达 60s+,需要更宽容的超时。
|
||||||
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(180);
|
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(180);
|
||||||
|
|
||||||
/// 流结束后延迟清理的时间窗口
|
/// 流结束后延迟清理的时间窗口(缩短到 5s,仅用于 Arc 引用释放)
|
||||||
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(60);
|
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
||||||
fn is_retryable_status(status: u16) -> bool {
|
fn is_retryable_status(status: u16) -> bool {
|
||||||
@@ -357,7 +357,7 @@ pub async fn execute_relay(
|
|||||||
// SSE 流结束后异步记录 usage + Key 使用量
|
// SSE 流结束后异步记录 usage + Key 使用量
|
||||||
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
||||||
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
|
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
|
||||||
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(16)));
|
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(64)));
|
||||||
let permit = match semaphore.clone().try_acquire_owned() {
|
let permit = match semaphore.clone().try_acquire_owned() {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ import { ReasoningBlock } from './ai/ReasoningBlock';
|
|||||||
import { StreamingText } from './ai/StreamingText';
|
import { StreamingText } from './ai/StreamingText';
|
||||||
import { ChatMode } from './ai/ChatMode';
|
import { ChatMode } from './ai/ChatMode';
|
||||||
import { ModelSelector } from './ai/ModelSelector';
|
import { ModelSelector } from './ai/ModelSelector';
|
||||||
import { isTauriRuntime } from '../lib/tauri-gateway';
|
|
||||||
import { SuggestionChips } from './ai/SuggestionChips';
|
import { SuggestionChips } from './ai/SuggestionChips';
|
||||||
import { PipelineResultPreview } from './pipeline/PipelineResultPreview';
|
import { PipelineResultPreview } from './pipeline/PipelineResultPreview';
|
||||||
import { PresentationContainer } from './presentation/PresentationContainer';
|
import { PresentationContainer } from './presentation/PresentationContainer';
|
||||||
@@ -563,7 +562,7 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD
|
|||||||
}
|
}
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
{!isTauriRuntime() && (
|
{models.length > 0 && (
|
||||||
<ModelSelector
|
<ModelSelector
|
||||||
models={models.map(m => ({ id: m.id, name: m.name, provider: m.provider }))}
|
models={models.map(m => ({ id: m.id, name: m.name, provider: m.provider }))}
|
||||||
currentModel={currentModel}
|
currentModel={currentModel}
|
||||||
|
|||||||
@@ -92,6 +92,9 @@ export function createSaaSRelayGatewayClient(
|
|||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// Helper: OpenAI SSE streaming via SaaS relay
|
// Helper: OpenAI SSE streaming via SaaS relay
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
// AbortController for cancelling active streams
|
||||||
|
let activeAbortController: AbortController | null = null;
|
||||||
|
|
||||||
async function chatStream(
|
async function chatStream(
|
||||||
message: string,
|
message: string,
|
||||||
callbacks: {
|
callbacks: {
|
||||||
@@ -112,10 +115,13 @@ export function createSaaSRelayGatewayClient(
|
|||||||
},
|
},
|
||||||
): Promise<{ runId: string }> {
|
): Promise<{ runId: string }> {
|
||||||
const runId = `run_${Date.now()}`;
|
const runId = `run_${Date.now()}`;
|
||||||
|
const abortController = new AbortController();
|
||||||
|
activeAbortController = abortController;
|
||||||
|
const aborted = () => abortController.signal.aborted;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const body: Record<string, unknown> = {
|
const body: Record<string, unknown> = {
|
||||||
model: getModel(),
|
model: getModel() || 'glm-4-flash-250414',
|
||||||
messages: [{ role: 'user', content: message }],
|
messages: [{ role: 'user', content: message }],
|
||||||
stream: true,
|
stream: true,
|
||||||
};
|
};
|
||||||
@@ -148,67 +154,88 @@ export function createSaaSRelayGatewayClient(
|
|||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let buffer = '';
|
let buffer = '';
|
||||||
|
|
||||||
while (true) {
|
while (!aborted()) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
if (done) break;
|
if (done) break;
|
||||||
|
|
||||||
buffer += decoder.decode(value, { stream: true });
|
buffer += decoder.decode(value, { stream: true });
|
||||||
const lines = buffer.split('\n');
|
|
||||||
buffer = lines.pop() || ''; // keep incomplete last line
|
|
||||||
|
|
||||||
for (const line of lines) {
|
// Optimized SSE parsing: split by double-newline (event boundaries)
|
||||||
if (!line.startsWith('data: ')) continue;
|
let boundary: number;
|
||||||
const data = line.slice(6).trim();
|
while ((boundary = buffer.indexOf('\n\n')) !== -1) {
|
||||||
if (data === '[DONE]') continue;
|
const eventBlock = buffer.slice(0, boundary);
|
||||||
|
buffer = buffer.slice(boundary + 2);
|
||||||
|
|
||||||
try {
|
// Process each line in the event block
|
||||||
const parsed = JSON.parse(data);
|
const lines = eventBlock.split('\n');
|
||||||
|
for (const line of lines) {
|
||||||
|
if (!line.startsWith('data: ')) continue;
|
||||||
|
const data = line.slice(6).trim();
|
||||||
|
if (data === '[DONE]') continue;
|
||||||
|
|
||||||
// Handle SSE error events from relay (e.g. stream_timeout)
|
try {
|
||||||
if (parsed.error) {
|
const parsed = JSON.parse(data);
|
||||||
const errMsg = parsed.message || parsed.error || 'Unknown stream error';
|
|
||||||
log.warn('SSE stream error:', errMsg);
|
// Handle SSE error events from relay (e.g. stream_timeout)
|
||||||
callbacks.onError(errMsg);
|
if (parsed.error) {
|
||||||
callbacks.onComplete();
|
const errMsg = parsed.message || parsed.error || 'Unknown stream error';
|
||||||
return { runId };
|
log.warn('SSE stream error:', errMsg);
|
||||||
|
callbacks.onError(errMsg);
|
||||||
|
callbacks.onComplete();
|
||||||
|
return { runId };
|
||||||
|
}
|
||||||
|
|
||||||
|
const choices = parsed.choices?.[0];
|
||||||
|
if (!choices) continue;
|
||||||
|
|
||||||
|
const delta = choices.delta;
|
||||||
|
|
||||||
|
// Handle thinking/reasoning content
|
||||||
|
if (delta?.reasoning_content) {
|
||||||
|
callbacks.onThinkingDelta?.(delta.reasoning_content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular content
|
||||||
|
if (delta?.content) {
|
||||||
|
callbacks.onDelta(delta.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for completion
|
||||||
|
if (choices.finish_reason) {
|
||||||
|
const usage = parsed.usage;
|
||||||
|
callbacks.onComplete(
|
||||||
|
usage?.prompt_tokens,
|
||||||
|
usage?.completion_tokens,
|
||||||
|
);
|
||||||
|
return { runId };
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Skip malformed SSE lines
|
||||||
}
|
}
|
||||||
|
|
||||||
const choices = parsed.choices?.[0];
|
|
||||||
if (!choices) continue;
|
|
||||||
|
|
||||||
const delta = choices.delta;
|
|
||||||
|
|
||||||
// Handle thinking/reasoning content
|
|
||||||
if (delta?.reasoning_content) {
|
|
||||||
callbacks.onThinkingDelta?.(delta.reasoning_content);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle regular content
|
|
||||||
if (delta?.content) {
|
|
||||||
callbacks.onDelta(delta.content);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for completion
|
|
||||||
if (choices.finish_reason) {
|
|
||||||
const usage = parsed.usage;
|
|
||||||
callbacks.onComplete(
|
|
||||||
usage?.prompt_tokens,
|
|
||||||
usage?.completion_tokens,
|
|
||||||
);
|
|
||||||
return { runId };
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
// Skip malformed SSE lines
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If aborted, cancel the reader
|
||||||
|
if (aborted()) {
|
||||||
|
try { reader.cancel(); } catch { /* already closed */ }
|
||||||
|
}
|
||||||
|
|
||||||
// Stream ended without explicit finish_reason
|
// Stream ended without explicit finish_reason
|
||||||
callbacks.onComplete();
|
callbacks.onComplete();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
if (aborted()) {
|
||||||
|
// Cancelled by user — don't report as error
|
||||||
|
callbacks.onComplete();
|
||||||
|
return { runId };
|
||||||
|
}
|
||||||
const msg = err instanceof Error ? err.message : String(err);
|
const msg = err instanceof Error ? err.message : String(err);
|
||||||
callbacks.onError(msg);
|
callbacks.onError(msg);
|
||||||
callbacks.onComplete();
|
callbacks.onComplete();
|
||||||
|
} finally {
|
||||||
|
if (activeAbortController === abortController) {
|
||||||
|
activeAbortController = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { runId };
|
return { runId };
|
||||||
@@ -256,6 +283,13 @@ export function createSaaSRelayGatewayClient(
|
|||||||
|
|
||||||
// --- Chat ---
|
// --- Chat ---
|
||||||
chatStream,
|
chatStream,
|
||||||
|
cancelStream: () => {
|
||||||
|
if (activeAbortController) {
|
||||||
|
activeAbortController.abort();
|
||||||
|
activeAbortController = null;
|
||||||
|
log.info('SSE stream cancelled by user');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
// --- Hands ---
|
// --- Hands ---
|
||||||
listHands: async () => ({ hands: [] }),
|
listHands: async () => ({ hands: [] }),
|
||||||
|
|||||||
@@ -581,11 +581,20 @@ export const useStreamStore = create<StreamState>()(
|
|||||||
if (!isStreaming) return;
|
if (!isStreaming) return;
|
||||||
|
|
||||||
// 1. Tell backend to abort — use sessionKey (which is the sessionId in Tauri)
|
// 1. Tell backend to abort — use sessionKey (which is the sessionId in Tauri)
|
||||||
|
// Also abort the frontend SSE fetch via cancelStream()
|
||||||
try {
|
try {
|
||||||
const client = getClient();
|
const client = getClient() as unknown as Record<string, unknown>;
|
||||||
if ('cancelStream' in client) {
|
if ('cancelStream' in client) {
|
||||||
const sessionId = useConversationStore.getState().sessionKey || activeRunId || '';
|
const fn = client.cancelStream;
|
||||||
(client as { cancelStream: (id: string) => void }).cancelStream(sessionId);
|
if (typeof fn === 'function') {
|
||||||
|
// Call with or without sessionId depending on arity
|
||||||
|
if (fn.length > 0) {
|
||||||
|
const sessionId = useConversationStore.getState().sessionKey || activeRunId || '';
|
||||||
|
(fn as (id: string) => void)(sessionId);
|
||||||
|
} else {
|
||||||
|
(fn as () => void)();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
// Backend cancel is best-effort; proceed with local cleanup
|
// Backend cancel is best-effort; proceed with local cleanup
|
||||||
|
|||||||
@@ -441,9 +441,10 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
|
|||||||
// Configure the singleton client (cookie auth — no token needed)
|
// Configure the singleton client (cookie auth — no token needed)
|
||||||
saasClient.setBaseUrl(session.saasUrl);
|
saasClient.setBaseUrl(session.saasUrl);
|
||||||
|
|
||||||
// Health check via GET /api/v1/relay/models
|
// Health check + model list: merged single listModels() call
|
||||||
|
let relayModels: Array<{ id: string; alias?: string }> | null = null;
|
||||||
try {
|
try {
|
||||||
await saasClient.listModels();
|
relayModels = await saasClient.listModels();
|
||||||
} catch (err: unknown) {
|
} catch (err: unknown) {
|
||||||
// Handle expired session — clear auth and trigger re-login
|
// Handle expired session — clear auth and trigger re-login
|
||||||
const status = (err as { status?: number })?.status;
|
const status = (err as { status?: number })?.status;
|
||||||
@@ -473,15 +474,8 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
|
|||||||
// baseUrl = saasUrl + /api/v1/relay → kernel appends /chat/completions
|
// baseUrl = saasUrl + /api/v1/relay → kernel appends /chat/completions
|
||||||
// apiKey = SaaS JWT token → sent as Authorization: Bearer <jwt>
|
// apiKey = SaaS JWT token → sent as Authorization: Bearer <jwt>
|
||||||
|
|
||||||
// Fetch available models from SaaS relay (shared by both branches)
|
// Models already fetched during health check above
|
||||||
let relayModels: Array<{ id: string }>;
|
if (!relayModels || relayModels.length === 0) {
|
||||||
try {
|
|
||||||
relayModels = await saasClient.listModels();
|
|
||||||
} catch {
|
|
||||||
throw new Error('无法获取可用模型列表,请确认管理后台已配置 Provider 和模型');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (relayModels.length === 0) {
|
|
||||||
throw new Error('SaaS 平台没有可用模型,请先在管理后台配置 Provider 和模型');
|
throw new Error('SaaS 平台没有可用模型,请先在管理后台配置 Provider 和模型');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -425,6 +425,12 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
|
|||||||
stopTelemetryCollector();
|
stopTelemetryCollector();
|
||||||
stopPromptOTASync();
|
stopPromptOTASync();
|
||||||
|
|
||||||
|
// Clear currentModel so next connection uses fresh model resolution
|
||||||
|
try {
|
||||||
|
const { useConversationStore } = require('./chat/conversationStore');
|
||||||
|
useConversationStore.getState().setCurrentModel('');
|
||||||
|
} catch { /* non-critical */ }
|
||||||
|
|
||||||
set({
|
set({
|
||||||
isLoggedIn: false,
|
isLoggedIn: false,
|
||||||
account: null,
|
account: null,
|
||||||
|
|||||||
Reference in New Issue
Block a user