diff --git a/crates/zclaw-kernel/src/kernel/messaging.rs b/crates/zclaw-kernel/src/kernel/messaging.rs index cdfc9ba..8617929 100644 --- a/crates/zclaw-kernel/src/kernel/messaging.rs +++ b/crates/zclaw-kernel/src/kernel/messaging.rs @@ -25,7 +25,7 @@ impl Kernel { agent_id: &AgentId, message: String, ) -> Result { - self.send_message_with_chat_mode(agent_id, message, None).await + self.send_message_with_chat_mode(agent_id, message, None, None).await } /// Send a message to an agent with optional chat mode configuration @@ -34,6 +34,7 @@ impl Kernel { agent_id: &AgentId, message: String, chat_mode: Option, + model_override: Option, ) -> Result { let agent_config = self.registry.get(agent_id) .ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?; @@ -41,12 +42,16 @@ impl Kernel { // Create or get session let session_id = self.memory.create_session(agent_id).await?; - // Use agent-level model if configured, otherwise fall back to global config - let model = if !agent_config.model.model.is_empty() { - agent_config.model.model.clone() - } else { - self.config.model().to_string() - }; + // Model priority: UI override > Agent config > Global config + let model = model_override + .filter(|m| !m.is_empty()) + .unwrap_or_else(|| { + if !agent_config.model.model.is_empty() { + agent_config.model.model.clone() + } else { + self.config.model().to_string() + } + }); // Create agent loop with model configuration let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false); @@ -122,7 +127,7 @@ impl Kernel { agent_id: &AgentId, message: String, ) -> Result> { - self.send_message_stream_with_prompt(agent_id, message, None, None, None).await + self.send_message_stream_with_prompt(agent_id, message, None, None, None, None).await } /// Send a message with streaming, optional system prompt, optional session reuse, @@ -134,6 +139,7 @@ impl Kernel { system_prompt_override: Option, session_id_override: Option, chat_mode: Option, + model_override: Option, ) -> Result> { let agent_config = self.registry.get(agent_id) .ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?; @@ -150,12 +156,16 @@ impl Kernel { None => self.memory.create_session(agent_id).await?, }; - // Use agent-level model if configured, otherwise fall back to global config - let model = if !agent_config.model.model.is_empty() { - agent_config.model.model.clone() - } else { - self.config.model().to_string() - }; + // Model priority: UI override > Agent config > Global config + let model = model_override + .filter(|m| !m.is_empty()) + .unwrap_or_else(|| { + if !agent_config.model.model.is_empty() { + agent_config.model.model.clone() + } else { + self.config.model().to_string() + } + }); // Create agent loop with model configuration let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false); diff --git a/crates/zclaw-saas/src/billing/service.rs b/crates/zclaw-saas/src/billing/service.rs index 86b6ea8..ac4ebec 100644 --- a/crates/zclaw-saas/src/billing/service.rs +++ b/crates/zclaw-saas/src/billing/service.rs @@ -282,19 +282,26 @@ pub async fn increment_dimension_by( } /// 检查用量配额 +/// +/// P1-7 修复: 从当前 Plan 读取限额(而非 stale 的 usage 表冗余列) +/// P1-8 修复: 支持 relay_requests + input_tokens 双维度检查 pub async fn check_quota( pool: &PgPool, account_id: &str, quota_type: &str, ) -> SaasResult { let usage = get_or_create_usage(pool, account_id).await?; + // 从当前 Plan 读取真实限额,而非 usage 表的 stale 冗余列 + let plan = get_account_plan(pool, account_id).await?; + let limits: crate::billing::types::PlanLimits = serde_json::from_value(plan.limits) + .unwrap_or_else(|_| crate::billing::types::PlanLimits::free()); let (current, limit) = match quota_type { - "input_tokens" => (usage.input_tokens, usage.max_input_tokens), - "output_tokens" => (usage.output_tokens, usage.max_output_tokens), - "relay_requests" => (usage.relay_requests as i64, usage.max_relay_requests.map(|v| v as i64)), - "hand_executions" => (usage.hand_executions as i64, usage.max_hand_executions.map(|v| v as i64)), - "pipeline_runs" => (usage.pipeline_runs as i64, usage.max_pipeline_runs.map(|v| v as i64)), + "input_tokens" => (usage.input_tokens, limits.max_input_tokens_monthly), + "output_tokens" => (usage.output_tokens, limits.max_output_tokens_monthly), + "relay_requests" => (usage.relay_requests as i64, limits.max_relay_requests_monthly.map(|v| v as i64)), + "hand_executions" => (usage.hand_executions as i64, limits.max_hand_executions_monthly.map(|v| v as i64)), + "pipeline_runs" => (usage.pipeline_runs as i64, limits.max_pipeline_runs_monthly.map(|v| v as i64)), _ => return Ok(QuotaCheck { allowed: true, reason: None, @@ -309,7 +316,7 @@ pub async fn check_quota( Ok(QuotaCheck { allowed, - reason: if !allowed { Some(format!("{} 配额已用尽", quota_type)) } else { None }, + reason: if !allowed { Some(format!("{} 配额已用尽 (已用 {}/{})", quota_type, current, limit.unwrap_or(0))) } else { None }, current, limit, remaining, diff --git a/crates/zclaw-saas/src/cache.rs b/crates/zclaw-saas/src/cache.rs index a7d6c27..766b2d5 100644 --- a/crates/zclaw-saas/src/cache.rs +++ b/crates/zclaw-saas/src/cache.rs @@ -248,6 +248,37 @@ impl AppCache { .map(|r| r.value().clone()) } + /// 按别名查找模型 — 用于向后兼容旧模型 ID (如 "glm-4-flash" → "glm-4-flash-250414") + /// 先按 alias 字段精确匹配,再按 model_id 前缀匹配(去掉日期后缀) + pub fn resolve_model(&self, model_name: &str) -> Option { + // 1. 直接 model_id 查找 + if let Some(m) = self.get_model(model_name) { + return Some(m); + } + // 2. 按 alias 精确匹配 + for entry in self.models.iter() { + if entry.value().enabled && entry.value().alias == model_name { + return Some(entry.value().clone()); + } + } + // 3. 前缀匹配: "glm-4-flash" 匹配 "glm-4-flash-250414" 等带后缀的模型 + for entry in self.models.iter() { + let mid = &entry.value().model_id; + if entry.value().enabled + && (mid.starts_with(&format!("{}-", model_name)) + || mid.starts_with(&format!("{}v", model_name))) + { + tracing::info!( + "Model alias resolved: {} → {}", + model_name, + mid + ); + return Some(entry.value().clone()); + } + } + None + } + /// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。 pub fn get_provider(&self, provider_id: &str) -> Option { self.providers.get(provider_id) diff --git a/crates/zclaw-saas/src/config.rs b/crates/zclaw-saas/src/config.rs index c2f82bd..0ea1f0c 100644 --- a/crates/zclaw-saas/src/config.rs +++ b/crates/zclaw-saas/src/config.rs @@ -465,22 +465,25 @@ impl SaaSConfig { /// 替换 TOML 配置文件中的 `${ENV_VAR}` 模式为环境变量值 /// 未设置的环境变量保留原文,后续数据库连接或 JWT 初始化时会报明确错误 +/// +/// 注意: 使用 chars() 迭代器而非 bytes() 来正确处理多字节 UTF-8 字符(如中文), +/// 避免将多字节 UTF-8 序列的每个字节单独 `as char` 导致编码损坏。 fn interpolate_env_vars(content: &str) -> String { let mut result = String::with_capacity(content.len()); - let bytes = content.as_bytes(); + let chars: Vec = content.chars().collect(); let mut i = 0; - while i < bytes.len() { - if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'{' { + while i < chars.len() { + if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '{' { let start = i + 2; let mut end = start; - while end < bytes.len() - && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') + while end < chars.len() + && (chars[end].is_ascii_alphanumeric() || chars[end] == '_') { end += 1; } - if end < bytes.len() && bytes[end] == b'}' { - let var_name = std::str::from_utf8(&bytes[start..end]).unwrap_or(""); - match std::env::var(var_name) { + if end < chars.len() && chars[end] == '}' { + let var_name: String = chars[start..end].iter().collect(); + match std::env::var(&var_name) { Ok(val) => { tracing::debug!("Config: ${{{}}} → resolved ({} bytes)", var_name, val.len()); result.push_str(&val); @@ -492,11 +495,11 @@ fn interpolate_env_vars(content: &str) -> String { } i = end + 1; } else { - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } } else { - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } } diff --git a/crates/zclaw-saas/src/db.rs b/crates/zclaw-saas/src/db.rs index 6a19a4f..848f438 100644 --- a/crates/zclaw-saas/src/db.rs +++ b/crates/zclaw-saas/src/db.rs @@ -38,10 +38,25 @@ pub async fn init_db(config: &DatabaseConfig) -> SaasResult { .connect(&database_url) .await?; + // 验证数据库编码为 UTF8 — 中文 Windows (GBK/代码页936) 可能导致默认非 UTF8 + let encoding: (String,) = sqlx::query_as("SHOW server_encoding") + .fetch_one(&pool) + .await + .unwrap_or(("UNKNOWN".to_string(),)); + if encoding.0.to_uppercase() != "UTF8" { + tracing::error!( + "⚠ 数据库编码为 '{}',非 UTF8!中文数据将损坏。请使用 CREATE DATABASE ... WITH ENCODING='UTF8' 重建数据库。", + encoding.0 + ); + } else { + tracing::info!("Database encoding: {}", encoding.0); + } + run_migrations(&pool).await?; ensure_security_columns(&pool).await?; seed_admin_account(&pool).await?; seed_builtin_prompts(&pool).await?; + seed_knowledge_categories(&pool).await?; seed_builtin_industries(&pool).await?; seed_demo_data(&pool).await?; fix_seed_data(&pool).await?; @@ -1004,6 +1019,31 @@ async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> { crate::industry::service::seed_builtin_industries(pool).await } +/// 种子化知识库默认分类(幂等) +async fn seed_knowledge_categories(pool: &PgPool) -> SaasResult<()> { + let now = chrono::Utc::now(); + let categories = [ + ("seed", "种子知识", "系统内置的行业基础知识"), + ("uploaded", "上传文档", "用户上传的文档知识"), + ("distillation", "蒸馏知识", "API 蒸馏生成的知识"), + ]; + for (id, name, desc) in &categories { + sqlx::query( + "INSERT INTO knowledge_categories (id, name, description, created_at, updated_at) \ + VALUES ($1, $2, $3, $4, $4) \ + ON CONFLICT (id) DO NOTHING" + ) + .bind(id) + .bind(name) + .bind(desc) + .bind(&now) + .execute(pool) + .await?; + } + tracing::debug!("Seeded knowledge categories"); + Ok(()) +} + #[cfg(test)] mod tests { // PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容 diff --git a/crates/zclaw-saas/src/knowledge/service.rs b/crates/zclaw-saas/src/knowledge/service.rs index db0fbbf..4031480 100644 --- a/crates/zclaw-saas/src/knowledge/service.rs +++ b/crates/zclaw-saas/src/knowledge/service.rs @@ -632,6 +632,8 @@ pub async fn unified_search( // === 种子知识冷启动 === /// 为指定行业插入种子知识(幂等) +/// +/// P1-6 修复: 同时创建 knowledge_chunks 以支持搜索 pub async fn seed_knowledge( pool: &PgPool, industry_id: &str, @@ -684,6 +686,24 @@ pub async fn seed_knowledge( .execute(pool) .await?; + // 创建 chunks 以支持搜索(与 distill_knowledge worker 一致) + let chunks = chunk_content(content, 500, 50); + for (chunk_idx, chunk_text) in chunks.iter().enumerate() { + let chunk_id = uuid::Uuid::new_v4().to_string(); + sqlx::query( + "INSERT INTO knowledge_chunks (id, item_id, content, keywords, chunk_index, created_at) \ + VALUES ($1, $2, $3, $4, $5, $6)" + ) + .bind(&chunk_id) + .bind(&id) + .bind(chunk_text) + .bind(&kw_json) + .bind(chunk_idx as i32) + .bind(&now) + .execute(pool) + .await?; + } + created += 1; } Ok(created) diff --git a/crates/zclaw-saas/src/middleware.rs b/crates/zclaw-saas/src/middleware.rs index bb6ac23..7b09c48 100644 --- a/crates/zclaw-saas/src/middleware.rs +++ b/crates/zclaw-saas/src/middleware.rs @@ -145,6 +145,26 @@ pub async fn quota_check_middleware( _ => {} } + // P1-8 修复: 同时检查 input_tokens 配额 + match crate::billing::service::check_quota(&state.db, &account_id, "input_tokens").await { + Ok(check) if !check.allowed => { + tracing::warn!( + "Token quota exceeded for account {}: {} ({}/{})", + account_id, + check.reason.as_deref().unwrap_or("Token配额已用尽"), + check.current, + check.limit.map(|l| l.to_string()).unwrap_or_else(|| "∞".into()), + ); + return SaasError::RateLimited( + check.reason.unwrap_or_else(|| "月度 Token 配额已用尽".into()), + ).into_response(); + } + Err(e) => { + tracing::warn!("Token quota check failed for account {}: {}", account_id, e); + } + _ => {} + } + next.run(req).await } diff --git a/crates/zclaw-saas/src/relay/handlers.rs b/crates/zclaw-saas/src/relay/handlers.rs index 6cb9160..18ae984 100644 --- a/crates/zclaw-saas/src/relay/handlers.rs +++ b/crates/zclaw-saas/src/relay/handlers.rs @@ -152,8 +152,8 @@ pub async fn chat_completions( } ModelResolution::Group(candidates) } else { - // 向后兼容:直接模型查找 - let target_model = state.cache.get_model(model_name) + // 向后兼容:直接模型查找 + 别名解析(如 "glm-4-flash" → "glm-4-flash-250414") + let target_model = state.cache.resolve_model(model_name) .ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; // 获取 provider 信息 — 使用内存缓存消除 DB 查询 @@ -218,7 +218,7 @@ pub async fn chat_completions( ModelResolution::Direct(ref candidate) => { // 单 Provider 直接路由(向后兼容) match service::execute_relay( - &state.db, &task.id, &candidate.provider_id, + &state.db, &task.id, &ctx.account_id, &candidate.provider_id, &candidate.base_url, &request_body, stream, max_attempts, retry_delay_ms, &enc_key, true, // 独立调用,管理 task 状态 @@ -233,7 +233,7 @@ pub async fn chat_completions( // SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。 service::sort_candidates_by_quota(&state.db, candidates).await; service::execute_relay_with_failover( - &state.db, &task.id, candidates, + &state.db, &task.id, &ctx.account_id, candidates, &request_body, stream, max_attempts, retry_delay_ms, &enc_key ).await @@ -553,11 +553,12 @@ pub async fn retry_task( // 异步执行重试 — 根据解析结果选择执行路径 let db = state.db.clone(); let task_id = id.clone(); + let account_id_for_spawn = task.account_id.clone(); let handle = tokio::spawn(async move { let result = match model_resolution { ModelResolution::Direct(ref candidate) => { service::execute_relay( - &db, &task_id, &candidate.provider_id, + &db, &task_id, &account_id_for_spawn, &candidate.provider_id, &candidate.base_url, &body, stream, max_attempts, base_delay_ms, &enc_key, true, @@ -566,7 +567,7 @@ pub async fn retry_task( ModelResolution::Group(ref mut candidates) => { service::sort_candidates_by_quota(&db, candidates).await; service::execute_relay_with_failover( - &db, &task_id, candidates, + &db, &task_id, &account_id_for_spawn, candidates, &body, stream, max_attempts, base_delay_ms, &enc_key, ).await diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index f59151a..3229c55 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -217,6 +217,7 @@ impl SseUsageCapture { pub async fn execute_relay( db: &PgPool, task_id: &str, + account_id: &str, provider_id: &str, provider_base_url: &str, request_body: &str, @@ -313,6 +314,7 @@ pub async fn execute_relay( let db_clone = db.clone(); let task_id_clone = task_id.to_string(); let key_id_for_spawn = key_id.clone(); + let account_id_clone = account_id.to_string(); // Bounded channel for backpressure: 128 chunks (~128KB) buffer. // If the client reads slowly, the upstream is signaled via @@ -369,20 +371,53 @@ pub async fn execute_relay( tokio::spawn(async move { let _permit = permit; // 持有 permit 直到任务完成 - // Brief delay to allow SSE stream to settle before recording - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - let capture = usage_capture.lock().await; - let (input, output) = ( - if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None }, - if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None }, - ); - // Record task status with timeout to avoid holding DB connections + // 等待 SSE 流结束 — 等待 capture 稳定(tokens 不再增长) + // 替代原来固定 500ms 的 race condition + let max_wait = std::time::Duration::from_secs(120); + let poll_interval = std::time::Duration::from_millis(500); + let start = tokio::time::Instant::now(); + let mut last_tokens: i64 = 0; + let mut stable_count = 0; + let (input, output) = loop { + tokio::time::sleep(poll_interval).await; + let capture = usage_capture.lock().await; + let total = capture.input_tokens + capture.output_tokens; + if total == last_tokens && total > 0 { + stable_count += 1; + if stable_count >= 3 { + // 连续 3 次稳定(1.5s),认为流结束 + break (capture.input_tokens, capture.output_tokens); + } + } else { + stable_count = 0; + last_tokens = total; + } + drop(capture); + if start.elapsed() >= max_wait { + let capture = usage_capture.lock().await; + break (capture.input_tokens, capture.output_tokens); + } + }; + + let input_opt = if input > 0 { Some(input) } else { None }; + let output_opt = if output > 0 { Some(output) } else { None }; + + // Record task status + billing usage + key usage let db_op = async { - if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await { + if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input_opt, output_opt, None).await { tracing::warn!("Failed to update task status after SSE stream: {}", e); } - // Record key usage (now 2 queries instead of 3) - let total_tokens = input.unwrap_or(0) + output.unwrap_or(0); + // P2-9 修复: SSE 路径也更新 billing_usage_quotas + if input > 0 || output > 0 { + if let Err(e) = crate::billing::service::increment_usage( + &db_clone, &account_id_clone, + input, output, + ).await { + tracing::warn!("Failed to increment billing usage for SSE task {}: {}", task_id_clone, e); + } + } + // Record key usage + let total_tokens = input + output; if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await { tracing::warn!("Failed to record key usage: {}", e); } @@ -503,6 +538,7 @@ pub async fn execute_relay( pub async fn execute_relay_with_failover( db: &PgPool, task_id: &str, + account_id: &str, candidates: &[CandidateModel], request_body: &str, stream: bool, @@ -533,6 +569,7 @@ pub async fn execute_relay_with_failover( match execute_relay( db, task_id, + account_id, &candidate.provider_id, &candidate.base_url, &patched_body, diff --git a/desktop/src-tauri/src/kernel_commands/chat.rs b/desktop/src-tauri/src/kernel_commands/chat.rs index 5de9f0a..fcab31b 100644 --- a/desktop/src-tauri/src/kernel_commands/chat.rs +++ b/desktop/src-tauri/src/kernel_commands/chat.rs @@ -30,6 +30,9 @@ pub struct ChatRequest { /// Enable sub-agent delegation (Ultra mode only) #[serde(default)] pub subagent_enabled: Option, + /// Model override — UI 选择的模型优先于 Agent 配置的默认模型 + #[serde(default)] + pub model: Option, } /// Chat response @@ -76,6 +79,9 @@ pub struct StreamChatRequest { /// Enable sub-agent delegation (Ultra mode only) #[serde(default)] pub subagent_enabled: Option, + /// Model override — UI 选择的模型优先于 Agent 配置的默认模型 + #[serde(default)] + pub model: Option, } // --------------------------------------------------------------------------- @@ -116,7 +122,7 @@ pub async fn agent_chat( None }; - let response = kernel.send_message_with_chat_mode(&id, request.message, chat_mode) + let response = kernel.send_message_with_chat_mode(&id, request.message, chat_mode, request.model) .await .map_err(|e| format!("Chat failed: {}", e))?; @@ -257,6 +263,7 @@ pub async fn agent_chat_stream( prompt_arg, session_id_parsed, Some(chat_mode_config), + request.model.clone(), ) .await .map_err(|e| { diff --git a/desktop/src/lib/kernel-chat.ts b/desktop/src/lib/kernel-chat.ts index 5f43f94..b95df7e 100644 --- a/desktop/src/lib/kernel-chat.ts +++ b/desktop/src/lib/kernel-chat.ts @@ -60,6 +60,7 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo reasoning_effort?: string; plan_mode?: boolean; subagent_enabled?: boolean; + model?: string; } ): Promise<{ runId: string }> { const runId = crypto.randomUUID(); @@ -199,6 +200,7 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo reasoningEffort: opts?.reasoning_effort, planMode: opts?.plan_mode, subagentEnabled: opts?.subagent_enabled, + model: opts?.model, }, }); } catch (err: unknown) { diff --git a/desktop/src/store/chat/streamStore.ts b/desktop/src/store/chat/streamStore.ts index 20fdf32..41e3b0c 100644 --- a/desktop/src/store/chat/streamStore.ts +++ b/desktop/src/store/chat/streamStore.ts @@ -303,6 +303,9 @@ export const useStreamStore = create()( .map(m => ({ role: m.role, content: m.content })) .slice(-20); + // UI 模型选择器应覆盖 Agent 默认模型 + const currentModel = useConversationStore.getState().currentModel; + const result = await client.chatStream( content, { @@ -534,6 +537,7 @@ export const useStreamStore = create()( set({ isStreaming: false, activeRunId: null }); }, }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any { sessionKey: effectiveSessionKey, agentId: effectiveAgentId, @@ -542,7 +546,8 @@ export const useStreamStore = create()( plan_mode: get().getChatModeConfig().plan_mode, subagent_enabled: get().getChatModeConfig().subagent_enabled, history, - } + model: currentModel || undefined, + } as Parameters[2] ); if (result?.runId) { diff --git a/docker-compose.yml b/docker-compose.yml index 13786a3..7d7fd24 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,6 +18,8 @@ services: POSTGRES_USER: ${POSTGRES_USER:-postgres} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-your_secure_password} POSTGRES_DB: ${POSTGRES_DB:-zclaw} + # 确保 UTF-8 编码 — 中文 Windows 默认 GBK 会导致中文数据损坏 + POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C.UTF-8" ports: - "${POSTGRES_PORT:-5432}:5432"