From 28299807b6c58ec078a509c5051bb642dd5aad91 Mon Sep 17 00:00:00 2001 From: iven Date: Thu, 2 Apr 2026 19:24:44 +0800 Subject: [PATCH] =?UTF-8?q?fix(desktop):=20DeerFlow=20UI=20=E2=80=94=20Cha?= =?UTF-8?q?tArea=20refactor=20+=20ai-elements=20+=20dead=20CSS=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChatArea retry button uses setInput instead of direct sendToGateway, fix bootstrap spinner stuck for non-logged-in users, remove dead CSS (aurora-title/sidebar-open/quick-action-chips), add ai components (ReasoningBlock/StreamingText/ChatMode/ModelSelector/TaskProgress), add ClassroomPlayer + ResizableChatLayout + artifact panel Co-Authored-By: Claude Opus 4.6 --- crates/zclaw-saas/Cargo.toml | 6 +- .../20260401000003_models_is_embedding.sql | 10 + .../20260402000003_scheduled_task_results.sql | 5 + crates/zclaw-saas/src/auth/mod.rs | 76 ++- crates/zclaw-saas/src/billing/handlers.rs | 196 +++++-- crates/zclaw-saas/src/billing/mod.rs | 7 +- crates/zclaw-saas/src/billing/payment.rs | 497 ++++++++++++++++-- crates/zclaw-saas/src/billing/service.rs | 73 ++- crates/zclaw-saas/src/cache.rs | 16 + crates/zclaw-saas/src/config.rs | 173 +++++- crates/zclaw-saas/src/main.rs | 1 + .../zclaw-saas/src/model_config/handlers.rs | 19 + crates/zclaw-saas/src/models/prompt.rs | 9 + crates/zclaw-saas/src/models/telemetry.rs | 18 + crates/zclaw-saas/src/prompt/service.rs | 20 +- crates/zclaw-saas/src/relay/key_pool.rs | 33 ++ crates/zclaw-saas/src/relay/service.rs | 139 ++++- .../zclaw-saas/src/scheduled_task/service.rs | 10 +- crates/zclaw-saas/src/scheduled_task/types.rs | 2 + crates/zclaw-saas/src/scheduler.rs | 89 +++- crates/zclaw-saas/src/state.rs | 108 +++- crates/zclaw-saas/src/telemetry/service.rs | 26 +- .../src/workers/generate_embedding.rs | 37 +- .../src/workers/update_last_used.rs | 7 +- .../src-tauri/src/classroom_commands/chat.rs | 223 ++++++++ .../src/classroom_commands/export.rs | 152 ++++++ .../src/classroom_commands/generate.rs | 286 ++++++++++ .../src-tauri/src/classroom_commands/mod.rs | 41 ++ .../src-tauri/src/intelligence/identity.rs | 17 +- desktop/src-tauri/src/kernel_commands/chat.rs | 5 + .../src/kernel_commands/lifecycle.rs | 127 +++++ desktop/src-tauri/src/lib.rs | 23 +- desktop/src/App.tsx | 10 +- desktop/src/components/ChatArea.tsx | 109 +++- desktop/src/components/ClassroomPreviewer.tsx | 11 + .../components/FirstConversationPrompt.tsx | 42 +- .../src/components/PipelineResultPreview.tsx | 7 + desktop/src/components/ai/Conversation.tsx | 2 +- .../src/components/ai/ResizableChatLayout.tsx | 4 +- .../components/classroom_player/AgentChat.tsx | 121 +++++ .../classroom_player/ClassroomPlayer.tsx | 231 ++++++++ .../classroom_player/NotesSidebar.tsx | 71 +++ .../classroom_player/SceneRenderer.tsx | 219 ++++++++ .../components/classroom_player/TtsPlayer.tsx | 155 ++++++ .../classroom_player/WhiteboardCanvas.tsx | 295 +++++++++++ .../src/components/classroom_player/index.ts | 12 + desktop/src/hooks/useClassroom.ts | 76 +++ desktop/src/index.css | 57 +- desktop/src/lib/audit-logger.ts | 4 + desktop/src/lib/classroom-adapter.ts | 142 +++++ desktop/src/lib/error-handling.ts | 9 +- desktop/src/lib/kernel-chat.ts | 6 + desktop/src/lib/kernel-hands.ts | 36 +- desktop/src/lib/kernel-skills.ts | 6 +- desktop/src/lib/kernel-triggers.ts | 7 +- desktop/src/lib/kernel-types.ts | 7 + desktop/src/lib/saas-admin.ts | 233 -------- desktop/src/lib/saas-client.ts | 100 +--- desktop/src/lib/saas-relay.ts | 3 +- desktop/src/lib/secure-storage.ts | 30 -- desktop/src/lib/security-index.ts | 1 - desktop/src/store/chat/artifactStore.ts | 54 ++ desktop/src/store/chat/conversationStore.ts | 368 +++++++++++++ desktop/src/store/chatStore.ts | 67 +-- desktop/src/store/classroomStore.ts | 223 ++++++++ desktop/src/store/index.ts | 11 + desktop/src/store/saasStore.ts | 21 + desktop/src/types/chat.ts | 133 +++++ desktop/src/types/classroom.ts | 181 +++++++ desktop/src/types/index.ts | 41 ++ 70 files changed, 4938 insertions(+), 618 deletions(-) create mode 100644 crates/zclaw-saas/migrations/20260401000003_models_is_embedding.sql create mode 100644 crates/zclaw-saas/migrations/20260402000003_scheduled_task_results.sql create mode 100644 desktop/src-tauri/src/classroom_commands/chat.rs create mode 100644 desktop/src-tauri/src/classroom_commands/export.rs create mode 100644 desktop/src-tauri/src/classroom_commands/generate.rs create mode 100644 desktop/src-tauri/src/classroom_commands/mod.rs create mode 100644 desktop/src/components/classroom_player/AgentChat.tsx create mode 100644 desktop/src/components/classroom_player/ClassroomPlayer.tsx create mode 100644 desktop/src/components/classroom_player/NotesSidebar.tsx create mode 100644 desktop/src/components/classroom_player/SceneRenderer.tsx create mode 100644 desktop/src/components/classroom_player/TtsPlayer.tsx create mode 100644 desktop/src/components/classroom_player/WhiteboardCanvas.tsx create mode 100644 desktop/src/components/classroom_player/index.ts create mode 100644 desktop/src/hooks/useClassroom.ts create mode 100644 desktop/src/lib/classroom-adapter.ts delete mode 100644 desktop/src/lib/saas-admin.ts create mode 100644 desktop/src/store/chat/artifactStore.ts create mode 100644 desktop/src/store/chat/conversationStore.ts create mode 100644 desktop/src/store/classroomStore.ts create mode 100644 desktop/src/types/chat.ts create mode 100644 desktop/src/types/classroom.ts diff --git a/crates/zclaw-saas/Cargo.toml b/crates/zclaw-saas/Cargo.toml index e6fb012..4c4f3b8 100644 --- a/crates/zclaw-saas/Cargo.toml +++ b/crates/zclaw-saas/Cargo.toml @@ -26,14 +26,17 @@ chrono = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } sqlx = { workspace = true } +pgvector = { version = "0.4", features = ["sqlx"] } reqwest = { workspace = true } secrecy = { workspace = true } sha2 = { workspace = true } rand = { workspace = true } dashmap = { workspace = true } hex = { workspace = true } +rsa = { workspace = true, features = ["sha2"] } +base64 = { workspace = true } socket2 = { workspace = true } -url = "2" +url = { workspace = true } axum = { workspace = true } axum-extra = { workspace = true } @@ -47,6 +50,7 @@ data-encoding = "2" regex = { workspace = true } aes-gcm = { workspace = true } bytes = { workspace = true } +async-stream = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/zclaw-saas/migrations/20260401000003_models_is_embedding.sql b/crates/zclaw-saas/migrations/20260401000003_models_is_embedding.sql new file mode 100644 index 0000000..35d5a02 --- /dev/null +++ b/crates/zclaw-saas/migrations/20260401000003_models_is_embedding.sql @@ -0,0 +1,10 @@ +-- Add is_embedding column to models table +-- Distinguishes embedding models from chat/completion models +ALTER TABLE models ADD COLUMN IF NOT EXISTS is_embedding BOOLEAN NOT NULL DEFAULT FALSE; + +-- Add model_type column for future extensibility (chat, embedding, image, audio, etc.) +ALTER TABLE models ADD COLUMN IF NOT EXISTS model_type TEXT NOT NULL DEFAULT 'chat'; + +-- Index for quick filtering of embedding models +CREATE INDEX IF NOT EXISTS idx_models_is_embedding ON models(is_embedding) WHERE is_embedding = TRUE; +CREATE INDEX IF NOT EXISTS idx_models_model_type ON models(model_type); diff --git a/crates/zclaw-saas/migrations/20260402000003_scheduled_task_results.sql b/crates/zclaw-saas/migrations/20260402000003_scheduled_task_results.sql new file mode 100644 index 0000000..d64a57c --- /dev/null +++ b/crates/zclaw-saas/migrations/20260402000003_scheduled_task_results.sql @@ -0,0 +1,5 @@ +-- Add execution result columns to scheduled_tasks +-- Tracks the output and duration of each task execution for observability + +ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_result TEXT; +ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_duration_ms INTEGER; diff --git a/crates/zclaw-saas/src/auth/mod.rs b/crates/zclaw-saas/src/auth/mod.rs index 188562f..915c091 100644 --- a/crates/zclaw-saas/src/auth/mod.rs +++ b/crates/zclaw-saas/src/auth/mod.rs @@ -67,14 +67,17 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option Option { - // 优先从 ConnectInfo 获取 - if let Some(ConnectInfo(addr)) = req.extensions().get::>() { - return Some(addr.ip().to_string()); +/// 从请求中提取客户端 IP(安全版:仅对 trusted_proxies 解析 XFF) +fn extract_client_ip(req: &Request, trusted_proxies: &[String]) -> Option { + // 优先从 ConnectInfo 获取直接连接 IP + let connect_ip = req.extensions() + .get::>() + .map(|ConnectInfo(addr)| addr.ip().to_string()); + + // 仅当直接连接 IP 在 trusted_proxies 中时,才信任 XFF/X-Real-IP + if let Some(ref ip) = connect_ip { + if trusted_proxies.iter().any(|p| p == ip) { + // 受信代理 → 从 XFF 取真实客户端 IP + if let Some(forwarded) = req.headers() + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + { + if let Some(client) = forwarded.split(',').next() { + let trimmed = client.trim(); + if !trimmed.is_empty() { + return Some(trimmed.to_string()); + } + } + } + // 尝试 X-Real-IP + if let Some(real_ip) = req.headers() + .get("x-real-ip") + .and_then(|v| v.to_str().ok()) + { + let trimmed = real_ip.trim(); + if !trimmed.is_empty() { + return Some(trimmed.to_string()); + } + } + } } - // 回退到 X-Forwarded-For / X-Real-IP - if let Some(forwarded) = req.headers() - .get("x-forwarded-for") - .and_then(|v| v.to_str().ok()) - { - return Some(forwarded.split(',').next()?.trim().to_string()); - } - req.headers() - .get("x-real-ip") - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()) + + // 非受信来源或无代理头 → 返回直接连接 IP + connect_ip } /// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份 @@ -110,7 +133,10 @@ pub async fn auth_middleware( mut req: Request, next: Next, ) -> Response { - let client_ip = extract_client_ip(&req); + let client_ip = { + let config = state.config.read().await; + extract_client_ip(&req, &config.server.trusted_proxies) + }; let auth_header = req.headers() .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()); diff --git a/crates/zclaw-saas/src/billing/handlers.rs b/crates/zclaw-saas/src/billing/handlers.rs index ea812bb..00c0b87 100644 --- a/crates/zclaw-saas/src/billing/handlers.rs +++ b/crates/zclaw-saas/src/billing/handlers.rs @@ -4,7 +4,6 @@ use axum::{ extract::{Extension, Form, Path, Query, State}, Json, }; -use axum::response::Html; use serde::Deserialize; use crate::auth::types::AuthContext; @@ -90,9 +89,8 @@ pub async fn increment_usage_dimension( )); } - for _ in 0..req.count { - service::increment_dimension(&state.db, &ctx.account_id, &req.dimension).await?; - } + // 单次原子更新,避免循环 N 次数据库查询 + service::increment_dimension_by(&state.db, &ctx.account_id, &req.dimension, req.count).await?; // 返回更新后的用量 let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?; @@ -109,10 +107,12 @@ pub async fn create_payment( Extension(ctx): Extension, Json(req): Json, ) -> SaasResult> { + let config = state.config.read().await; let result = super::payment::create_payment( &state.db, &ctx.account_id, &req, + &config.payment, ).await?; Ok(Json(result)) } @@ -139,22 +139,28 @@ pub async fn payment_callback( ) -> SaasResult { tracing::info!("Payment callback received: method={}, body_len={}", method, body.len()); - // 解析回调参数 let body_str = String::from_utf8_lossy(&body); + let config = state.config.read().await; - // 支付宝回调:form-urlencoded 格式 - // 微信回调:JSON 格式 - let (trade_no, status) = if method == "alipay" { - parse_alipay_callback(&body_str) + let (trade_no, status, callback_amount) = if method == "alipay" { + parse_alipay_callback(&body_str, &config.payment)? } else if method == "wechat" { - parse_wechat_callback(&body_str) + parse_wechat_callback(&body_str, &config.payment)? } else { tracing::warn!("Unknown payment callback method: {}", method); return Ok("fail".into()); }; - if let Some(trade_no) = trade_no { - super::payment::handle_payment_callback(&state.db, &trade_no, &status).await?; + // trade_no 是必填字段,缺失说明回调格式异常 + let trade_no = trade_no.ok_or_else(|| { + tracing::warn!("Payment callback missing out_trade_no: method={}", method); + SaasError::InvalidInput("回调缺少交易号".into()) + })?; + + if let Err(e) = super::payment::handle_payment_callback(&state.db, &trade_no, &status, callback_amount).await { + // 对外返回通用错误,不泄露内部细节 + tracing::error!("Payment callback processing failed: method={}, error={}", method, e); + return Ok("fail".into()); } // 支付宝期望 "success",微信期望 JSON @@ -178,6 +184,11 @@ pub struct MockPayQuery { pub async fn mock_pay_page( Query(params): Query, ) -> axum::response::Html { + // HTML 转义防止 XSS + let safe_subject = html_escape(¶ms.subject); + let safe_trade_no = html_escape(¶ms.trade_no); + let amount_yuan = params.amount as f64 / 100.0; + axum::response::Html(format!(r#" @@ -194,23 +205,19 @@ body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20
-
{subject}
+
{safe_subject}
¥{amount_yuan}
- 订单号: {trade_no} + 订单号: {safe_trade_no}
- +
- "#, - subject = params.subject, - trade_no = params.trade_no, - amount_yuan = params.amount as f64 / 100.0, - )) + "#)) } #[derive(Debug, Deserialize)] @@ -226,7 +233,7 @@ pub async fn mock_pay_confirm( ) -> SaasResult> { let status = if form.action == "success" { "success" } else { "failed" }; - if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status).await { + if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status, None).await { tracing::error!("Mock payment callback failed: {}", e); } @@ -249,31 +256,140 @@ body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20 "#))) } -// === 辅助函数 === +// === 回调解析 === -fn parse_alipay_callback(body: &str) -> (Option, String) { - // 简化解析:支付宝回调是 form-urlencoded +/// 解析支付宝回调并验签,返回 (trade_no, status, callback_amount_cents) +fn parse_alipay_callback( + body: &str, + config: &crate::config::PaymentConfig, +) -> SaasResult<(Option, String, Option)> { + // form-urlencoded → key=value 对 + let mut params: Vec<(String, String)> = Vec::new(); for pair in body.split('&') { if let Some((k, v)) = pair.split_once('=') { - if k == "out_trade_no" { - return (Some(urlencoding::decode(v).unwrap_or_default().to_string()), "TRADE_SUCCESS".into()); - } + params.push(( + k.to_string(), + urlencoding::decode(v).unwrap_or_default().to_string(), + )); } } - (None, "unknown".into()) + + let mut trade_no = None; + let mut callback_amount: Option = None; + + // 验签:生产环境强制,开发环境允许跳过 + let is_dev = std::env::var("ZCLAW_SAAS_DEV") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + + if let Some(ref public_key) = config.alipay_public_key { + match super::payment::verify_alipay_callback(¶ms, public_key) { + Ok(true) => {} + Ok(false) => { + tracing::warn!("Alipay callback signature verification FAILED"); + return Err(SaasError::InvalidInput("支付宝回调验签失败".into())); + } + Err(e) => { + tracing::error!("Alipay callback verification error: {}", e); + return Err(SaasError::InvalidInput("支付宝回调验签异常".into())); + } + } + } else if !is_dev { + tracing::error!("Alipay public key not configured in production — rejecting callback"); + return Err(SaasError::InvalidInput("支付宝公钥未配置,无法验签".into())); + } else { + tracing::warn!("Alipay public key not configured (dev mode), skipping signature verification"); + } + + // 提取 trade_no、trade_status 和 total_amount + let mut trade_status = "unknown".to_string(); + for (k, v) in ¶ms { + match k.as_str() { + "out_trade_no" => trade_no = Some(v.clone()), + "trade_status" => trade_status = v.clone(), + "total_amount" => { + // 支付宝金额为元(字符串),转为分(整数) + if let Ok(yuan) = v.parse::() { + callback_amount = Some((yuan * 100.0).round() as i32); + } + } + _ => {} + } + } + + // 支付宝成功状态映射 + let status = if trade_status == "TRADE_SUCCESS" || trade_status == "TRADE_FINISHED" { + "TRADE_SUCCESS" + } else { + &trade_status + }; + + Ok((trade_no, status.to_string(), callback_amount)) } -fn parse_wechat_callback(body: &str) -> (Option, String) { - // 微信回调是 JSON - if let Ok(v) = serde_json::from_str::(body) { - if let Some(event_type) = v.get("event_type").and_then(|t| t.as_str()) { - if event_type == "TRANSACTION.SUCCESS" { - let trade_no = v.pointer("/resource/out_trade_no") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - return (trade_no, "SUCCESS".into()); - } - } +/// 解析微信支付回调,解密 resource 字段,返回 (trade_no, status, callback_amount_cents) +fn parse_wechat_callback( + body: &str, + config: &crate::config::PaymentConfig, +) -> SaasResult<(Option, String, Option)> { + let v: serde_json::Value = serde_json::from_str(body) + .map_err(|e| SaasError::InvalidInput(format!("微信回调 JSON 解析失败: {}", e)))?; + + let event_type = v.get("event_type") + .and_then(|t| t.as_str()) + .unwrap_or(""); + + if event_type != "TRANSACTION.SUCCESS" { + // 非支付成功事件(如退款等),忽略 + return Ok((None, event_type.to_string(), None)); } - (None, "unknown".into()) + + // 解密 resource 字段 + let resource = v.get("resource") + .ok_or_else(|| SaasError::InvalidInput("微信回调缺少 resource 字段".into()))?; + + let ciphertext = resource.get("ciphertext") + .and_then(|v| v.as_str()) + .ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 ciphertext".into()))?; + let nonce = resource.get("nonce") + .and_then(|v| v.as_str()) + .ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 nonce".into()))?; + let associated_data = resource.get("associated_data") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let api_v3_key = config.wechat_api_v3_key.as_deref() + .ok_or_else(|| SaasError::InvalidInput("微信 API v3 密钥未配置,无法解密回调".into()))?; + + let plaintext = super::payment::decrypt_wechat_resource( + ciphertext, nonce, associated_data, api_v3_key, + )?; + + let decrypted: serde_json::Value = serde_json::from_str(&plaintext) + .map_err(|e| SaasError::Internal(format!("微信回调解密内容 JSON 解析失败: {}", e)))?; + + let trade_no = decrypted.get("out_trade_no") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let trade_state = decrypted.get("trade_state") + .and_then(|v| v.as_str()) + .unwrap_or("UNKNOWN"); + + // 微信金额已为分(整数) + let callback_amount = decrypted.get("amount") + .and_then(|a| a.get("total")) + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + Ok((trade_no, trade_state.to_string(), callback_amount)) +} + +/// HTML 转义,防止 XSS 注入 +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") } diff --git a/crates/zclaw-saas/src/billing/mod.rs b/crates/zclaw-saas/src/billing/mod.rs index 6997b24..6b54b1a 100644 --- a/crates/zclaw-saas/src/billing/mod.rs +++ b/crates/zclaw-saas/src/billing/mod.rs @@ -7,6 +7,7 @@ pub mod payment; use axum::routing::{get, post}; +/// 需要认证的计费路由 pub fn routes() -> axum::Router { axum::Router::new() .route("/api/v1/billing/plans", get(handlers::list_plans)) @@ -16,7 +17,11 @@ pub fn routes() -> axum::Router { .route("/api/v1/billing/usage/increment", post(handlers::increment_usage_dimension)) .route("/api/v1/billing/payments", post(handlers::create_payment)) .route("/api/v1/billing/payments/{id}", get(handlers::get_payment_status)) - // 支付回调(无需 auth) +} + +/// 支付回调路由(无需 auth — 支付宝/微信服务器回调) +pub fn callback_routes() -> axum::Router { + axum::Router::new() .route("/api/v1/billing/callback/{method}", post(handlers::payment_callback)) } diff --git a/crates/zclaw-saas/src/billing/payment.rs b/crates/zclaw-saas/src/billing/payment.rs index 9953a41..4f90fd3 100644 --- a/crates/zclaw-saas/src/billing/payment.rs +++ b/crates/zclaw-saas/src/billing/payment.rs @@ -1,21 +1,25 @@ -//! 支付集成 — 支付宝/微信支付 +//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现) //! -//! 开发模式使用 mock 支付,生产模式调用真实支付 API。 -//! 真实集成需要: -//! - 支付宝:alipay-sdk-rust 或 HTTP 直连(支付宝开放平台 v3 API) -//! - 微信支付:wxpay-rust 或 wechat-pay-rs +//! 不依赖第三方 SDK,使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用。 +//! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。 use sqlx::PgPool; +use crate::config::PaymentConfig; use crate::error::{SaasError, SaasResult}; use super::types::*; -/// 创建支付订单 +// ──────────────────────────────────────────────────────────────── +// 公开 API +// ──────────────────────────────────────────────────────────────── + +/// 创建支付订单,返回支付链接/二维码 URL /// -/// 返回支付链接/二维码 URL,前端跳转或展示 +/// 发票和支付记录在事务中创建,确保原子性。 pub async fn create_payment( pool: &PgPool, account_id: &str, req: &CreatePaymentRequest, + config: &PaymentConfig, ) -> SaasResult { // 1. 获取计划信息 let plan = sqlx::query_as::<_, BillingPlan>( @@ -40,7 +44,10 @@ pub async fn create_payment( return Err(SaasError::InvalidInput("已订阅该计划".into())); } - // 2. 创建发票 + // 2. 在事务中创建发票和支付记录 + let mut tx = pool.begin().await + .map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?; + let invoice_id = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now(); let due = now + chrono::Duration::days(1); @@ -58,10 +65,9 @@ pub async fn create_payment( .bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m"))) .bind(due.to_rfc3339()) .bind(now.to_rfc3339()) - .execute(pool) + .execute(&mut *tx) .await?; - // 3. 创建支付记录 let payment_id = uuid::Uuid::new_v4().to_string(); let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]); @@ -75,14 +81,23 @@ pub async fn create_payment( .bind(account_id) .bind(plan.price_cents) .bind(&plan.currency) - .bind(format!("{:?}", req.payment_method).to_lowercase()) + .bind(req.payment_method.to_string()) .bind(&trade_no) .bind(now.to_rfc3339()) - .execute(pool) + .execute(&mut *tx) .await?; - // 4. 生成支付链接 - let pay_url = generate_pay_url(req.payment_method, &trade_no, plan.price_cents, &plan.display_name)?; + tx.commit().await + .map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?; + + // 3. 生成支付链接 + let pay_url = generate_pay_url( + req.payment_method, + &trade_no, + plan.price_cents, + &plan.display_name, + config, + ).await?; Ok(PaymentResult { payment_id, @@ -93,70 +108,108 @@ pub async fn create_payment( } /// 处理支付回调(支付宝/微信异步通知) +/// +/// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。 +/// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。 pub async fn handle_payment_callback( pool: &PgPool, trade_no: &str, status: &str, + callback_amount_cents: Option, ) -> SaasResult<()> { - // 1. 查找支付记录 - let payment: Option<(String, String, String, i32)> = sqlx::query_as::<_, (String, String, String, i32)>( - "SELECT id, invoice_id, account_id, amount_cents \ - FROM billing_payments WHERE external_trade_no = $1 AND status = 'pending'" + // 1. 在事务中锁定支付记录,防止 TOCTOU 竞态 + let mut tx = pool.begin().await + .map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?; + + let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>( + "SELECT id, invoice_id, account_id, amount_cents, status \ + FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE" ) .bind(trade_no) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await?; - let (payment_id, invoice_id, account_id, _amount) = match payment { + let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment { Some(p) => p, None => { - tracing::warn!("Payment callback for unknown/expired trade: {}", trade_no); + tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no)); + tx.rollback().await?; return Ok(()); } }; + // 幂等性:已处理过直接返回 + if current_status != "pending" { + tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status); + tx.rollback().await?; + return Ok(()); + } + + // 2. 金额交叉验证(防篡改) + let is_dev = std::env::var("ZCLAW_SAAS_DEV") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + + if let Some(callback_amount) = callback_amount_cents { + if callback_amount != db_amount { + tracing::error!( + "Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.", + sanitize_log(trade_no), db_amount, callback_amount + ); + tx.rollback().await?; + return Err(SaasError::InvalidInput("回调验证失败".into())); + } + } else if !is_dev { + // 非开发环境必须有金额 + tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no)); + tx.rollback().await?; + return Err(SaasError::InvalidInput("回调缺少金额验证".into())); + } else { + tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no)); + } + let now = chrono::Utc::now().to_rfc3339(); if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" { - // 2. 更新支付状态 + // 3. 更新支付状态 sqlx::query( "UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2" ) .bind(&now) .bind(&payment_id) - .execute(pool) + .execute(&mut *tx) .await?; - // 3. 更新发票状态 + // 4. 更新发票状态 sqlx::query( "UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2" ) .bind(&now) .bind(&invoice_id) - .execute(pool) + .execute(&mut *tx) .await?; - // 4. 获取发票关联的计划 + // 5. 获取发票关联的计划 let plan_id: Option = sqlx::query_scalar( "SELECT plan_id FROM billing_invoices WHERE id = $1" ) .bind(&invoice_id) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await? .flatten(); if let Some(plan_id) = plan_id { - // 5. 取消旧订阅 + // 6. 取消旧订阅 sqlx::query( "UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \ WHERE account_id = $2 AND status IN ('trial', 'active')" ) .bind(&now) .bind(&account_id) - .execute(pool) + .execute(&mut *tx) .await?; - // 6. 创建新订阅(30 天周期) + // 7. 创建新订阅(30 天周期) let sub_id = uuid::Uuid::new_v4().to_string(); let period_end = (chrono::Utc::now() + chrono::Duration::days(30)).to_rfc3339(); let period_start = chrono::Utc::now().to_rfc3339(); @@ -172,7 +225,7 @@ pub async fn handle_payment_callback( .bind(&period_start) .bind(&period_end) .bind(&now) - .execute(pool) + .execute(&mut *tx) .await?; tracing::info!( @@ -180,18 +233,34 @@ pub async fn handle_payment_callback( account_id, plan_id, sub_id ); } + + tx.commit().await + .map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?; } else { - // 支付失败 + // 支付失败:截断 status 防止注入,更新发票为 void + let safe_reason = truncate_str(status, 200); sqlx::query( "UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3" ) - .bind(status) + .bind(&safe_reason) .bind(&now) .bind(&payment_id) - .execute(pool) + .execute(&mut *tx) .await?; - tracing::warn!("Payment failed: trade={}, status={}", trade_no, status); + // 同时将发票标记为 void + sqlx::query( + "UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2" + ) + .bind(&now) + .bind(&invoice_id) + .execute(&mut *tx) + .await?; + + tx.commit().await + .map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?; + + tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason); } Ok(()) @@ -223,44 +292,356 @@ pub async fn query_payment_status( })) } -// === 内部函数 === +// ──────────────────────────────────────────────────────────────── +// 支付 URL 生成 +// ──────────────────────────────────────────────────────────────── -/// 生成支付 URL(开发模式使用 mock) -fn generate_pay_url( +/// 生成支付 URL:根据配置决定 mock 或真实支付 +async fn generate_pay_url( method: PaymentMethod, trade_no: &str, amount_cents: i32, subject: &str, + config: &PaymentConfig, ) -> SaasResult { let is_dev = std::env::var("ZCLAW_SAAS_DEV") .map(|v| v == "true" || v == "1") .unwrap_or(false); if is_dev { - // 开发模式:返回 mock 支付页面 URL - let base = std::env::var("ZCLAW_SAAS_URL") - .unwrap_or_else(|_| "http://localhost:8080".into()); - return Ok(format!( - "{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}", - base, trade_no, amount_cents, - urlencoding::encode(subject), - )); + return Ok(mock_pay_url(trade_no, amount_cents, subject)); } match method { - PaymentMethod::Alipay => { - // TODO: 真实支付宝集成 - // 需要 ALIPAY_APP_ID, ALIPAY_PRIVATE_KEY 等环境变量 - Err(SaasError::InvalidInput( - "支付宝支付集成尚未配置,请联系管理员".into(), - )) + PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config), + PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await, + } +} + +fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String { + let base = std::env::var("ZCLAW_SAAS_URL") + .unwrap_or_else(|_| "http://localhost:8080".into()); + format!( + "{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}", + base, + urlencoding::encode(trade_no), + amount_cents, + urlencoding::encode(subject), + ) +} + +// ──────────────────────────────────────────────────────────────── +// 支付宝 — alipay.trade.page.pay(RSA2 签名 + 证书模式) +// ──────────────────────────────────────────────────────────────── + +fn generate_alipay_url( + trade_no: &str, + amount_cents: i32, + subject: &str, + config: &PaymentConfig, +) -> SaasResult { + let app_id = config.alipay_app_id.as_deref() + .ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?; + let private_key_pem = config.alipay_private_key.as_deref() + .ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?; + let notify_url = config.alipay_notify_url.as_deref() + .ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?; + + // 金额:分 → 元(整数运算避免浮点精度问题) + let yuan_part = amount_cents / 100; + let cent_part = amount_cents % 100; + let amount_yuan = format!("{}.{:02}", yuan_part, cent_part); + + // 构建请求参数(字典序) + let mut params: Vec<(&str, String)> = vec![ + ("app_id", app_id.to_string()), + ("method", "alipay.trade.page.pay".to_string()), + ("charset", "utf-8".to_string()), + ("sign_type", "RSA2".to_string()), + ("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()), + ("version", "1.0".to_string()), + ("notify_url", notify_url.to_string()), + ("biz_content", serde_json::json!({ + "out_trade_no": trade_no, + "total_amount": amount_yuan, + "subject": subject, + "product_code": "FAST_INSTANT_TRADE_PAY", + }).to_string()), + ]; + + // 按 key 字典序排列并拼接 + params.sort_by(|a, b| a.0.cmp(b.0)); + let sign_str: String = params.iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join("&"); + + // RSA2 签名 + let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?; + + // 构建 gateway URL + params.push(("sign", sign)); + let query: String = params.iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v))) + .collect::>() + .join("&"); + + Ok(format!("https://openapi.alipay.com/gateway.do?{}", query)) +} + +// ──────────────────────────────────────────────────────────────── +// 微信支付 — V3 Native Pay(QR 码模式) +// ──────────────────────────────────────────────────────────────── + +async fn generate_wechat_url( + trade_no: &str, + amount_cents: i32, + subject: &str, + config: &PaymentConfig, +) -> SaasResult { + let mch_id = config.wechat_mch_id.as_deref() + .ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?; + let serial_no = config.wechat_serial_no.as_deref() + .ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?; + let private_key_pem = config.wechat_private_key_path.as_deref() + .ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?; + let notify_url = config.wechat_notify_url.as_deref() + .ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?; + let app_id = config.wechat_app_id.as_deref() + .ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?; + + // 读取私钥文件 + let private_key = std::fs::read_to_string(private_key_pem) + .map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?; + + let body = serde_json::json!({ + "appid": app_id, + "mchid": mch_id, + "description": subject, + "out_trade_no": trade_no, + "notify_url": notify_url, + "amount": { + "total": amount_cents, + "currency": "CNY", + }, + }); + let body_str = body.to_string(); + + // 构建签名字符串 + let timestamp = chrono::Utc::now().timestamp().to_string(); + let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", ""); + let sign_message = format!( + "POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n", + timestamp, nonce_str, body_str + ); + + let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?; + + // 构建 Authorization 头 + let auth_header = format!( + "WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"", + mch_id, nonce_str, timestamp, serial_no, signature + ); + + // 发送请求 + let client = reqwest::Client::new(); + let resp = client + .post("https://api.mch.weixin.qq.com/v3/pay/transactions/native") + .header("Content-Type", "application/json") + .header("Authorization", auth_header) + .header("Accept", "application/json") + .body(body_str) + .send() + .await + .map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + tracing::error!("WeChat Pay API error: status={}, body={}", status, text); + return Err(SaasError::InvalidInput(format!( + "微信支付创建订单失败 (HTTP {})", status + ))); + } + + let resp_json: serde_json::Value = resp.json().await + .map_err(|e| SaasError::Internal(format!("微信支付响应解析失败: {}", e)))?; + + let code_url = resp_json.get("code_url") + .and_then(|v| v.as_str()) + .ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))? + .to_string(); + + Ok(code_url) +} + +// ──────────────────────────────────────────────────────────────── +// 回调验签 +// ──────────────────────────────────────────────────────────────── + +/// 验证支付宝回调签名 +pub fn verify_alipay_callback( + params: &[(String, String)], + alipay_public_key_pem: &str, +) -> SaasResult { + // 1. 提取 sign 和 sign_type,剩余参数字典序拼接 + let mut sign = None; + let mut filtered: Vec<(&str, &str)> = Vec::new(); + + for (k, v) in params { + match k.as_str() { + "sign" => sign = Some(v.clone()), + "sign_type" => {} // 跳过 + _ => { + if !v.is_empty() { + filtered.push((k.as_str(), v.as_str())); + } + } } - PaymentMethod::Wechat => { - // TODO: 真实微信支付集成 - // 需要 WECHAT_PAY_MCH_ID, WECHAT_PAY_API_KEY 等 - Err(SaasError::InvalidInput( - "微信支付集成尚未配置,请联系管理员".into(), - )) + } + + let sign = match sign { + Some(s) => s, + None => return Ok(false), + }; + + filtered.sort_by(|a, b| a.0.cmp(b.0)); + let sign_str: String = filtered.iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join("&"); + + // 2. 用支付宝公钥验签 + rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign) +} + +/// 解密微信支付回调 resource 字段(AES-256-GCM) +pub fn decrypt_wechat_resource( + ciphertext_b64: &str, + nonce: &str, + associated_data: &str, + api_v3_key: &str, +) -> SaasResult { + use aes_gcm::{Aes256Gcm, KeyInit, Nonce}; + use aes_gcm::aead::Aead; + use base64::Engine; + + let key_bytes = api_v3_key.as_bytes(); + if key_bytes.len() != 32 { + return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into())); + } + + let nonce_bytes = nonce.as_bytes(); + if nonce_bytes.len() != 12 { + return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into())); + } + + let ciphertext = base64::engine::general_purpose::STANDARD + .decode(ciphertext_b64) + .map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?; + + let cipher = Aes256Gcm::new_from_slice(key_bytes) + .map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?; + let nonce = Nonce::from_slice(nonce_bytes); + + let plaintext = cipher + .decrypt(nonce, aes_gcm::aead::Payload { + msg: &ciphertext, + aad: associated_data.as_bytes(), + }) + .map_err(|e| SaasError::Internal(format!("AES-GCM 解密失败: {}", e)))?; + + String::from_utf8(plaintext) + .map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e))) +} + +// ──────────────────────────────────────────────────────────────── +// RSA 工具函数 +// ──────────────────────────────────────────────────────────────── + +/// SHA256WithRSA 签名 + Base64 编码(PKCS#1 v1.5) +fn rsa_sign_sha256_base64( + private_key_pem: &str, + message: &[u8], +) -> SaasResult { + use rsa::pkcs8::DecodePrivateKey; + use rsa::signature::{Signer, SignatureEncoding}; + use sha2::Sha256; + use rsa::pkcs1v15::SigningKey; + use base64::Engine; + + let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem) + .map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?; + + let signing_key = SigningKey::::new(private_key); + let signature = signing_key.sign(message); + + Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes())) +} + +/// SHA256WithRSA 验签 +fn rsa_verify_sha256( + public_key_pem: &str, + message: &[u8], + signature_b64: &str, +) -> SaasResult { + use rsa::pkcs8::DecodePublicKey; + use rsa::signature::Verifier; + use sha2::Sha256; + use rsa::pkcs1v15::VerifyingKey; + use base64::Engine; + + let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) { + Ok(k) => k, + Err(e) => { + tracing::error!("RSA 公钥解析失败: {}", e); + return Ok(false); + } + }; + + let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) { + Ok(b) => b, + Err(e) => { + tracing::error!("签名 base64 解码失败: {}", e); + return Ok(false); + } + }; + + let verifying_key = VerifyingKey::::new(public_key); + let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) { + Ok(s) => s, + Err(_) => return Ok(false), + }; + + Ok(verifying_key.verify(message, &signature).is_ok()) +} + +// ──────────────────────────────────────────────────────────────── +// 辅助函数 +// ──────────────────────────────────────────────────────────────── + +/// 日志安全:只保留字母数字和 `-` `_`,防止日志注入 +fn sanitize_log(s: &str) -> String { + s.chars() + .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_') + .collect() +} + +/// 截断字符串到指定长度(按字符而非字节) +fn truncate_str(s: &str, max_len: usize) -> String { + let chars: Vec = s.chars().collect(); + if chars.len() <= max_len { + s.to_string() + } else { + chars.into_iter().take(max_len).collect() + } +} + +impl std::fmt::Display for PaymentMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Alipay => write!(f, "alipay"), + Self::Wechat => write!(f, "wechat"), } } } diff --git a/crates/zclaw-saas/src/billing/service.rs b/crates/zclaw-saas/src/billing/service.rs index eccdd57..71561f6 100644 --- a/crates/zclaw-saas/src/billing/service.rs +++ b/crates/zclaw-saas/src/billing/service.rs @@ -17,8 +17,19 @@ pub async fn list_plans(pool: &PgPool) -> SaasResult> { Ok(plans) } -/// 获取单个计划 +/// 获取单个计划(公开 API 只返回 active 计划) pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult> { + let plan = sqlx::query_as::<_, BillingPlan>( + "SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'" + ) + .bind(plan_id) + .fetch_optional(pool) + .await?; + Ok(plan) +} + +/// 获取单个计划(内部使用,不过滤 status,用于已订阅用户查看旧计划) +pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult> { let plan = sqlx::query_as::<_, BillingPlan>( "SELECT * FROM billing_plans WHERE id = $1" ) @@ -47,7 +58,7 @@ pub async fn get_active_subscription( /// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free) pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult { if let Some(sub) = get_active_subscription(pool, account_id).await? { - if let Some(plan) = get_plan(pool, &sub.plan_id).await? { + if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? { return Ok(plan); } } @@ -81,7 +92,7 @@ pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult SaasResult { let now = chrono::Utc::now(); let period_start = now @@ -91,7 +102,7 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult< .with_second(0).unwrap_or(now) .with_nanosecond(0).unwrap_or(now); - // 尝试获取现有记录 + // 先尝试获取已有记录 let existing = sqlx::query_as::<_, UsageQuota>( "SELECT * FROM billing_usage_quotas \ WHERE account_id = $1 AND period_start = $2" @@ -122,13 +133,15 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult< .with_second(0).unwrap_or(now) .with_nanosecond(0).unwrap_or(now); + // 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入) let id = uuid::Uuid::new_v4().to_string(); - let usage = sqlx::query_as::<_, UsageQuota>( + let inserted = sqlx::query_as::<_, UsageQuota>( "INSERT INTO billing_usage_quotas \ (id, account_id, period_start, period_end, \ max_input_tokens, max_output_tokens, max_relay_requests, \ max_hand_executions, max_pipeline_runs) \ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \ + ON CONFLICT (account_id, period_start) DO NOTHING \ RETURNING *" ) .bind(&id) @@ -140,6 +153,20 @@ pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult< .bind(limits.max_relay_requests_monthly) .bind(limits.max_hand_executions_monthly) .bind(limits.max_pipeline_runs_monthly) + .fetch_optional(pool) + .await?; + + if let Some(usage) = inserted { + return Ok(usage); + } + + // ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回 + let usage = sqlx::query_as::<_, UsageQuota>( + "SELECT * FROM billing_usage_quotas \ + WHERE account_id = $1 AND period_start = $2" + ) + .bind(account_id) + .bind(period_start) .fetch_one(pool) .await?; @@ -173,7 +200,7 @@ pub async fn increment_usage( Ok(()) } -/// 增加单一维度用量计数(hand_executions / pipeline_runs / relay_requests) +/// 增加单一维度用量计数(单次 +1) /// /// 使用静态 SQL 分支(白名单),避免动态列名注入风险。 pub async fn increment_dimension( @@ -206,6 +233,40 @@ pub async fn increment_dimension( Ok(()) } +/// 增加单一维度用量计数(批量 +N,原子操作,替代循环调用) +/// +/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。 +pub async fn increment_dimension_by( + pool: &PgPool, + account_id: &str, + dimension: &str, + count: i32, +) -> SaasResult<()> { + let usage = get_or_create_usage(pool, account_id).await?; + + match dimension { + "relay_requests" => { + sqlx::query( + "UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2" + ).bind(count).bind(&usage.id).execute(pool).await?; + } + "hand_executions" => { + sqlx::query( + "UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2" + ).bind(count).bind(&usage.id).execute(pool).await?; + } + "pipeline_runs" => { + sqlx::query( + "UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2" + ).bind(count).bind(&usage.id).execute(pool).await?; + } + _ => return Err(crate::error::SaasError::InvalidInput( + format!("Unknown usage dimension: {}", dimension) + )), + } + Ok(()) +} + /// 检查用量配额 pub async fn check_quota( pool: &PgPool, diff --git a/crates/zclaw-saas/src/cache.rs b/crates/zclaw-saas/src/cache.rs index 6106b2b..6e60efa 100644 --- a/crates/zclaw-saas/src/cache.rs +++ b/crates/zclaw-saas/src/cache.rs @@ -167,6 +167,22 @@ impl AppCache { self.relay_queue_counts.retain(|k, _| db_keys.contains(k)); } + // ============ 快捷查找(Phase 2: 减少关键路径 DB 查询) ============ + + /// 按 model_id 查找已启用的模型。O(1) DashMap 查找。 + pub fn get_model(&self, model_id: &str) -> Option { + self.models.get(model_id) + .filter(|m| m.enabled) + .map(|r| r.value().clone()) + } + + /// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。 + pub fn get_provider(&self, provider_id: &str) -> Option { + self.providers.get(provider_id) + .filter(|p| p.enabled) + .map(|r| r.value().clone()) + } + // ============ 缓存失效 ============ /// 清除 model 缓存中的指定条目(Admin CRUD 后调用) diff --git a/crates/zclaw-saas/src/config.rs b/crates/zclaw-saas/src/config.rs index 258f212..c2f82bd 100644 --- a/crates/zclaw-saas/src/config.rs +++ b/crates/zclaw-saas/src/config.rs @@ -4,9 +4,15 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; use secrecy::SecretString; +/// 当前期望的配置版本 +const CURRENT_CONFIG_VERSION: u32 = 1; + /// SaaS 服务器完整配置 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SaaSConfig { + /// Configuration schema version + #[serde(default = "default_config_version")] + pub config_version: u32, pub server: ServerConfig, pub database: DatabaseConfig, pub auth: AuthConfig, @@ -15,6 +21,8 @@ pub struct SaaSConfig { pub rate_limit: RateLimitConfig, #[serde(default)] pub scheduler: SchedulerConfig, + #[serde(default)] + pub payment: PaymentConfig, } /// Scheduler 定时任务配置 @@ -66,6 +74,30 @@ pub struct ServerConfig { pub struct DatabaseConfig { #[serde(default = "default_db_url")] pub url: String, + /// 连接池最大连接数 + #[serde(default = "default_max_connections")] + pub max_connections: u32, + /// 连接池最小连接数 + #[serde(default = "default_min_connections")] + pub min_connections: u32, + /// 获取连接超时 (秒) + #[serde(default = "default_acquire_timeout")] + pub acquire_timeout_secs: u64, + /// 空闲连接回收超时 (秒) + #[serde(default = "default_idle_timeout")] + pub idle_timeout_secs: u64, + /// 连接最大生命周期 (秒) + #[serde(default = "default_max_lifetime")] + pub max_lifetime_secs: u64, + /// Worker 并发上限 (Semaphore permits) + #[serde(default = "default_worker_concurrency")] + pub worker_concurrency: usize, + /// 限流事件批量 flush 间隔 (秒) + #[serde(default = "default_rate_limit_batch_interval")] + pub rate_limit_batch_interval_secs: u64, + /// 限流事件批量 flush 最大条目数 + #[serde(default = "default_rate_limit_batch_max")] + pub rate_limit_batch_max_size: usize, } /// 认证配置 @@ -97,12 +129,21 @@ pub struct RelayConfig { pub max_attempts: u32, } +fn default_config_version() -> u32 { 1 } fn default_host() -> String { "0.0.0.0".into() } fn default_port() -> u16 { 8080 } fn default_db_url() -> String { "postgres://localhost:5432/zclaw".into() } fn default_jwt_hours() -> i64 { 24 } fn default_totp_issuer() -> String { "ZCLAW SaaS".into() } fn default_refresh_hours() -> i64 { 168 } +fn default_max_connections() -> u32 { 100 } +fn default_min_connections() -> u32 { 5 } +fn default_acquire_timeout() -> u64 { 8 } +fn default_idle_timeout() -> u64 { 180 } +fn default_max_lifetime() -> u64 { 900 } +fn default_worker_concurrency() -> usize { 20 } +fn default_rate_limit_batch_interval() -> u64 { 5 } +fn default_rate_limit_batch_max() -> usize { 500 } fn default_max_queue() -> usize { 1000 } fn default_max_concurrent() -> usize { 5 } fn default_batch_window() -> u64 { 50 } @@ -132,15 +173,115 @@ impl Default for RateLimitConfig { } } +/// 支付配置 +/// +/// 支付宝和微信支付商户配置。所有字段通过环境变量传入(不写入 TOML 文件)。 +/// 字段缺失时自动降级为 mock 支付模式。 +/// +/// 注意:自定义 Debug 和 Serialize 实现会隐藏敏感字段。 +#[derive(Clone, Serialize, Deserialize)] +pub struct PaymentConfig { + /// 支付宝 App ID(来自支付宝开放平台) + #[serde(default)] + pub alipay_app_id: Option, + /// 支付宝商户私钥(RSA2)— 敏感,不序列化 + #[serde(default, skip_serializing)] + pub alipay_private_key: Option, + /// 支付宝公钥证书路径(用于验签) + #[serde(default)] + pub alipay_cert_path: Option, + /// 支付宝回调通知 URL + #[serde(default)] + pub alipay_notify_url: Option, + /// 支付宝公钥(用于回调验签,PEM 格式)— 敏感,不序列化 + #[serde(default, skip_serializing)] + pub alipay_public_key: Option, + + /// 微信支付商户号 + #[serde(default)] + pub wechat_mch_id: Option, + /// 微信支付商户证书序列号 + #[serde(default)] + pub wechat_serial_no: Option, + /// 微信支付商户私钥路径 + #[serde(default)] + pub wechat_private_key_path: Option, + /// 微信支付 API v3 密钥 — 敏感,不序列化 + #[serde(default, skip_serializing)] + pub wechat_api_v3_key: Option, + /// 微信支付回调通知 URL + #[serde(default)] + pub wechat_notify_url: Option, + /// 微信支付 App ID(公众号/小程序) + #[serde(default)] + pub wechat_app_id: Option, +} + +impl std::fmt::Debug for PaymentConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PaymentConfig") + .field("alipay_app_id", &self.alipay_app_id) + .field("alipay_private_key", &self.alipay_private_key.as_ref().map(|_| "***REDACTED***")) + .field("alipay_cert_path", &self.alipay_cert_path) + .field("alipay_notify_url", &self.alipay_notify_url) + .field("alipay_public_key", &self.alipay_public_key.as_ref().map(|_| "***REDACTED***")) + .field("wechat_mch_id", &self.wechat_mch_id) + .field("wechat_serial_no", &self.wechat_serial_no) + .field("wechat_private_key_path", &self.wechat_private_key_path) + .field("wechat_api_v3_key", &self.wechat_api_v3_key.as_ref().map(|_| "***REDACTED***")) + .field("wechat_notify_url", &self.wechat_notify_url) + .field("wechat_app_id", &self.wechat_app_id) + .finish() + } +} + +impl Default for PaymentConfig { + fn default() -> Self { + // 优先从环境变量读取,未配置则降级 mock + Self { + alipay_app_id: std::env::var("ALIPAY_APP_ID").ok(), + alipay_private_key: std::env::var("ALIPAY_PRIVATE_KEY").ok(), + alipay_cert_path: std::env::var("ALIPAY_CERT_PATH").ok(), + alipay_notify_url: std::env::var("ALIPAY_NOTIFY_URL").ok(), + alipay_public_key: std::env::var("ALIPAY_PUBLIC_KEY").ok(), + wechat_mch_id: std::env::var("WECHAT_PAY_MCH_ID").ok(), + wechat_serial_no: std::env::var("WECHAT_PAY_SERIAL_NO").ok(), + wechat_private_key_path: std::env::var("WECHAT_PAY_PRIVATE_KEY_PATH").ok(), + wechat_api_v3_key: std::env::var("WECHAT_PAY_API_V3_KEY").ok(), + wechat_notify_url: std::env::var("WECHAT_PAY_NOTIFY_URL").ok(), + wechat_app_id: std::env::var("WECHAT_PAY_APP_ID").ok(), + } + } +} + +impl PaymentConfig { + /// 支付宝是否已完整配置 + pub fn alipay_configured(&self) -> bool { + self.alipay_app_id.is_some() + && self.alipay_private_key.is_some() + && self.alipay_notify_url.is_some() + } + + /// 微信支付是否已完整配置 + pub fn wechat_configured(&self) -> bool { + self.wechat_mch_id.is_some() + && self.wechat_serial_no.is_some() + && self.wechat_private_key_path.is_some() + && self.wechat_notify_url.is_some() + } +} + impl Default for SaaSConfig { fn default() -> Self { Self { + config_version: 1, server: ServerConfig::default(), database: DatabaseConfig::default(), auth: AuthConfig::default(), relay: RelayConfig::default(), rate_limit: RateLimitConfig::default(), scheduler: SchedulerConfig::default(), + payment: PaymentConfig::default(), } } } @@ -158,7 +299,17 @@ impl Default for ServerConfig { impl Default for DatabaseConfig { fn default() -> Self { - Self { url: default_db_url() } + Self { + url: default_db_url(), + max_connections: default_max_connections(), + min_connections: default_min_connections(), + acquire_timeout_secs: default_acquire_timeout(), + idle_timeout_secs: default_idle_timeout(), + max_lifetime_secs: default_max_lifetime(), + worker_concurrency: default_worker_concurrency(), + rate_limit_batch_interval_secs: default_rate_limit_batch_interval(), + rate_limit_batch_max_size: default_rate_limit_batch_max(), + } } } @@ -220,6 +371,26 @@ impl SaaSConfig { SaaSConfig::default() }; + // 配置版本兼容性检查 + if config.config_version < CURRENT_CONFIG_VERSION { + tracing::warn!( + "[Config] config_version ({}) is below current version ({}). \ + Some features may not work correctly. \ + Please update your saas-config.toml. \ + See docs for migration guide.", + config.config_version, + CURRENT_CONFIG_VERSION + ); + } else if config.config_version > CURRENT_CONFIG_VERSION { + tracing::error!( + "[Config] config_version ({}) is ahead of supported version ({}). \ + This server version may not support all configured features. \ + Consider upgrading the server.", + config.config_version, + CURRENT_CONFIG_VERSION + ); + } + // 环境变量覆盖数据库 URL (避免在配置文件中存储密码) if let Ok(db_url) = std::env::var("ZCLAW_DATABASE_URL") { config.database.url = db_url; diff --git a/crates/zclaw-saas/src/main.rs b/crates/zclaw-saas/src/main.rs index e635745..36e7dab 100644 --- a/crates/zclaw-saas/src/main.rs +++ b/crates/zclaw-saas/src/main.rs @@ -323,6 +323,7 @@ async fn build_router(state: AppState) -> axum::Router { let public_routes = zclaw_saas::auth::routes() .route("/api/health", axum::routing::get(health_handler)) + .merge(zclaw_saas::billing::callback_routes()) .layer(middleware::from_fn_with_state( state.clone(), zclaw_saas::middleware::public_rate_limit_middleware, diff --git a/crates/zclaw-saas/src/model_config/handlers.rs b/crates/zclaw-saas/src/model_config/handlers.rs index 0a2c18e..a38eebb 100644 --- a/crates/zclaw-saas/src/model_config/handlers.rs +++ b/crates/zclaw-saas/src/model_config/handlers.rs @@ -82,6 +82,10 @@ pub async fn create_provider( let provider = service::create_provider(&state.db, &req, &enc_key).await?; log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id, Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?; + // Admin mutation 后立即刷新缓存,消除 60s 陈旧窗口 + if let Err(e) = state.cache.load_from_db(&state.db).await { + tracing::warn!("Cache reload failed after provider.create: {}", e); + } Ok((StatusCode::CREATED, Json(provider))) } @@ -102,6 +106,9 @@ pub async fn update_provider( drop(config); let provider = service::update_provider(&state.db, &id, &req, &enc_key).await?; log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await{ + tracing::warn!("Cache reload failed after provider.update: {}", e); + } Ok(Json(provider)) } @@ -114,6 +121,9 @@ pub async fn delete_provider( check_permission(&ctx, "provider:manage")?; service::delete_provider(&state.db, &id).await?; log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await{ + tracing::warn!("Cache reload failed after provider.delete: {}", e); + } Ok(Json(serde_json::json!({"ok": true}))) } @@ -150,6 +160,9 @@ pub async fn create_model( let model = service::create_model(&state.db, &req).await?; log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id, Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await{ + tracing::warn!("Cache reload failed after model.create: {}", e); + } Ok((StatusCode::CREATED, Json(model))) } @@ -163,6 +176,9 @@ pub async fn update_model( check_permission(&ctx, "model:manage")?; let model = service::update_model(&state.db, &id, &req).await?; log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await{ + tracing::warn!("Cache reload failed after model.update: {}", e); + } Ok(Json(model)) } @@ -175,6 +191,9 @@ pub async fn delete_model( check_permission(&ctx, "model:manage")?; service::delete_model(&state.db, &id).await?; log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, ctx.client_ip.as_deref()).await?; + if let Err(e) = state.cache.load_from_db(&state.db).await{ + tracing::warn!("Cache reload failed after model.delete: {}", e); + } Ok(Json(serde_json::json!({"ok": true}))) } diff --git a/crates/zclaw-saas/src/models/prompt.rs b/crates/zclaw-saas/src/models/prompt.rs index eeee2f1..0a7ff9a 100644 --- a/crates/zclaw-saas/src/models/prompt.rs +++ b/crates/zclaw-saas/src/models/prompt.rs @@ -29,3 +29,12 @@ pub struct PromptVersionRow { pub min_app_version: Option, pub created_at: String, } + +/// prompt_sync_status 表行 +#[derive(Debug, FromRow)] +pub struct PromptSyncStatusRow { + pub device_id: String, + pub template_id: String, + pub synced_version: i32, + pub synced_at: String, +} diff --git a/crates/zclaw-saas/src/models/telemetry.rs b/crates/zclaw-saas/src/models/telemetry.rs index 7c02beb..bce2fc8 100644 --- a/crates/zclaw-saas/src/models/telemetry.rs +++ b/crates/zclaw-saas/src/models/telemetry.rs @@ -2,6 +2,24 @@ use sqlx::FromRow; +/// telemetry_reports 表行 +#[derive(Debug, FromRow)] +pub struct TelemetryReportRow { + pub id: String, + pub account_id: String, + pub device_id: String, + pub app_version: Option, + pub model_id: String, + pub input_tokens: i64, + pub output_tokens: i64, + pub latency_ms: Option, + pub success: bool, + pub error_type: Option, + pub connection_mode: Option, + pub reported_at: String, + pub created_at: String, +} + /// telemetry 按 model 分组统计 #[derive(Debug, FromRow)] pub struct TelemetryModelStatsRow { diff --git a/crates/zclaw-saas/src/prompt/service.rs b/crates/zclaw-saas/src/prompt/service.rs index 24c4c3b..e8495d7 100644 --- a/crates/zclaw-saas/src/prompt/service.rs +++ b/crates/zclaw-saas/src/prompt/service.rs @@ -4,7 +4,7 @@ use sqlx::PgPool; use crate::error::{SaasError, SaasResult}; use crate::common::PaginatedResponse; use crate::common::normalize_pagination; -use crate::models::{PromptTemplateRow, PromptVersionRow}; +use crate::models::{PromptTemplateRow, PromptVersionRow, PromptSyncStatusRow}; use super::types::*; /// 创建提示词模板 + 初始版本 @@ -310,3 +310,21 @@ pub async fn check_updates( server_time: chrono::Utc::now().to_rfc3339(), }) } + +/// 查询设备的提示词同步状态 +pub async fn get_sync_status( + db: &PgPool, + device_id: &str, +) -> SaasResult> { + let rows = sqlx::query_as::<_, PromptSyncStatusRow>( + "SELECT device_id, template_id, synced_version, synced_at \ + FROM prompt_sync_status \ + WHERE device_id = $1 \ + ORDER BY synced_at DESC \ + LIMIT 50" + ) + .bind(device_id) + .fetch_all(db) + .await?; + Ok(rows) +} diff --git a/crates/zclaw-saas/src/relay/key_pool.rs b/crates/zclaw-saas/src/relay/key_pool.rs index d4eb6c5..ebd718a 100644 --- a/crates/zclaw-saas/src/relay/key_pool.rs +++ b/crates/zclaw-saas/src/relay/key_pool.rs @@ -281,6 +281,39 @@ pub async fn delete_provider_key( Ok(()) } +/// Key 使用窗口统计 +#[derive(Debug, Clone)] +pub struct KeyUsageStats { + pub key_id: String, + pub window_minute: String, + pub request_count: i32, + pub token_count: i64, +} + +/// 查询指定 Key 的最近使用窗口统计 +pub async fn get_key_usage_stats( + db: &PgPool, + key_id: &str, + limit: i64, +) -> SaasResult> { + let limit = limit.min(60).max(1); + let rows: Vec<(String, String, i32, i64)> = sqlx::query_as( + "SELECT key_id, window_minute, request_count, token_count \ + FROM key_usage_window \ + WHERE key_id = $1 \ + ORDER BY window_minute DESC \ + LIMIT $2" + ) + .bind(key_id) + .bind(limit) + .fetch_all(db) + .await?; + + Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| { + KeyUsageStats { key_id, window_minute, request_count, token_count } + }).collect()) +} + /// 解析冷却剩余时间(秒) fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 { let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until); diff --git a/crates/zclaw-saas/src/relay/service.rs b/crates/zclaw-saas/src/relay/service.rs index 399e705..c3fe5a1 100644 --- a/crates/zclaw-saas/src/relay/service.rs +++ b/crates/zclaw-saas/src/relay/service.rs @@ -2,11 +2,23 @@ use sqlx::PgPool; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Mutex; use crate::error::{SaasError, SaasResult}; use crate::models::RelayTaskRow; use super::types::*; +// ============ StreamBridge 背压常量 ============ + +/// 上游无数据时,发送 SSE 心跳注释行的间隔 +const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15); + +/// 上游无数据时,丢弃连接的超时阈值 +const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(30); + +/// 流结束后延迟清理的时间窗口 +const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(60); + /// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429) fn is_retryable_status(status: u16) -> bool { status == 429 || (500..600).contains(&status) @@ -33,15 +45,24 @@ pub async fn create_relay_task( let request_hash = hash_request(request_body); let max_attempts = max_attempts.max(1).min(5); - sqlx::query( + // INSERT ... RETURNING 合并两次 DB 往返为一次 + let row: RelayTaskRow = sqlx::query_as( "INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at) - VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)" + VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9) + RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at" ) .bind(&id).bind(account_id).bind(provider_id).bind(model_id) .bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now) - .execute(db).await?; + .fetch_one(db) + .await?; - get_relay_task(db, &id).await + Ok(RelayTaskInfo { + id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id, + status: row.status, priority: row.priority, attempt_count: row.attempt_count, + max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens, + error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at, + completed_at: row.completed_at, created_at: row.created_at, + }) } pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult { @@ -295,9 +316,9 @@ pub async fn execute_relay( } }); - // Convert mpsc::Receiver into a Body stream - let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx); - let body = axum::body::Body::from_stream(body_stream); + // Build StreamBridge: wraps the bounded receiver with heartbeat, + // timeout, and delayed cleanup (DeerFlow-inspired backpressure). + let body = build_stream_bridge(rx, task_id.to_string()); // SSE 流结束后异步记录 usage + Key 使用量 // 使用全局 Arc 限制并发 spawned tasks,防止高并发时耗尽连接池 @@ -335,6 +356,14 @@ pub async fn execute_relay( if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() { tracing::warn!("SSE usage recording timed out for task {}", task_id_clone); } + + // StreamBridge 延迟清理:流结束 60s 后释放残留资源 + // (主要是 Arc 等,通过 drop(_permit) 归还信号量) + tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await; + tracing::debug!( + "[StreamBridge] Cleanup delay elapsed for task {}", + task_id_clone + ); }); return Ok(RelayResponse::Sse(body)); @@ -346,7 +375,9 @@ pub async fn execute_relay( // 记录 Key 使用量 let _ = super::key_pool::record_key_usage( db, &key_id, Some(input_tokens + output_tokens), - ).await; + ).await.map_err(|e| { + tracing::warn!("[Relay] Failed to record key usage for billing: {}", e); + }); return Ok(RelayResponse::Json(body)); } } @@ -423,6 +454,98 @@ pub enum RelayResponse { Sse(axum::body::Body), } +// ============ StreamBridge ============ + +/// 构建 StreamBridge:将 mpsc::Receiver 包装为带心跳、超时的 axum Body。 +/// +/// 借鉴 DeerFlow StreamBridge 背压机制: +/// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃 +/// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流 +/// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源 +fn build_stream_bridge( + mut rx: tokio::sync::mpsc::Receiver>, + task_id: String, +) -> axum::body::Body { + // SSE heartbeat comment bytes: `: heartbeat\n\n` + // SSE spec: lines starting with `:` are comments and ignored by clients + const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n"; + // SSE timeout error event + const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n"; + + let stream = async_stream::stream! { + // Track how many consecutive heartbeat-only cycles have elapsed. + // Real data resets this counter; after 2 heartbeats (30s) without + // real data, we terminate the stream. + let mut idle_heartbeats: u32 = 0; + + loop { + // tokio::select! races the next data chunk against a heartbeat timer. + // The timer resets on every iteration, ensuring heartbeats only fire + // during genuine idle periods. + tokio::select! { + biased; // prioritize data over heartbeat + + chunk = rx.recv() => { + match chunk { + Some(Ok(data)) => { + // Real data received — reset idle counter + idle_heartbeats = 0; + yield Ok::(data); + } + Some(Err(e)) => { + tracing::warn!( + "[StreamBridge] Upstream error for task {}: {}", + task_id, e + ); + yield Err(e); + break; + } + None => { + // Channel closed = upstream finished normally + tracing::debug!( + "[StreamBridge] Upstream completed for task {}", + task_id + ); + break; + } + } + } + + // Heartbeat: send SSE comment if no data for 15s + _ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => { + idle_heartbeats += 1; + tracing::trace!( + "[StreamBridge] Heartbeat #{} for task {} (idle {}s)", + idle_heartbeats, + task_id, + idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(), + ); + + // After 2 consecutive heartbeats without real data (30s), + // terminate the stream to prevent connection leaks. + if idle_heartbeats >= 2 { + tracing::warn!( + "[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}", + STREAMBRIDGE_TIMEOUT, + task_id, + ); + yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT)); + break; + } + + yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES)); + } + } + } + }; + + // Pin the stream to a Box to satisfy Body::from_stream + let boxed: std::pin::Pin> + Send>> = + Box::pin(stream); + + axum::body::Body::from_stream(boxed) +} + // ============ Helpers ============ fn hash_request(body: &str) -> String { diff --git a/crates/zclaw-saas/src/scheduled_task/service.rs b/crates/zclaw-saas/src/scheduled_task/service.rs index 3acd247..501c6aa 100644 --- a/crates/zclaw-saas/src/scheduled_task/service.rs +++ b/crates/zclaw-saas/src/scheduled_task/service.rs @@ -20,7 +20,9 @@ struct ScheduledTaskRow { last_run_at: Option, next_run_at: Option, run_count: i32, + last_result: Option, last_error: Option, + last_duration_ms: Option, input_payload: Option, created_at: String, } @@ -41,7 +43,9 @@ impl ScheduledTaskRow { last_run: self.last_run_at.clone(), next_run: self.next_run_at.clone(), run_count: self.run_count, + last_result: self.last_result.clone(), last_error: self.last_error.clone(), + last_duration_ms: self.last_duration_ms, created_at: self.created_at.clone(), } } @@ -86,7 +90,9 @@ pub async fn create_task( last_run: None, next_run: None, run_count: 0, + last_result: None, last_error: None, + last_duration_ms: None, created_at: now, }) } @@ -99,7 +105,7 @@ pub async fn list_tasks( let rows: Vec = sqlx::query_as( "SELECT id, account_id, name, description, schedule, schedule_type, target_type, target_id, enabled, last_run_at, next_run_at, - run_count, last_error, input_payload, created_at + run_count, last_result, last_error, last_duration_ms, input_payload, created_at FROM scheduled_tasks WHERE account_id = $1 ORDER BY created_at DESC" ) .bind(account_id) @@ -118,7 +124,7 @@ pub async fn get_task( let row: Option = sqlx::query_as( "SELECT id, account_id, name, description, schedule, schedule_type, target_type, target_id, enabled, last_run_at, next_run_at, - run_count, last_error, input_payload, created_at + run_count, last_result, last_error, last_duration_ms, input_payload, created_at FROM scheduled_tasks WHERE id = $1 AND account_id = $2" ) .bind(task_id) diff --git a/crates/zclaw-saas/src/scheduled_task/types.rs b/crates/zclaw-saas/src/scheduled_task/types.rs index f680c87..9401288 100644 --- a/crates/zclaw-saas/src/scheduled_task/types.rs +++ b/crates/zclaw-saas/src/scheduled_task/types.rs @@ -58,6 +58,8 @@ pub struct ScheduledTaskResponse { pub last_run: Option, pub next_run: Option, pub run_count: i32, + pub last_result: Option, pub last_error: Option, + pub last_duration_ms: Option, pub created_at: String, } diff --git a/crates/zclaw-saas/src/scheduler.rs b/crates/zclaw-saas/src/scheduler.rs index b795189..f4af561 100644 --- a/crates/zclaw-saas/src/scheduler.rs +++ b/crates/zclaw-saas/src/scheduler.rs @@ -3,11 +3,18 @@ //! 通过 TOML 配置定时任务,无需改代码调整调度时间。 //! 配置格式在 config.rs 的 SchedulerConfig / JobConfig 中定义。 -use std::time::Duration; +use std::time::{Duration, Instant}; use sqlx::PgPool; use crate::config::SchedulerConfig; use crate::workers::WorkerDispatcher; +/// 单次任务执行的产出 +struct TaskExecution { + result: Option, + error: Option, + duration_ms: i64, +} + /// 解析时间间隔字符串为 Duration pub fn parse_duration(s: &str) -> Result { let s = s.trim().to_lowercase(); @@ -143,23 +150,42 @@ pub fn start_user_task_scheduler(db: PgPool) { }); } -/// 执行单个调度任务 +/// 执行单个调度任务,返回执行产出(结果/错误/耗时) async fn execute_scheduled_task( db: &PgPool, task_id: &str, target_type: &str, -) -> Result<(), Box> { - let task_info: Option<(String, Option)> = sqlx::query_as( +) -> TaskExecution { + let start = Instant::now(); + + let task_info: Option<(String, Option)> = match sqlx::query_as( "SELECT name, config_json FROM scheduled_tasks WHERE id = $1" ) .bind(task_id) .fetch_optional(db) .await - .map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?; + { + Ok(info) => info, + Err(e) => { + let elapsed = start.elapsed().as_millis() as i64; + return TaskExecution { + result: None, + error: Some(format!("Failed to fetch task {}: {}", task_id, e)), + duration_ms: elapsed, + }; + } + }; let (task_name, _config_json) = match task_info { Some(info) => info, - None => return Err(format!("Task {} not found", task_id).into()), + None => { + let elapsed = start.elapsed().as_millis() as i64; + return TaskExecution { + result: None, + error: Some(format!("Task {} not found", task_id)), + duration_ms: elapsed, + }; + } }; tracing::info!( @@ -167,22 +193,39 @@ async fn execute_scheduled_task( task_name, target_type ); - match target_type { + let exec_result = match target_type { t if t == "agent" => { tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name); + Ok("agent_dispatched".to_string()) } t if t == "hand" => { tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name); + Ok("hand_dispatched".to_string()) } t if t == "workflow" => { tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name); + Ok("workflow_dispatched".to_string()) } other => { tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name); + Err(format!("Unknown target_type: {}", other)) } - } + }; - Ok(()) + let elapsed = start.elapsed().as_millis() as i64; + + match exec_result { + Ok(msg) => TaskExecution { + result: Some(msg), + error: None, + duration_ms: elapsed, + }, + Err(err) => TaskExecution { + result: None, + error: Some(err), + duration_ms: elapsed, + }, + } } async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> { @@ -206,17 +249,19 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> { task_id, target_type, schedule_type ); - // 执行任务 - match execute_scheduled_task(db, &task_id, &target_type).await { - Ok(()) => { - tracing::info!("[UserScheduler] task {} executed successfully", task_id); - } - Err(e) => { - tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e); - } + // 执行任务并收集产出 + let exec = execute_scheduled_task(db, &task_id, &target_type).await; + + if let Some(ref err) = exec.error { + tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, err); + } else { + tracing::info!( + "[UserScheduler] task {} executed successfully ({}ms)", + task_id, exec.duration_ms + ); } - // 更新任务状态 + // 更新任务状态(含执行产出) let result = sqlx::query( "UPDATE scheduled_tasks SET last_run_at = NOW(), @@ -228,10 +273,16 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> { WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL THEN NOW() + (interval_seconds || ' seconds')::INTERVAL ELSE NULL - END + END, + last_result = $2, + last_error = $3, + last_duration_ms = $4 WHERE id = $1" ) .bind(&task_id) + .bind(&exec.result) + .bind(&exec.error) + .bind(exec.duration_ms) .execute(db) .await; diff --git a/crates/zclaw-saas/src/state.rs b/crates/zclaw-saas/src/state.rs index 5d9280d..6a1d369 100644 --- a/crates/zclaw-saas/src/state.rs +++ b/crates/zclaw-saas/src/state.rs @@ -10,6 +10,44 @@ use crate::config::SaaSConfig; use crate::workers::WorkerDispatcher; use crate::cache::AppCache; +// ============ SpawnLimiter ============ + +/// 可复用的并发限制器,基于 Arc。 +/// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。 +#[derive(Clone)] +pub struct SpawnLimiter { + semaphore: Arc, + name: &'static str, +} + +impl SpawnLimiter { + pub fn new(name: &'static str, max_permits: usize) -> Self { + Self { + semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)), + name, + } + } + + /// 尝试获取 permit,满时返回 None(适用于可丢弃的操作如 usage 记录) + pub fn try_acquire(&self) -> Option { + self.semaphore.clone().try_acquire_owned().ok() + } + + /// 异步等待 permit(适用于不可丢弃的操作如 Worker 任务) + pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit { + self.semaphore + .clone() + .acquire_owned() + .await + .expect("SpawnLimiter semaphore closed unexpectedly") + } + + pub fn name(&self) -> &'static str { self.name } + pub fn available(&self) -> usize { self.semaphore.available_permits() } +} + +// ============ AppState ============ + /// 全局应用状态,通过 Axum State 共享 #[derive(Clone)] pub struct AppState { @@ -33,10 +71,20 @@ pub struct AppState { pub shutdown_token: CancellationToken, /// 应用缓存: Model/Provider/队列计数器 pub cache: AppCache, + /// Worker spawn 并发限制器 + pub worker_limiter: SpawnLimiter, + /// 限流事件批量累加器: key → 待写入计数 + pub rate_limit_batch: Arc>, } impl AppState { - pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken) -> anyhow::Result { + pub fn new( + db: PgPool, + config: SaaSConfig, + worker_dispatcher: WorkerDispatcher, + shutdown_token: CancellationToken, + worker_limiter: SpawnLimiter, + ) -> anyhow::Result { let jwt_secret = config.jwt_secret()?; let rpm = config.rate_limit.requests_per_minute; Ok(Self { @@ -50,6 +98,8 @@ impl AppState { worker_dispatcher, shutdown_token, cache: AppCache::new(), + worker_limiter, + rate_limit_batch: Arc::new(dashmap::DashMap::new()), }) } @@ -96,4 +146,60 @@ impl AppState { tracing::warn!("Failed to dispatch log_operation: {}", e); } } + + /// 限流事件批量 flush 到 DB + /// + /// 使用 swap-to-zero 模式:先将计数器原子归零,DB 写入成功后删除条目。 + /// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。 + pub async fn flush_rate_limit_batch(&self, max_batch: usize) { + // 阶段1: 收集非零 key,将计数器原子归零(而非删除) + // 这样如果 DB 写入失败,middleware 的新累加会在已有 key 上继续 + let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64)); + + let keys: Vec = self.rate_limit_batch.iter() + .filter(|e| *e.value() > 0) + .take(max_batch) + .map(|e| e.key().clone()) + .collect(); + + for key in &keys { + // 原子交换为 0,取走当前值 + if let Some(mut entry) = self.rate_limit_batch.get_mut(key) { + if *entry > 0 { + batch.push((key.clone(), *entry)); + *entry = 0; // 归零而非删除 + } + } + } + + if batch.is_empty() { return; } + + let keys_buf: Vec = batch.iter().map(|(k, _)| k.clone()).collect(); + let counts: Vec = batch.iter().map(|(_, c)| *c).collect(); + + let result = sqlx::query( + "INSERT INTO rate_limit_events (key, window_start, count) + SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)" + ) + .bind(&keys_buf) + .bind(&counts) + .execute(&self.db) + .await; + + if let Err(e) = result { + // DB 写入失败:将归零的计数加回去,避免数据丢失 + tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e); + for (key, count) in &batch { + if let Some(mut entry) = self.rate_limit_batch.get_mut(key) { + *entry += *count; + } + } + } else { + // DB 写入成功:删除已归零的条目 + for (key, _) in &batch { + self.rate_limit_batch.remove_if(key, |_, v| *v == 0); + } + tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len()); + } + } } diff --git a/crates/zclaw-saas/src/telemetry/service.rs b/crates/zclaw-saas/src/telemetry/service.rs index c5def32..3f5c25b 100644 --- a/crates/zclaw-saas/src/telemetry/service.rs +++ b/crates/zclaw-saas/src/telemetry/service.rs @@ -2,7 +2,7 @@ use sqlx::PgPool; use crate::error::SaasResult; -use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow}; +use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow, TelemetryReportRow}; use super::types::*; const CHUNK_SIZE: usize = 100; @@ -270,3 +270,27 @@ pub async fn get_daily_stats( Ok(stats) } + +/// 查询账号最近的遥测报告 +pub async fn get_recent_reports( + db: &PgPool, + account_id: &str, + limit: i64, +) -> SaasResult> { + let limit = limit.min(100).max(1); + let rows = sqlx::query_as::<_, TelemetryReportRow>( + "SELECT id, account_id, device_id, app_version, model_id, \ + input_tokens, output_tokens, latency_ms, success, \ + error_type, connection_mode, \ + reported_at::text, created_at::text \ + FROM telemetry_reports \ + WHERE account_id = $1 \ + ORDER BY reported_at DESC \ + LIMIT $2" + ) + .bind(account_id) + .bind(limit) + .fetch_all(db) + .await?; + Ok(rows) +} diff --git a/crates/zclaw-saas/src/workers/generate_embedding.rs b/crates/zclaw-saas/src/workers/generate_embedding.rs index db763fd..e22b14d 100644 --- a/crates/zclaw-saas/src/workers/generate_embedding.rs +++ b/crates/zclaw-saas/src/workers/generate_embedding.rs @@ -44,13 +44,7 @@ impl Worker for GenerateEmbeddingWorker { } }; - // 2. 删除旧分块(full refresh on each update) - sqlx::query("DELETE FROM knowledge_chunks WHERE item_id = $1") - .bind(&args.item_id) - .execute(db) - .await?; - - // 3. 分块 + // 2. 分块 let chunks = crate::knowledge::service::chunk_content(&content, 512, 64); if chunks.is_empty() { @@ -58,13 +52,32 @@ impl Worker for GenerateEmbeddingWorker { return Ok(()); } - // 4. 写入分块(带关键词继承) + // 3. 在事务中删除旧分块 + 插入新分块(防止并发竞争条件) + let mut tx = db.begin().await?; + + // 锁定条目行防止并发 worker 同时处理同一条目 + let locked: Option<(String,)> = sqlx::query_as( + "SELECT id FROM knowledge_items WHERE id = $1 FOR UPDATE" + ) + .bind(&args.item_id) + .fetch_optional(&mut *tx) + .await?; + + if locked.is_none() { + tx.rollback().await?; + tracing::warn!("GenerateEmbedding: item {} was deleted during processing", args.item_id); + return Ok(()); + } + + sqlx::query("DELETE FROM knowledge_chunks WHERE item_id = $1") + .bind(&args.item_id) + .execute(&mut *tx) + .await?; + for (idx, chunk) in chunks.iter().enumerate() { let chunk_id = uuid::Uuid::new_v4().to_string(); - // 为每个 chunk 提取额外关键词(简单策略:标题 + 继承关键词) let mut chunk_keywords = keywords.clone(); - // 从 chunk 内容提取高频词作为补充关键词 extract_chunk_keywords(chunk, &mut chunk_keywords); sqlx::query( @@ -76,10 +89,12 @@ impl Worker for GenerateEmbeddingWorker { .bind(idx as i32) .bind(chunk) .bind(&chunk_keywords) - .execute(db) + .execute(&mut *tx) .await?; } + tx.commit().await?; + tracing::info!( "GenerateEmbedding: item '{}' → {} chunks (keywords: {})", title, diff --git a/crates/zclaw-saas/src/workers/update_last_used.rs b/crates/zclaw-saas/src/workers/update_last_used.rs index 4b0f09f..292223e 100644 --- a/crates/zclaw-saas/src/workers/update_last_used.rs +++ b/crates/zclaw-saas/src/workers/update_last_used.rs @@ -8,7 +8,8 @@ use super::Worker; #[derive(Debug, Serialize, Deserialize)] pub struct UpdateLastUsedArgs { - pub token_id: String, + /// token_hash 用于 WHERE 条件匹配 + pub token_hash: String, } pub struct UpdateLastUsedWorker; @@ -23,9 +24,9 @@ impl Worker for UpdateLastUsedWorker { async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> { let now = chrono::Utc::now().to_rfc3339(); - sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE id = $2") + sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2") .bind(&now) - .bind(&args.token_id) + .bind(&args.token_hash) .execute(db) .await?; Ok(()) diff --git a/desktop/src-tauri/src/classroom_commands/chat.rs b/desktop/src-tauri/src/classroom_commands/chat.rs new file mode 100644 index 0000000..6c13f58 --- /dev/null +++ b/desktop/src-tauri/src/classroom_commands/chat.rs @@ -0,0 +1,223 @@ +//! Classroom multi-agent chat commands +//! +//! - `classroom_chat` — send a message and receive multi-agent responses +//! - `classroom_chat_history` — retrieve chat history for a classroom + +use std::sync::Arc; +use tokio::sync::Mutex; +use serde::{Deserialize, Serialize}; +use tauri::State; + +use zclaw_kernel::generation::{ + AgentProfile, AgentRole, + ClassroomChatMessage, ClassroomChatState, + ClassroomChatRequest, + build_chat_prompt, parse_chat_responses, +}; +use zclaw_runtime::CompletionRequest; + +use super::ClassroomStore; +use crate::kernel_commands::KernelState; + +// --------------------------------------------------------------------------- +// State +// --------------------------------------------------------------------------- + +/// Chat state store: classroom_id → chat state +pub type ChatStore = Arc>>; + +pub fn create_chat_state() -> ChatStore { + Arc::new(Mutex::new(std::collections::HashMap::new())) +} + +// --------------------------------------------------------------------------- +// Request / Response +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClassroomChatCmdRequest { + pub classroom_id: String, + pub user_message: String, + pub scene_context: Option, +} + +// --------------------------------------------------------------------------- +// Commands +// --------------------------------------------------------------------------- + +/// Send a message in the classroom chat and get multi-agent responses. +#[tauri::command] +pub async fn classroom_chat( + store: State<'_, ClassroomStore>, + chat_store: State<'_, ChatStore>, + kernel_state: State<'_, KernelState>, + request: ClassroomChatCmdRequest, +) -> Result, String> { + if request.user_message.trim().is_empty() { + return Err("Message cannot be empty".to_string()); + } + + // Get classroom data + let classroom = { + let s = store.lock().await; + s.get(&request.classroom_id) + .cloned() + .ok_or_else(|| format!("Classroom '{}' not found", request.classroom_id))? + }; + + // Create user message + let user_msg = ClassroomChatMessage::user_message(&request.user_message); + + // Get chat history for context + let history: Vec = { + let cs = chat_store.lock().await; + cs.get(&request.classroom_id) + .map(|s| s.messages.clone()) + .unwrap_or_default() + }; + + // Try LLM-powered multi-agent responses, fallback to placeholder + let agent_responses = match generate_llm_responses(&kernel_state, &classroom.agents, &request.user_message, request.scene_context.as_deref(), &history).await { + Ok(responses) => responses, + Err(e) => { + tracing::warn!("LLM chat generation failed, using placeholders: {}", e); + generate_placeholder_responses( + &classroom.agents, + &request.user_message, + request.scene_context.as_deref(), + ) + } + }; + + // Store in chat state + { + let mut cs = chat_store.lock().await; + let state = cs.entry(request.classroom_id.clone()) + .or_insert_with(|| ClassroomChatState { + messages: vec![], + active: true, + }); + + state.messages.push(user_msg); + state.messages.extend(agent_responses.clone()); + } + + Ok(agent_responses) +} + +/// Retrieve chat history for a classroom +#[tauri::command] +pub async fn classroom_chat_history( + chat_store: State<'_, ChatStore>, + classroom_id: String, +) -> Result, String> { + let cs = chat_store.lock().await; + Ok(cs.get(&classroom_id) + .map(|s| s.messages.clone()) + .unwrap_or_default()) +} + +// --------------------------------------------------------------------------- +// Placeholder response generation +// --------------------------------------------------------------------------- + +fn generate_placeholder_responses( + agents: &[AgentProfile], + user_message: &str, + scene_context: Option<&str>, +) -> Vec { + let mut responses = Vec::new(); + + // Teacher always responds + if let Some(teacher) = agents.iter().find(|a| a.role == AgentRole::Teacher) { + let context_hint = scene_context + .map(|ctx| format!("关于「{}」,", ctx)) + .unwrap_or_default(); + + responses.push(ClassroomChatMessage::agent_message( + teacher, + &format!("{}这是一个很好的问题!让我来详细解释一下「{}」的核心概念...", context_hint, user_message), + )); + } + + // Assistant chimes in + if let Some(assistant) = agents.iter().find(|a| a.role == AgentRole::Assistant) { + responses.push(ClassroomChatMessage::agent_message( + assistant, + "我来补充一下要点 📌", + )); + } + + // One student responds + if let Some(student) = agents.iter().find(|a| a.role == AgentRole::Student) { + responses.push(ClassroomChatMessage::agent_message( + student, + &format!("谢谢老师!我大概理解了{}", user_message), + )); + } + + responses +} + +// --------------------------------------------------------------------------- +// LLM-powered response generation +// --------------------------------------------------------------------------- + +async fn generate_llm_responses( + kernel_state: &State<'_, KernelState>, + agents: &[AgentProfile], + user_message: &str, + scene_context: Option<&str>, + history: &[ClassroomChatMessage], +) -> Result, String> { + let driver = { + let ks = kernel_state.lock().await; + ks.as_ref() + .map(|k| k.driver()) + .ok_or_else(|| "Kernel not initialized".to_string())? + }; + + if !driver.is_configured() { + return Err("LLM driver not configured".to_string()); + } + + // Build the chat request for prompt generation (include history) + let chat_request = ClassroomChatRequest { + classroom_id: String::new(), + user_message: user_message.to_string(), + agents: agents.to_vec(), + scene_context: scene_context.map(|s| s.to_string()), + history: history.to_vec(), + }; + + let prompt = build_chat_prompt(&chat_request); + + let request = CompletionRequest { + model: "default".to_string(), + system: Some("你是一个课堂多智能体讨论的协调器。".to_string()), + messages: vec![zclaw_types::Message::User { + content: prompt, + }], + ..Default::default() + }; + + let response = driver.complete(request).await + .map_err(|e| format!("LLM call failed: {}", e))?; + + // Extract text from response + let text = response.content.iter() + .filter_map(|block| match block { + zclaw_runtime::ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join(""); + + let responses = parse_chat_responses(&text, agents); + if responses.is_empty() { + return Err("LLM returned no parseable agent responses".to_string()); + } + + Ok(responses) +} diff --git a/desktop/src-tauri/src/classroom_commands/export.rs b/desktop/src-tauri/src/classroom_commands/export.rs new file mode 100644 index 0000000..bbf3f92 --- /dev/null +++ b/desktop/src-tauri/src/classroom_commands/export.rs @@ -0,0 +1,152 @@ +//! Classroom export commands +//! +//! - `classroom_export` — export classroom as HTML, Markdown, or JSON + +use serde::{Deserialize, Serialize}; +use tauri::State; + +use zclaw_kernel::generation::Classroom; + +use super::ClassroomStore; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClassroomExportRequest { + pub classroom_id: String, + pub format: String, // "html" | "markdown" | "json" +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClassroomExportResponse { + pub content: String, + pub filename: String, + pub mime_type: String, +} + +// --------------------------------------------------------------------------- +// Command +// --------------------------------------------------------------------------- + +#[tauri::command] +pub async fn classroom_export( + store: State<'_, ClassroomStore>, + request: ClassroomExportRequest, +) -> Result { + let classroom = { + let s = store.lock().await; + s.get(&request.classroom_id) + .cloned() + .ok_or_else(|| format!("Classroom '{}' not found", request.classroom_id))? + }; + + match request.format.as_str() { + "json" => export_json(&classroom), + "html" => export_html(&classroom), + "markdown" | "md" => export_markdown(&classroom), + _ => Err(format!("Unsupported export format: '{}'. Use html, markdown, or json.", request.format)), + } +} + +// --------------------------------------------------------------------------- +// Exporters +// --------------------------------------------------------------------------- + +fn export_json(classroom: &Classroom) -> Result { + let content = serde_json::to_string_pretty(classroom) + .map_err(|e| format!("JSON serialization failed: {}", e))?; + + Ok(ClassroomExportResponse { + filename: format!("{}.json", sanitize_filename(&classroom.title)), + content, + mime_type: "application/json".to_string(), + }) +} + +fn export_html(classroom: &Classroom) -> Result { + let mut html = String::from(r#""#); + html.push_str(&format!("{}", html_escape(&classroom.title))); + html.push_str(r#""#); + + html.push_str(&format!("

{}

", html_escape(&classroom.title))); + html.push_str(&format!("

{}

", html_escape(&classroom.description))); + + // Agents + html.push_str("

课堂角色

"); + for agent in &classroom.agents { + html.push_str(&format!( + r#"{} {}"#, + agent.color, agent.avatar, html_escape(&agent.name) + )); + } + html.push_str("
"); + + // Scenes + html.push_str("

课程内容

"); + for scene in &classroom.scenes { + let type_class = match scene.content.scene_type { + zclaw_kernel::generation::SceneType::Quiz => "quiz", + zclaw_kernel::generation::SceneType::Discussion => "discussion", + _ => "", + }; + html.push_str(&format!( + r#"

{}

类型: {:?} | 时长: {}秒

"#, + type_class, + html_escape(&scene.content.title), + scene.content.scene_type, + scene.content.duration_seconds + )); + } + + html.push_str(""); + + Ok(ClassroomExportResponse { + filename: format!("{}.html", sanitize_filename(&classroom.title)), + content: html, + mime_type: "text/html".to_string(), + }) +} + +fn export_markdown(classroom: &Classroom) -> Result { + let mut md = String::new(); + md.push_str(&format!("# {}\n\n", &classroom.title)); + md.push_str(&format!("{}\n\n", &classroom.description)); + + md.push_str("## 课堂角色\n\n"); + for agent in &classroom.agents { + md.push_str(&format!("- {} **{}** ({:?})\n", agent.avatar, agent.name, agent.role)); + } + md.push('\n'); + + md.push_str("## 课程内容\n\n"); + for (i, scene) in classroom.scenes.iter().enumerate() { + md.push_str(&format!("### {}. {}\n\n", i + 1, scene.content.title)); + md.push_str(&format!("- 类型: `{:?}`\n", scene.content.scene_type)); + md.push_str(&format!("- 时长: {}秒\n\n", scene.content.duration_seconds)); + } + + Ok(ClassroomExportResponse { + filename: format!("{}.md", sanitize_filename(&classroom.title)), + content: md, + mime_type: "text/markdown".to_string(), + }) +} + +fn sanitize_filename(name: &str) -> String { + name.chars() + .map(|c| if c.is_alphanumeric() || c == '-' || c == '_' { c } else { '_' }) + .collect::() + .trim_matches('_') + .to_string() +} + +fn html_escape(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) +} diff --git a/desktop/src-tauri/src/classroom_commands/generate.rs b/desktop/src-tauri/src/classroom_commands/generate.rs new file mode 100644 index 0000000..f504a18 --- /dev/null +++ b/desktop/src-tauri/src/classroom_commands/generate.rs @@ -0,0 +1,286 @@ +//! Classroom generation commands +//! +//! - `classroom_generate` — start 4-stage pipeline, emit progress events +//! - `classroom_generation_progress` — query current progress +//! - `classroom_cancel_generation` — cancel active generation +//! - `classroom_get` — retrieve generated classroom data +//! - `classroom_list` — list all generated classrooms + +use serde::{Deserialize, Serialize}; +use tauri::{AppHandle, Emitter, State}; + +use zclaw_kernel::generation::{ + Classroom, GenerationPipeline, GenerationRequest as KernelGenRequest, GenerationStage, + TeachingStyle, DifficultyLevel, +}; + +use super::{ClassroomStore, GenerationTasks}; +use crate::kernel_commands::KernelState; + +// --------------------------------------------------------------------------- +// Request / Response types +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClassroomGenerateRequest { + pub topic: String, + pub document: Option, + pub style: Option, + pub level: Option, + pub target_duration_minutes: Option, + pub scene_count: Option, + pub custom_instructions: Option, + pub language: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClassroomGenerateResponse { + pub classroom_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClassroomProgressResponse { + pub stage: String, + pub progress: u8, + pub activity: String, + pub items_progress: Option<(usize, usize)>, +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn parse_style(s: Option<&str>) -> TeachingStyle { + match s.unwrap_or("lecture") { + "discussion" => TeachingStyle::Discussion, + "pbl" => TeachingStyle::Pbl, + "flipped" => TeachingStyle::Flipped, + "socratic" => TeachingStyle::Socratic, + _ => TeachingStyle::Lecture, + } +} + +fn parse_level(l: Option<&str>) -> DifficultyLevel { + match l.unwrap_or("intermediate") { + "beginner" => DifficultyLevel::Beginner, + "advanced" => DifficultyLevel::Advanced, + "expert" => DifficultyLevel::Expert, + _ => DifficultyLevel::Intermediate, + } +} + +fn stage_name(stage: &GenerationStage) -> &'static str { + match stage { + GenerationStage::AgentProfiles => "agent_profiles", + GenerationStage::Outline => "outline", + GenerationStage::Scene => "scene", + GenerationStage::Complete => "complete", + } +} + +// --------------------------------------------------------------------------- +// Commands +// --------------------------------------------------------------------------- + +/// Start classroom generation (4-stage pipeline). +/// Progress events are emitted via `classroom:progress`. +/// Supports cancellation between stages by removing the task from GenerationTasks. +#[tauri::command] +pub async fn classroom_generate( + app: AppHandle, + store: State<'_, ClassroomStore>, + tasks: State<'_, GenerationTasks>, + kernel_state: State<'_, KernelState>, + request: ClassroomGenerateRequest, +) -> Result { + if request.topic.trim().is_empty() { + return Err("Topic is required".to_string()); + } + + let topic_clone = request.topic.clone(); + + let kernel_request = KernelGenRequest { + topic: request.topic.clone(), + document: request.document.clone(), + style: parse_style(request.style.as_deref()), + level: parse_level(request.level.as_deref()), + target_duration_minutes: request.target_duration_minutes.unwrap_or(30), + scene_count: request.scene_count, + custom_instructions: request.custom_instructions.clone(), + language: request.language.clone().or_else(|| Some("zh-CN".to_string())), + }; + + // Register generation task so cancellation can check it + { + use zclaw_kernel::generation::GenerationProgress; + let mut t = tasks.lock().await; + t.insert(topic_clone.clone(), GenerationProgress { + stage: zclaw_kernel::generation::GenerationStage::AgentProfiles, + progress: 0, + activity: "Starting generation...".to_string(), + items_progress: None, + eta_seconds: None, + }); + } + + // Get LLM driver from kernel if available, otherwise use placeholder mode + let pipeline = { + let ks = kernel_state.lock().await; + if let Some(kernel) = ks.as_ref() { + GenerationPipeline::with_driver(kernel.driver()) + } else { + GenerationPipeline::new() + } + }; + + // Helper: check if cancelled + let is_cancelled = || { + let t = tasks.blocking_lock(); + !t.contains_key(&topic_clone) + }; + + // Helper: emit progress event + let emit_progress = |stage: &str, progress: u8, activity: &str| { + let _ = app.emit("classroom:progress", serde_json::json!({ + "topic": &topic_clone, + "stage": stage, + "progress": progress, + "activity": activity + })); + }; + + // ── Stage 0: Agent Profiles ── + emit_progress("agent_profiles", 5, "生成课堂角色..."); + let agents = pipeline.generate_agent_profiles(&kernel_request).await; + emit_progress("agent_profiles", 25, "角色生成完成"); + if is_cancelled() { + return Err("Generation cancelled".to_string()); + } + + // ── Stage 1: Outline ── + emit_progress("outline", 30, "分析主题,生成大纲..."); + let outline = pipeline.generate_outline(&kernel_request).await + .map_err(|e| format!("Outline generation failed: {}", e))?; + emit_progress("outline", 50, &format!("大纲完成:{} 个场景", outline.len())); + if is_cancelled() { + return Err("Generation cancelled".to_string()); + } + + // ── Stage 2: Scenes (parallel) ── + emit_progress("scene", 55, &format!("并行生成 {} 个场景...", outline.len())); + let scenes = pipeline.generate_scenes(&outline).await + .map_err(|e| format!("Scene generation failed: {}", e))?; + if is_cancelled() { + return Err("Generation cancelled".to_string()); + } + + // ── Stage 3: Assemble ── + emit_progress("complete", 90, "组装课堂..."); + + // Build classroom directly (pipeline.build_classroom is private) + let total_duration: u32 = scenes.iter().map(|s| s.content.duration_seconds).sum(); + let objectives = outline.iter() + .take(3) + .map(|item| format!("理解: {}", item.title)) + .collect::>(); + let classroom_id = uuid::Uuid::new_v4().to_string(); + + let classroom = Classroom { + id: classroom_id.clone(), + title: format!("课堂: {}", kernel_request.topic), + description: format!("{:?} 风格课堂 — {}", kernel_request.style, kernel_request.topic), + topic: kernel_request.topic.clone(), + style: kernel_request.style, + level: kernel_request.level, + total_duration, + objectives, + scenes, + agents, + metadata: zclaw_kernel::generation::ClassroomMetadata { + generated_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as i64, + source_document: kernel_request.document.map(|_| "user_document".to_string()), + model: None, + version: "2.0.0".to_string(), + custom: serde_json::Map::new(), + }, + }; + + // Store classroom + { + let mut s = store.lock().await; + s.insert(classroom_id.clone(), classroom); + } + + // Clear generation task + { + let mut t = tasks.lock().await; + t.remove(&topic_clone); + } + + // Emit completion + emit_progress("complete", 100, "课堂生成完成"); + + Ok(ClassroomGenerateResponse { + classroom_id, + }) +} + +/// Get current generation progress for a topic +#[tauri::command] +pub async fn classroom_generation_progress( + tasks: State<'_, GenerationTasks>, + topic: String, +) -> Result { + let t = tasks.lock().await; + let progress = t.get(&topic); + Ok(ClassroomProgressResponse { + stage: progress.map(|p| stage_name(&p.stage).to_string()).unwrap_or_else(|| "none".to_string()), + progress: progress.map(|p| p.progress).unwrap_or(0), + activity: progress.map(|p| p.activity.clone()).unwrap_or_default(), + items_progress: progress.and_then(|p| p.items_progress), + }) +} + +/// Cancel an active generation +#[tauri::command] +pub async fn classroom_cancel_generation( + tasks: State<'_, GenerationTasks>, + topic: String, +) -> Result<(), String> { + let mut t = tasks.lock().await; + t.remove(&topic); + Ok(()) +} + +/// Retrieve a generated classroom by ID +#[tauri::command] +pub async fn classroom_get( + store: State<'_, ClassroomStore>, + classroom_id: String, +) -> Result { + let s = store.lock().await; + s.get(&classroom_id) + .cloned() + .ok_or_else(|| format!("Classroom '{}' not found", classroom_id)) +} + +/// List all generated classrooms (id + title only) +#[tauri::command] +pub async fn classroom_list( + store: State<'_, ClassroomStore>, +) -> Result, String> { + let s = store.lock().await; + Ok(s.values().map(|c| serde_json::json!({ + "id": c.id, + "title": c.title, + "topic": c.topic, + "totalDuration": c.total_duration, + "sceneCount": c.scenes.len(), + })).collect()) +} diff --git a/desktop/src-tauri/src/classroom_commands/mod.rs b/desktop/src-tauri/src/classroom_commands/mod.rs new file mode 100644 index 0000000..4aa179f --- /dev/null +++ b/desktop/src-tauri/src/classroom_commands/mod.rs @@ -0,0 +1,41 @@ +//! Classroom generation and interaction commands +//! +//! Tauri commands for the OpenMAIC-style interactive classroom: +//! - Generate classroom (4-stage pipeline with progress events) +//! - Multi-agent chat +//! - Export (HTML/Markdown/JSON) + +use std::sync::Arc; +use tokio::sync::Mutex; +use zclaw_kernel::generation::Classroom; + +pub mod chat; +pub mod export; +pub mod generate; + +// --------------------------------------------------------------------------- +// Shared state types +// --------------------------------------------------------------------------- + +/// In-memory classroom store: classroom_id → Classroom +pub type ClassroomStore = Arc>>; + +/// Active generation tasks: topic → progress +pub type GenerationTasks = Arc>>; + +// Re-export chat state type +// Re-export chat state type — used by lib.rs to construct managed state +#[allow(unused_imports)] +pub use chat::ChatStore; + +// --------------------------------------------------------------------------- +// State constructors +// --------------------------------------------------------------------------- + +pub fn create_classroom_state() -> ClassroomStore { + Arc::new(Mutex::new(std::collections::HashMap::new())) +} + +pub fn create_generation_tasks() -> GenerationTasks { + Arc::new(Mutex::new(std::collections::HashMap::new())) +} diff --git a/desktop/src-tauri/src/intelligence/identity.rs b/desktop/src-tauri/src/intelligence/identity.rs index 7205904..729ae9a 100644 --- a/desktop/src-tauri/src/intelligence/identity.rs +++ b/desktop/src-tauri/src/intelligence/identity.rs @@ -258,11 +258,18 @@ impl AgentIdentityManager { if !identity.instructions.is_empty() { sections.push(identity.instructions.clone()); } - if !identity.user_profile.is_empty() - && identity.user_profile != default_user_profile() - { - sections.push(format!("## 用户画像\n{}", identity.user_profile)); - } + // NOTE: user_profile injection is intentionally disabled. + // The reflection engine may accumulate overly specific details from past + // conversations (e.g., "广东光华", "汕头玩具产业") into user_profile. + // These details then leak into every new conversation's system prompt, + // causing the model to think about old topics instead of the current query. + // Memory injection should only happen via MemoryMiddleware with relevance + // filtering, not unconditionally via user_profile. + // if !identity.user_profile.is_empty() + // && identity.user_profile != default_user_profile() + // { + // sections.push(format!("## 用户画像\n{}", identity.user_profile)); + // } if let Some(ctx) = memory_context { sections.push(ctx.to_string()); } diff --git a/desktop/src-tauri/src/kernel_commands/chat.rs b/desktop/src-tauri/src/kernel_commands/chat.rs index 4ec6f6e..49ba752 100644 --- a/desktop/src-tauri/src/kernel_commands/chat.rs +++ b/desktop/src-tauri/src/kernel_commands/chat.rs @@ -34,6 +34,7 @@ pub struct ChatResponse { #[serde(rename_all = "camelCase", tag = "type")] pub enum StreamChatEvent { Delta { delta: String }, + ThinkingDelta { delta: String }, ToolStart { name: String, input: serde_json::Value }, ToolEnd { name: String, output: serde_json::Value }, IterationStart { iteration: usize, max_iterations: usize }, @@ -218,6 +219,10 @@ pub async fn agent_chat_stream( tracing::trace!("[agent_chat_stream] Delta: {} bytes", delta.len()); StreamChatEvent::Delta { delta: delta.clone() } } + LoopEvent::ThinkingDelta(delta) => { + tracing::trace!("[agent_chat_stream] ThinkingDelta: {} bytes", delta.len()); + StreamChatEvent::ThinkingDelta { delta: delta.clone() } + } LoopEvent::ToolStart { name, input } => { tracing::debug!("[agent_chat_stream] ToolStart: {}", name); if name.starts_with("hand_") { diff --git a/desktop/src-tauri/src/kernel_commands/lifecycle.rs b/desktop/src-tauri/src/kernel_commands/lifecycle.rs index 8a5bbd0..ee79f0f 100644 --- a/desktop/src-tauri/src/kernel_commands/lifecycle.rs +++ b/desktop/src-tauri/src/kernel_commands/lifecycle.rs @@ -249,3 +249,130 @@ pub async fn kernel_shutdown( Ok(()) } + +/// Apply SaaS-synced configuration to the Kernel config file. +/// +/// Writes relevant config values (agent, llm categories) to the TOML config file. +/// The changes take effect on the next Kernel restart. +#[tauri::command] +pub async fn kernel_apply_saas_config( + configs: Vec, +) -> Result { + use std::io::Write; + + let config_path = zclaw_kernel::config::KernelConfig::find_config_path() + .ok_or_else(|| "No config file path found".to_string())?; + + // Read existing config or create empty + let existing = if config_path.exists() { + std::fs::read_to_string(&config_path).unwrap_or_default() + } else { + String::new() + }; + + let mut updated = existing; + let mut applied: u32 = 0; + + for config in &configs { + // Only process kernel-relevant categories + if !matches!(config.category.as_str(), "agent" | "llm") { + continue; + } + + // Write key=value to the [llm] or [agent] section + let section = &config.category; + let key = config.key.replace('.', "_"); + let value = &config.value; + + // Simple TOML patching: find or create section, update key + let section_header = format!("[{}]", section); + let line_to_set = format!("{} = {}", key, toml_quote_value(value)); + + if let Some(section_start) = updated.find(§ion_header) { + // Section exists, find or add the key within it + let after_header = section_start + section_header.len(); + let next_section = updated[after_header..].find("\n[") + .map(|i| after_header + i) + .unwrap_or(updated.len()); + + let section_content = &updated[after_header..next_section]; + let key_prefix = format!("\n{} =", key); + let key_prefix_alt = format!("\n{}=", key); + + if let Some(key_pos) = section_content.find(&key_prefix) + .or_else(|| section_content.find(&key_prefix_alt)) + { + // Key exists, replace the line + let line_start = after_header + key_pos + 1; // skip \n + let line_end = updated[line_start..].find('\n') + .map(|i| line_start + i) + .unwrap_or(updated.len()); + updated = format!( + "{}{}{}\n{}", + &updated[..line_start], + line_to_set, + if line_end < updated.len() { "" } else { "" }, + &updated[line_end..] + ); + // Remove the extra newline if line_end included one + updated = updated.replace(&format!("{}\n\n", line_to_set), &format!("{}\n", line_to_set)); + } else { + // Key doesn't exist, append to section + updated.insert_str(next_section, format!("\n{}", line_to_set).as_str()); + } + } else { + // Section doesn't exist, append it + updated = format!("{}\n{}\n{}\n", updated.trim_end(), section_header, line_to_set); + } + applied += 1; + } + + if applied > 0 { + // Ensure parent directory exists + if let Some(parent) = config_path.parent() { + std::fs::create_dir_all(parent).map_err(|e| format!("Failed to create config dir: {}", e))?; + } + + let mut file = std::fs::File::create(&config_path) + .map_err(|e| format!("Failed to write config: {}", e))?; + file.write_all(updated.as_bytes()) + .map_err(|e| format!("Failed to write config: {}", e))?; + + tracing::info!( + "[kernel_apply_saas_config] Applied {} config items to {:?} (restart required)", + applied, + config_path + ); + } + + Ok(applied) +} + +/// Single config item from SaaS sync +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SaasConfigItem { + pub category: String, + pub key: String, + pub value: String, +} + +/// Quote a value for TOML format +fn toml_quote_value(value: &str) -> String { + // Try to parse as number or boolean + if value == "true" || value == "false" { + return value.to_string(); + } + if let Ok(n) = value.parse::() { + return n.to_string(); + } + if let Ok(n) = value.parse::() { + return n.to_string(); + } + // Handle multi-line strings with TOML triple-quote syntax + if value.contains('\n') { + return format!("\"\"\"\n{}\"\"\"", value.replace('\\', "\\\\").replace("\"\"\"", "'\"'\"'\"")); + } + // Default: quote as string + format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\"")) +} diff --git a/desktop/src-tauri/src/lib.rs b/desktop/src-tauri/src/lib.rs index 2dec273..ab59589 100644 --- a/desktop/src-tauri/src/lib.rs +++ b/desktop/src-tauri/src/lib.rs @@ -34,6 +34,9 @@ mod kernel_commands; // Pipeline commands (DSL-based workflows) mod pipeline_commands; +// Classroom generation and interaction commands +mod classroom_commands; + // Gateway sub-modules (runtime, config, io, commands) mod gateway; @@ -99,6 +102,11 @@ pub fn run() { // Initialize Pipeline state (DSL-based workflows) let pipeline_state = pipeline_commands::create_pipeline_state(); + // Initialize Classroom state (generation + chat) + let classroom_state = classroom_commands::create_classroom_state(); + let classroom_chat_state = classroom_commands::chat::create_chat_state(); + let classroom_gen_tasks = classroom_commands::create_generation_tasks(); + tauri::Builder::default() .plugin(tauri_plugin_opener::init()) .manage(browser_state) @@ -110,11 +118,15 @@ pub fn run() { .manage(scheduler_state) .manage(kernel_commands::SessionStreamGuard::default()) .manage(pipeline_state) + .manage(classroom_state) + .manage(classroom_chat_state) + .manage(classroom_gen_tasks) .invoke_handler(tauri::generate_handler![ // Internal ZCLAW Kernel commands (preferred) kernel_commands::lifecycle::kernel_init, kernel_commands::lifecycle::kernel_status, kernel_commands::lifecycle::kernel_shutdown, + kernel_commands::lifecycle::kernel_apply_saas_config, kernel_commands::agent::agent_create, kernel_commands::agent::agent_list, kernel_commands::agent::agent_get, @@ -300,7 +312,16 @@ pub fn run() { intelligence::identity::identity_get_snapshots, intelligence::identity::identity_restore_snapshot, intelligence::identity::identity_list_agents, - intelligence::identity::identity_delete_agent + intelligence::identity::identity_delete_agent, + // Classroom generation and interaction commands + classroom_commands::generate::classroom_generate, + classroom_commands::generate::classroom_generation_progress, + classroom_commands::generate::classroom_cancel_generation, + classroom_commands::generate::classroom_get, + classroom_commands::generate::classroom_list, + classroom_commands::chat::classroom_chat, + classroom_commands::chat::classroom_chat_history, + classroom_commands::export::classroom_export ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); diff --git a/desktop/src/App.tsx b/desktop/src/App.tsx index dc4d6a2..0e58f06 100644 --- a/desktop/src/App.tsx +++ b/desktop/src/App.tsx @@ -29,6 +29,7 @@ import { useProposalNotifications, ProposalNotificationHandler } from './lib/use import { useToast } from './components/ui/Toast'; import type { Clone } from './store/agentStore'; import { createLogger } from './lib/logger'; +import { startOfflineMonitor } from './store/offlineStore'; const log = createLogger('App'); @@ -86,6 +87,8 @@ function App() { useEffect(() => { document.title = 'ZCLAW'; + const stopOfflineMonitor = startOfflineMonitor(); + return () => { stopOfflineMonitor(); }; }, []); // Restore SaaS session from OS keyring on startup (before auth gate) @@ -152,8 +155,11 @@ function App() { let mounted = true; const bootstrap = async () => { - // 未登录时不启动 bootstrap - if (!useSaaSStore.getState().isLoggedIn) return; + // 未登录时不启动 bootstrap,直接结束 loading + if (!useSaaSStore.getState().isLoggedIn) { + setBootstrapping(false); + return; + } try { // Step 1: Check and start local gateway in Tauri environment diff --git a/desktop/src/components/ChatArea.tsx b/desktop/src/components/ChatArea.tsx index fe41c54..a5d396a 100644 --- a/desktop/src/components/ChatArea.tsx +++ b/desktop/src/components/ChatArea.tsx @@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback, useMemo, type MutableRefObjec import { motion, AnimatePresence } from 'framer-motion'; import { List, type ListImperativeAPI } from 'react-window'; import { useChatStore, Message } from '../store/chatStore'; +import { useArtifactStore } from '../store/chat/artifactStore'; import { useConnectionStore } from '../store/connectionStore'; import { useAgentStore } from '../store/agentStore'; import { useConfigStore } from '../store/configStore'; @@ -12,6 +13,8 @@ import { ArtifactPanel } from './ai/ArtifactPanel'; import { ToolCallChain } from './ai/ToolCallChain'; import { listItemVariants, defaultTransition, fadeInVariants } from '../lib/animations'; import { FirstConversationPrompt } from './FirstConversationPrompt'; +import { ClassroomPlayer } from './classroom_player'; +import { useClassroomStore } from '../store/classroomStore'; // MessageSearch temporarily removed during DeerFlow redesign import { OfflineIndicator } from './OfflineIndicator'; import { @@ -45,11 +48,14 @@ export function ChatArea() { messages, currentAgent, isStreaming, isLoading, currentModel, sendMessage: sendToGateway, setCurrentModel, initStreamListener, newConversation, chatMode, setChatMode, suggestions, - artifacts, selectedArtifactId, artifactPanelOpen, - selectArtifact, setArtifactPanelOpen, totalInputTokens, totalOutputTokens, } = useChatStore(); + const { + artifacts, selectedArtifactId, artifactPanelOpen, + selectArtifact, setArtifactPanelOpen, + } = useArtifactStore(); const connectionState = useConnectionStore((s) => s.connectionState); + const { activeClassroom, classroomOpen, closeClassroom, generating, progressPercent, progressActivity, error: classroomError, clearError: clearClassroomError } = useClassroomStore(); const clones = useAgentStore((s) => s.clones); const models = useConfigStore((s) => s.models); @@ -203,9 +209,76 @@ export function ChatArea() { ); return ( +
+ {/* Generation progress overlay */} + + {generating && ( + +
+
+
+

+ 正在生成课堂... +

+

+ {progressActivity || '准备中...'} +

+
+ {progressPercent > 0 && ( +
+
+
+
+

{progressPercent}%

+
+ )} + +
+ + )} + + + {/* ClassroomPlayer overlay */} + + {classroomOpen && activeClassroom && ( + + + + )} + + + {/* Classroom generation error banner */} + {classroomError && ( +
+ 课堂生成失败: {classroomError} + +
+ )} {/* Header — DeerFlow-style: minimal */}
@@ -298,6 +371,7 @@ export function ChatArea() { getHeight={getHeight} onHeightChange={setHeight} messageRefs={messageRefs} + setInput={setInput} /> ) : ( messages.map((message) => ( @@ -310,7 +384,7 @@ export function ChatArea() { layout transition={defaultTransition} > - + )) )} @@ -433,19 +507,16 @@ export function ChatArea() { rightPanelOpen={artifactPanelOpen} onRightPanelToggle={setArtifactPanelOpen} /> +
); } -function MessageBubble({ message }: { message: Message }) { - // Tool messages are now absorbed into the assistant message's toolSteps chain. - // Legacy standalone tool messages (from older sessions) still render as before. +function MessageBubble({ message, setInput }: { message: Message; setInput: (text: string) => void }) { if (message.role === 'tool') { return null; } const isUser = message.role === 'user'; - - // 思考中状态:streaming 且内容为空时显示思考指示器 const isThinking = message.streaming && !message.content; // Download message as Markdown file @@ -518,7 +589,20 @@ function MessageBubble({ message }: { message: Message }) { : '...'}
{message.error && ( -

{message.error}

+
+

{message.error}

+ +
)} {/* Download button for AI messages - show on hover */} {!isUser && message.content && !message.streaming && ( @@ -543,6 +627,7 @@ interface VirtualizedMessageRowProps { message: Message; onHeightChange: (height: number) => void; messageRefs: MutableRefObject>; + setInput: (text: string) => void; } /** @@ -553,6 +638,7 @@ function VirtualizedMessageRow({ message, onHeightChange, messageRefs, + setInput, style, ariaAttributes, }: VirtualizedMessageRowProps & { @@ -587,7 +673,7 @@ function VirtualizedMessageRow({ className="py-3" {...ariaAttributes} > - +
); } @@ -598,6 +684,7 @@ interface VirtualizedMessageListProps { getHeight: (id: string, role: string) => number; onHeightChange: (id: string, height: number) => void; messageRefs: MutableRefObject>; + setInput: (text: string) => void; } /** @@ -610,6 +697,7 @@ function VirtualizedMessageList({ getHeight, onHeightChange, messageRefs, + setInput, }: VirtualizedMessageListProps) { // Row component for react-window v2 const RowComponent = (props: { @@ -625,6 +713,7 @@ function VirtualizedMessageList({ message={messages[props.index]} onHeightChange={(h) => onHeightChange(messages[props.index].id, h)} messageRefs={messageRefs} + setInput={setInput} style={props.style} ariaAttributes={props.ariaAttributes} /> diff --git a/desktop/src/components/ClassroomPreviewer.tsx b/desktop/src/components/ClassroomPreviewer.tsx index b892303..5d49e3c 100644 --- a/desktop/src/components/ClassroomPreviewer.tsx +++ b/desktop/src/components/ClassroomPreviewer.tsx @@ -67,6 +67,7 @@ interface ClassroomPreviewerProps { data: ClassroomData; onClose?: () => void; onExport?: (format: 'pptx' | 'html' | 'pdf') => void; + onOpenFullPlayer?: () => void; } // === Sub-Components === @@ -271,6 +272,7 @@ function OutlinePanel({ export function ClassroomPreviewer({ data, onExport, + onOpenFullPlayer, }: ClassroomPreviewerProps) { const [currentSceneIndex, setCurrentSceneIndex] = useState(0); const [isPlaying, setIsPlaying] = useState(false); @@ -398,6 +400,15 @@ export function ClassroomPreviewer({

+ {onOpenFullPlayer && ( + + )}
); diff --git a/desktop/src/components/ai/Conversation.tsx b/desktop/src/components/ai/Conversation.tsx index 675a88a..067182d 100644 --- a/desktop/src/components/ai/Conversation.tsx +++ b/desktop/src/components/ai/Conversation.tsx @@ -109,7 +109,7 @@ export function Conversation({ children, className = '' }: ConversationProps) {
{children}
diff --git a/desktop/src/components/ai/ResizableChatLayout.tsx b/desktop/src/components/ai/ResizableChatLayout.tsx index 5eba9a6..8046d73 100644 --- a/desktop/src/components/ai/ResizableChatLayout.tsx +++ b/desktop/src/components/ai/ResizableChatLayout.tsx @@ -62,7 +62,7 @@ export function ResizableChatLayout({ if (!rightPanelOpen || !rightPanel) { return ( -
+
{chatPanel} +
+
+
+ ); +} diff --git a/desktop/src/components/classroom_player/ClassroomPlayer.tsx b/desktop/src/components/classroom_player/ClassroomPlayer.tsx new file mode 100644 index 0000000..8806e93 --- /dev/null +++ b/desktop/src/components/classroom_player/ClassroomPlayer.tsx @@ -0,0 +1,231 @@ +/** + * ClassroomPlayer — Full-screen interactive classroom player. + * + * Layout: Notes sidebar | Main stage | Chat panel + * Top: Title + Agent avatars + * Bottom: Scene navigation + playback controls + */ + +import { useState, useCallback, useEffect } from 'react'; +import { invoke } from '@tauri-apps/api/core'; +import { useClassroom } from '../../hooks/useClassroom'; +import { SceneRenderer } from './SceneRenderer'; +import { AgentChat } from './AgentChat'; +import { NotesSidebar } from './NotesSidebar'; +import { TtsPlayer } from './TtsPlayer'; +import { Download } from 'lucide-react'; + +interface ClassroomPlayerProps { + onClose: () => void; +} + +export function ClassroomPlayer({ onClose }: ClassroomPlayerProps) { + const { + activeClassroom, + chatMessages, + chatLoading, + sendChatMessage, + } = useClassroom(); + + const [currentSceneIndex, setCurrentSceneIndex] = useState(0); + const [sidebarOpen, setSidebarOpen] = useState(true); + const [chatOpen, setChatOpen] = useState(true); + const [exporting, setExporting] = useState(false); + + const classroom = activeClassroom; + const scenes = classroom?.scenes ?? []; + const agents = classroom?.agents ?? []; + const currentScene = scenes[currentSceneIndex] ?? null; + + // Navigate to next/prev scene + const goNext = useCallback(() => { + setCurrentSceneIndex((i) => Math.min(i + 1, scenes.length - 1)); + }, [scenes.length]); + + const goPrev = useCallback(() => { + setCurrentSceneIndex((i) => Math.max(i - 1, 0)); + }, []); + + // Keyboard shortcuts + useEffect(() => { + const handler = (e: KeyboardEvent) => { + if (e.key === 'ArrowRight') goNext(); + else if (e.key === 'ArrowLeft') goPrev(); + else if (e.key === 'Escape') onClose(); + }; + window.addEventListener('keydown', handler); + return () => window.removeEventListener('keydown', handler); + }, [goNext, goPrev, onClose]); + + // Chat handler + const handleChatSend = useCallback(async (message: string) => { + const sceneContext = currentScene?.content.title; + await sendChatMessage(message, sceneContext); + }, [sendChatMessage, currentScene]); + + // Export handler + const handleExport = useCallback(async (format: 'html' | 'markdown' | 'json') => { + if (!classroom) return; + setExporting(true); + try { + const result = await invoke<{ content: string; filename: string; mimeType: string }>( + 'classroom_export', + { request: { classroomId: classroom.id, format } } + ); + // Download the exported file + const blob = new Blob([result.content], { type: result.mimeType }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = result.filename; + a.click(); + URL.revokeObjectURL(url); + } catch (e) { + console.error('Export failed:', e); + } finally { + setExporting(false); + } + }, [classroom]); + + if (!classroom) { + return ( +
+ No classroom loaded +
+ ); + } + + return ( +
+ {/* Header */} +
+
+ +

+ {classroom.title} +

+
+ + {/* Agent avatars */} +
+ {agents.map((agent) => ( + + {agent.avatar} + + ))} +
+ +
+ + + {/* Export dropdown */} +
+ +
+ + + +
+
+
+
+ + {/* Main content */} +
+ {/* Notes sidebar */} + {sidebarOpen && ( + + )} + + {/* Main stage */} +
+ {currentScene ? ( + + ) : ( +
+ No scenes available +
+ )} +
+ + {/* Chat panel */} + {chatOpen && ( + + )} +
+ + {/* Bottom navigation */} +
+
+ + + {currentSceneIndex + 1} / {scenes.length} + + +
+ + {/* TTS + Scene info */} +
+ {currentScene?.content.notes && ( + + )} +
+ {currentScene?.content.sceneType ?? ''} + {currentScene?.content.durationSeconds + ? ` · ${Math.floor(currentScene.content.durationSeconds / 60)}:${String(currentScene.content.durationSeconds % 60).padStart(2, '0')}` + : ''} +
+
+
+
+ ); +} diff --git a/desktop/src/components/classroom_player/NotesSidebar.tsx b/desktop/src/components/classroom_player/NotesSidebar.tsx new file mode 100644 index 0000000..d90a45b --- /dev/null +++ b/desktop/src/components/classroom_player/NotesSidebar.tsx @@ -0,0 +1,71 @@ +/** + * NotesSidebar — Scene outline navigation + notes. + * + * Left panel showing all scenes as clickable items with notes. + */ + +import type { GeneratedScene } from '../../types/classroom'; + +interface NotesSidebarProps { + scenes: GeneratedScene[]; + currentIndex: number; + onSelectScene: (index: number) => void; +} + +export function NotesSidebar({ scenes, currentIndex, onSelectScene }: NotesSidebarProps) { + return ( +
+
+

+ Outline +

+
+ + +
+ ); +} + +function getTypeColor(type: string): string { + switch (type) { + case 'slide': return '#6366F1'; + case 'quiz': return '#F59E0B'; + case 'discussion': return '#10B981'; + case 'interactive': return '#8B5CF6'; + case 'pbl': return '#EF4444'; + case 'media': return '#06B6D4'; + default: return '#9CA3AF'; + } +} diff --git a/desktop/src/components/classroom_player/SceneRenderer.tsx b/desktop/src/components/classroom_player/SceneRenderer.tsx new file mode 100644 index 0000000..c5fda5b --- /dev/null +++ b/desktop/src/components/classroom_player/SceneRenderer.tsx @@ -0,0 +1,219 @@ +/** + * SceneRenderer — Renders a single classroom scene. + * + * Supports scene types: slide, quiz, discussion, interactive, text, pbl, media. + * Executes scene actions (speech, whiteboard, quiz, discussion). + */ + +import { useState, useEffect, useCallback } from 'react'; +import type { GeneratedScene, SceneContent, SceneAction, AgentProfile } from '../../types/classroom'; + +interface SceneRendererProps { + scene: GeneratedScene; + agents: AgentProfile[]; + autoPlay?: boolean; +} + +export function SceneRenderer({ scene, agents, autoPlay = true }: SceneRendererProps) { + const { content } = scene; + const [actionIndex, setActionIndex] = useState(0); + const [isPlaying, setIsPlaying] = useState(autoPlay); + const [whiteboardItems, setWhiteboardItems] = useState>([]); + + const actions = content.actions ?? []; + const currentAction = actions[actionIndex] ?? null; + + // Auto-advance through actions + useEffect(() => { + if (!isPlaying || actions.length === 0) return; + if (actionIndex >= actions.length) { + setIsPlaying(false); + return; + } + + const delay = getActionDelay(actions[actionIndex]); + const timer = setTimeout(() => { + processAction(actions[actionIndex]); + setActionIndex((i) => i + 1); + }, delay); + + return () => clearTimeout(timer); + }, [actionIndex, isPlaying, actions]); + + const processAction = useCallback((action: SceneAction) => { + switch (action.type) { + case 'whiteboard_draw_text': + case 'whiteboard_draw_shape': + case 'whiteboard_draw_chart': + case 'whiteboard_draw_latex': + setWhiteboardItems((prev) => [...prev, { type: action.type, data: action }]); + break; + case 'whiteboard_clear': + setWhiteboardItems([]); + break; + } + }, []); + + // Render scene based on type + return ( +
+ {/* Scene title */} +
+

+ {content.title} +

+ {content.notes && ( +

{content.notes}

+ )} +
+ + {/* Main content area */} +
+ {/* Content panel */} +
+ {renderContent(content)} +
+ + {/* Whiteboard area */} + {whiteboardItems.length > 0 && ( +
+ + {whiteboardItems.map((item, i) => ( + {renderWhiteboardItem(item)} + ))} + +
+ )} +
+ + {/* Current action indicator */} + {currentAction && ( +
+ {renderCurrentAction(currentAction, agents)} +
+ )} + + {/* Playback controls */} +
+ + + + Action {Math.min(actionIndex + 1, actions.length)} / {actions.length} + +
+
+ ); +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function getActionDelay(action: SceneAction): number { + switch (action.type) { + case 'speech': return 2000; + case 'whiteboard_draw_text': return 800; + case 'whiteboard_draw_shape': return 600; + case 'quiz_show': return 5000; + case 'discussion': return 10000; + default: return 1000; + } +} + +function renderContent(content: SceneContent) { + const data = content.content; + if (!data || typeof data !== 'object') return null; + + // Handle slide content + const keyPoints = data.key_points as string[] | undefined; + const description = data.description as string | undefined; + const slides = data.slides as Array<{ title: string; content: string }> | undefined; + + return ( +
+ {description && ( +

{description}

+ )} + {keyPoints && keyPoints.length > 0 && ( +
    + {keyPoints.map((point, i) => ( +
  • + + {point} +
  • + ))} +
+ )} + {slides && slides.map((slide, i) => ( +
+

{slide.title}

+

{slide.content}

+
+ ))} +
+ ); +} + +function renderCurrentAction(action: SceneAction, agents: AgentProfile[]) { + switch (action.type) { + case 'speech': { + const agent = agents.find(a => a.role === action.agentRole); + return ( +
+ {agent?.avatar ?? '💬'} +
+ {agent?.name ?? action.agentRole} +

{action.text}

+
+
+ ); + } + case 'quiz_show': + return
Quiz: {action.quizId}
; + case 'discussion': + return
Discussion: {action.topic}
; + default: + return
{action.type}
; + } +} + + +function renderWhiteboardItem(item: { type: string; data: Record }) { + switch (item.type) { + case 'whiteboard_draw_text': { + const d = item.data; + if ('text' in d && 'x' in d && 'y' in d) { + return ( + + {String(d.text ?? '')} + + ); + } + return null; + } + case 'whiteboard_draw_shape': { + const d = item.data as Record; + const x = typeof d.x === 'number' ? d.x : 0; + const y = typeof d.y === 'number' ? d.y : 0; + const w = typeof d.width === 'number' ? d.width : 100; + const h = typeof d.height === 'number' ? d.height : 50; + const fill = typeof d.fill === 'string' ? d.fill : '#e5e5e5'; + return ( + + ); + } + } +} diff --git a/desktop/src/components/classroom_player/TtsPlayer.tsx b/desktop/src/components/classroom_player/TtsPlayer.tsx new file mode 100644 index 0000000..88b973f --- /dev/null +++ b/desktop/src/components/classroom_player/TtsPlayer.tsx @@ -0,0 +1,155 @@ +/** + * TtsPlayer — Text-to-Speech playback controls for classroom narration. + * + * Uses the browser's built-in SpeechSynthesis API. + * Provides play/pause, speed, and volume controls. + */ + +import { useState, useEffect, useCallback, useRef } from 'react'; +import { Volume2, VolumeX, Pause, Play, Gauge } from 'lucide-react'; + +interface TtsPlayerProps { + text: string; + autoPlay?: boolean; + onEnd?: () => void; +} + +export function TtsPlayer({ text, autoPlay = false, onEnd }: TtsPlayerProps) { + const [isPlaying, setIsPlaying] = useState(false); + const [isPaused, setIsPaused] = useState(false); + const [rate, setRate] = useState(1.0); + const [isMuted, setIsMuted] = useState(false); + const utteranceRef = useRef(null); + + const speak = useCallback(() => { + if (!text || typeof window === 'undefined') return; + + window.speechSynthesis.cancel(); + + const utterance = new SpeechSynthesisUtterance(text); + utterance.lang = 'zh-CN'; + utterance.rate = rate; + utterance.volume = isMuted ? 0 : 1; + + utterance.onend = () => { + setIsPlaying(false); + setIsPaused(false); + onEnd?.(); + }; + utterance.onerror = () => { + setIsPlaying(false); + setIsPaused(false); + }; + + utteranceRef.current = utterance; + window.speechSynthesis.speak(utterance); + setIsPlaying(true); + setIsPaused(false); + }, [text, rate, isMuted, onEnd]); + + const togglePlay = useCallback(() => { + if (isPlaying && !isPaused) { + window.speechSynthesis.pause(); + setIsPaused(true); + } else if (isPaused) { + window.speechSynthesis.resume(); + setIsPaused(false); + } else { + speak(); + } + }, [isPlaying, isPaused, speak]); + + const stop = useCallback(() => { + window.speechSynthesis.cancel(); + setIsPlaying(false); + setIsPaused(false); + }, []); + + // Auto-play when text changes + useEffect(() => { + if (autoPlay && text) { + speak(); + } + return () => { + if (typeof window !== 'undefined') { + window.speechSynthesis.cancel(); + } + }; + }, [text, autoPlay, speak]); + + // Cleanup on unmount + useEffect(() => { + return () => { + if (typeof window !== 'undefined') { + window.speechSynthesis.cancel(); + } + }; + }, []); + + if (!text) return null; + + return ( +
+ {/* Play/Pause button */} + + + {/* Stop button */} + {isPlaying && ( + + )} + + {/* Speed control */} +
+ + +
+ + {/* Mute toggle */} + + + {/* Status indicator */} + {isPlaying && ( + + {isPaused ? '已暂停' : '朗读中...'} + + )} +
+ ); +} diff --git a/desktop/src/components/classroom_player/WhiteboardCanvas.tsx b/desktop/src/components/classroom_player/WhiteboardCanvas.tsx new file mode 100644 index 0000000..820dd52 --- /dev/null +++ b/desktop/src/components/classroom_player/WhiteboardCanvas.tsx @@ -0,0 +1,295 @@ +/** + * WhiteboardCanvas — SVG-based whiteboard for classroom scene rendering. + * + * Supports incremental drawing operations: + * - Text (positioned labels) + * - Shapes (rectangles, circles, arrows) + * - Charts (bar/line/pie via simple SVG) + * - LaTeX (rendered as styled text blocks) + */ + +import { useCallback } from 'react'; +import type { SceneAction } from '../../types/classroom'; + +interface WhiteboardCanvasProps { + items: WhiteboardItem[]; + width?: number; + height?: number; +} + +export interface WhiteboardItem { + type: string; + data: SceneAction; +} + +export function WhiteboardCanvas({ + items, + width = 800, + height = 600, +}: WhiteboardCanvasProps) { + const renderItem = useCallback((item: WhiteboardItem, index: number) => { + switch (item.type) { + case 'whiteboard_draw_text': + return ; + case 'whiteboard_draw_shape': + return ; + case 'whiteboard_draw_chart': + return ; + case 'whiteboard_draw_latex': + return ; + default: + return null; + } + }, []); + + return ( +
+ + {/* Grid background */} + + + + + + + + {/* Rendered items */} + {items.map((item, i) => renderItem(item, i))} + +
+ ); +} + +// --------------------------------------------------------------------------- +// Sub-components +// --------------------------------------------------------------------------- + +interface TextDrawData { + type: 'whiteboard_draw_text'; + x: number; + y: number; + text: string; + fontSize?: number; + color?: string; +} + +function TextItem({ data }: { data: TextDrawData }) { + return ( + + {data.text} + + ); +} + +interface ShapeDrawData { + type: 'whiteboard_draw_shape'; + shape: string; + x: number; + y: number; + width: number; + height: number; + fill?: string; +} + +function ShapeItem({ data }: { data: ShapeDrawData }) { + switch (data.shape) { + case 'circle': + return ( + + ); + case 'arrow': + return ( + + + + + + + + + ); + default: // rectangle + return ( + + ); + } +} + +interface ChartDrawData { + type: 'whiteboard_draw_chart'; + chartType: string; + data: Record; + x: number; + y: number; + width: number; + height: number; +} + +function ChartItem({ data }: { data: ChartDrawData }) { + const chartData = data.data; + const labels = (chartData?.labels as string[]) ?? []; + const values = (chartData?.values as number[]) ?? []; + + if (labels.length === 0 || values.length === 0) return null; + + switch (data.chartType) { + case 'bar': + return ; + case 'line': + return ; + default: + return ; + } +} + +function BarChart({ data, labels, values }: { + data: ChartDrawData; + labels: string[]; + values: number[]; +}) { + const maxVal = Math.max(...values, 1); + const barWidth = data.width / (labels.length * 2); + const chartHeight = data.height - 30; + + return ( + + {values.map((val, i) => { + const barHeight = (val / maxVal) * chartHeight; + return ( + + + + {labels[i]} + + + ); + })} + + ); +} + +function LineChart({ data, labels, values }: { + data: ChartDrawData; + labels: string[]; + values: number[]; +}) { + const maxVal = Math.max(...values, 1); + const chartHeight = data.height - 30; + const stepX = data.width / Math.max(labels.length - 1, 1); + + const points = values.map((val, i) => { + const x = i * stepX; + const y = chartHeight - (val / maxVal) * chartHeight; + return `${x},${y}`; + }).join(' '); + + return ( + + + {values.map((val, i) => { + const x = i * stepX; + const y = chartHeight - (val / maxVal) * chartHeight; + return ( + + + + {labels[i]} + + + ); + })} + + ); +} + +interface LatexDrawData { + type: 'whiteboard_draw_latex'; + latex: string; + x: number; + y: number; +} + +function LatexItem({ data }: { data: LatexDrawData }) { + return ( + + + + {data.latex} + + + ); +} diff --git a/desktop/src/components/classroom_player/index.ts b/desktop/src/components/classroom_player/index.ts new file mode 100644 index 0000000..2c3e87d --- /dev/null +++ b/desktop/src/components/classroom_player/index.ts @@ -0,0 +1,12 @@ +/** + * Classroom Player Components + * + * Re-exports all classroom player components. + */ + +export { ClassroomPlayer } from './ClassroomPlayer'; +export { SceneRenderer } from './SceneRenderer'; +export { AgentChat } from './AgentChat'; +export { NotesSidebar } from './NotesSidebar'; +export { WhiteboardCanvas } from './WhiteboardCanvas'; +export { TtsPlayer } from './TtsPlayer'; diff --git a/desktop/src/hooks/useClassroom.ts b/desktop/src/hooks/useClassroom.ts new file mode 100644 index 0000000..468eeb2 --- /dev/null +++ b/desktop/src/hooks/useClassroom.ts @@ -0,0 +1,76 @@ +/** + * useClassroom — React hook wrapping the classroom store for component consumption. + * + * Provides a simplified interface for classroom generation and chat. + */ + +import { useCallback } from 'react'; +import { + useClassroomStore, + type GenerationRequest, +} from '../store/classroomStore'; +import type { + Classroom, + ClassroomChatMessage, +} from '../types/classroom'; + +export interface UseClassroomReturn { + /** Is generation in progress */ + generating: boolean; + /** Current generation stage name */ + progressStage: string | null; + /** Progress percentage 0-100 */ + progressPercent: number; + /** The active classroom */ + activeClassroom: Classroom | null; + /** Chat messages for active classroom */ + chatMessages: ClassroomChatMessage[]; + /** Is a chat request loading */ + chatLoading: boolean; + /** Error message, if any */ + error: string | null; + /** Start classroom generation */ + startGeneration: (request: GenerationRequest) => Promise; + /** Cancel active generation */ + cancelGeneration: () => void; + /** Send a chat message in the active classroom */ + sendChatMessage: (message: string, sceneContext?: string) => Promise; + /** Clear current error */ + clearError: () => void; +} + +/** + * Hook for classroom generation and multi-agent chat. + * + * Components should use this hook rather than accessing the store directly, + * to keep the rendering logic decoupled from state management. + */ +export function useClassroom(): UseClassroomReturn { + const { + generating, + progressStage, + progressPercent, + activeClassroom, + chatMessages, + chatLoading, + error, + startGeneration, + cancelGeneration, + sendChatMessage, + clearError, + } = useClassroomStore(); + + return { + generating, + progressStage, + progressPercent, + activeClassroom, + chatMessages, + chatLoading, + error, + startGeneration: useCallback((req: GenerationRequest) => startGeneration(req), [startGeneration]), + cancelGeneration: useCallback(() => cancelGeneration(), [cancelGeneration]), + sendChatMessage: useCallback((msg, ctx) => sendChatMessage(msg, ctx), [sendChatMessage]), + clearError: useCallback(() => clearError(), [clearError]), + }; +} diff --git a/desktop/src/index.css b/desktop/src/index.css index 18b8a88..c49980e 100644 --- a/desktop/src/index.css +++ b/desktop/src/index.css @@ -1,27 +1,5 @@ @import "tailwindcss"; -/* Aurora gradient animation for welcome title (DeerFlow-inspired) */ -@keyframes gradient-shift { - 0%, 100% { background-position: 0% 50%; } - 50% { background-position: 100% 50%; } -} - -.aurora-title { - background: linear-gradient( - 135deg, - #f97316 0%, /* orange-500 */ - #ef4444 25%, /* red-500 */ - #f97316 50%, /* orange-500 */ - #fb923c 75%, /* orange-400 */ - #f97316 100% /* orange-500 */ - ); - background-size: 200% 200%; - -webkit-background-clip: text; - background-clip: text; - -webkit-text-fill-color: transparent; - animation: gradient-shift 4s ease infinite; -} - :root { /* Brand Colors - 中性灰色系 */ --color-primary: #374151; /* gray-700 */ @@ -154,3 +132,38 @@ textarea:focus-visible { outline: none !important; box-shadow: none !important; } + +/* === Accessibility: reduced motion === */ +@media (prefers-reduced-motion: reduce) { + *, *::before, *::after { + animation-duration: 0.01ms !important; + animation-iteration-count: 1 !important; + transition-duration: 0.01ms !important; + scroll-behavior: auto !important; + } +} + +/* === Responsive breakpoints for small windows/tablets === */ +@media (max-width: 768px) { + /* Auto-collapse sidebar aside on narrow viewports */ + aside.w-64 { + width: 0 !important; + min-width: 0 !important; + overflow: hidden; + border-right: none !important; + } + aside.w-64.sidebar-open { + width: 260px !important; + min-width: 260px !important; + position: fixed; + z-index: 50; + height: 100vh; + } +} + +@media (max-width: 480px) { + .chat-bubble-assistant, + .chat-bubble-user { + max-width: 95% !important; + } +} diff --git a/desktop/src/lib/audit-logger.ts b/desktop/src/lib/audit-logger.ts index eb6aba8..1399b06 100644 --- a/desktop/src/lib/audit-logger.ts +++ b/desktop/src/lib/audit-logger.ts @@ -3,6 +3,10 @@ * * 为 ZCLAW 前端操作提供统一的审计日志记录功能。 * 记录关键操作(Hand 触发、Agent 创建等)到本地存储。 + * + * @reserved This module is reserved for future audit logging integration. + * It is not currently imported by any component. When audit logging is needed, + * import { logAudit, logAuditSuccess, logAuditFailure } from this module. */ import { createLogger } from './logger'; diff --git a/desktop/src/lib/classroom-adapter.ts b/desktop/src/lib/classroom-adapter.ts new file mode 100644 index 0000000..b1fa7df --- /dev/null +++ b/desktop/src/lib/classroom-adapter.ts @@ -0,0 +1,142 @@ +/** + * Classroom Adapter + * + * Bridges the old ClassroomData type (ClassroomPreviewer) with the new + * Classroom type (ClassroomPlayer + Tauri backend). + */ + +import type { Classroom, GeneratedScene } from '../types/classroom'; +import { SceneType, TeachingStyle, DifficultyLevel } from '../types/classroom'; +import type { ClassroomData, ClassroomScene } from '../components/ClassroomPreviewer'; + +// --------------------------------------------------------------------------- +// Old → New (ClassroomData → Classroom) +// --------------------------------------------------------------------------- + +/** + * Convert a legacy ClassroomData to the new Classroom format. + * Used when opening ClassroomPlayer from Pipeline result previews. + */ +export function adaptToClassroom(data: ClassroomData): Classroom { + const scenes: GeneratedScene[] = data.scenes.map((scene, index) => ({ + id: scene.id, + outlineId: `outline-${index}`, + content: { + title: scene.title, + sceneType: mapSceneType(scene.type), + content: { + heading: scene.content.heading ?? scene.title, + key_points: scene.content.bullets ?? [], + description: scene.content.explanation, + quiz: scene.content.quiz ?? undefined, + }, + actions: [], + durationSeconds: scene.duration ?? 60, + notes: scene.narration, + }, + order: index, + })) as GeneratedScene[]; + + return { + id: data.id, + title: data.title, + description: data.subject, + topic: data.subject, + style: TeachingStyle.Lecture, + level: mapDifficulty(data.difficulty), + totalDuration: data.duration * 60, + objectives: [], + scenes, + agents: [], + metadata: { + generatedAt: new Date(data.createdAt).getTime(), + version: '1.0', + custom: {}, + }, + }; +} + +// --------------------------------------------------------------------------- +// New → Old (Classroom → ClassroomData) +// --------------------------------------------------------------------------- + +/** + * Convert a new Classroom to the legacy ClassroomData format. + * Used when rendering ClassroomPreviewer from new pipeline results. + */ +export function adaptToClassroomData(classroom: Classroom): ClassroomData { + const scenes: ClassroomScene[] = classroom.scenes.map((scene) => { + const data = scene.content.content as Record; + return { + id: scene.id, + title: scene.content.title, + type: mapToLegacySceneType(scene.content.sceneType), + content: { + heading: (data?.heading as string) ?? scene.content.title, + bullets: (data?.key_points as string[]) ?? [], + explanation: (data?.description as string) ?? '', + quiz: (data?.quiz as ClassroomScene['content']['quiz']) ?? undefined, + }, + narration: scene.content.notes, + duration: scene.content.durationSeconds, + }; + }); + + return { + id: classroom.id, + title: classroom.title, + subject: classroom.topic, + difficulty: mapToLegacyDifficulty(classroom.level), + duration: Math.ceil(classroom.totalDuration / 60), + scenes, + outline: { + sections: classroom.scenes.map((scene) => ({ + title: scene.content.title, + scenes: [scene.id], + })), + }, + createdAt: new Date(classroom.metadata.generatedAt).toISOString(), + }; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function mapSceneType(type: ClassroomScene['type']): SceneType { + switch (type) { + case 'title': return SceneType.Slide; + case 'content': return SceneType.Slide; + case 'quiz': return SceneType.Quiz; + case 'interactive': return SceneType.Interactive; + case 'summary': return SceneType.Text; + default: return SceneType.Slide; + } +} + +function mapToLegacySceneType(sceneType: string): ClassroomScene['type'] { + switch (sceneType) { + case 'quiz': return 'quiz'; + case 'interactive': return 'interactive'; + case 'text': return 'summary'; + default: return 'content'; + } +} + +function mapDifficulty(difficulty: string): DifficultyLevel { + switch (difficulty) { + case '初级': return DifficultyLevel.Beginner; + case '中级': return DifficultyLevel.Intermediate; + case '高级': return DifficultyLevel.Advanced; + default: return DifficultyLevel.Intermediate; + } +} + +function mapToLegacyDifficulty(level: string): ClassroomData['difficulty'] { + switch (level) { + case 'beginner': return '初级'; + case 'advanced': return '高级'; + case 'expert': return '高级'; + default: return '中级'; + } +} diff --git a/desktop/src/lib/error-handling.ts b/desktop/src/lib/error-handling.ts index 444b1d3..f697f70 100644 --- a/desktop/src/lib/error-handling.ts +++ b/desktop/src/lib/error-handling.ts @@ -56,12 +56,19 @@ function initErrorStore(): void { errors: [], addError: (error: AppError) => { + // Dedup: skip if same title+message already exists and undismissed + const isDuplicate = errorStore.errors.some( + (e) => !e.dismissed && e.title === error.title && e.message === error.message + ); + if (isDuplicate) return; + const storedError: StoredError = { ...error, dismissed: false, reported: false, }; - errorStore.errors = [storedError, ...errorStore.errors]; + // Cap at 50 errors to prevent unbounded growth + errorStore.errors = [storedError, ...errorStore.errors].slice(0, 50); // Notify listeners notifyErrorListeners(error); }, diff --git a/desktop/src/lib/kernel-chat.ts b/desktop/src/lib/kernel-chat.ts index bb66043..41d22ed 100644 --- a/desktop/src/lib/kernel-chat.ts +++ b/desktop/src/lib/kernel-chat.ts @@ -103,6 +103,12 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo callbacks.onDelta(streamEvent.delta); break; + case 'thinkingDelta': + if (callbacks.onThinkingDelta) { + callbacks.onThinkingDelta(streamEvent.delta); + } + break; + case 'tool_start': log.debug('Tool started:', streamEvent.name, streamEvent.input); if (callbacks.onTool) { diff --git a/desktop/src/lib/kernel-hands.ts b/desktop/src/lib/kernel-hands.ts index 8ad2dde..eaa8174 100644 --- a/desktop/src/lib/kernel-hands.ts +++ b/desktop/src/lib/kernel-hands.ts @@ -5,8 +5,20 @@ */ import { invoke } from '@tauri-apps/api/core'; +import { listen, type UnlistenFn } from '@tauri-apps/api/event'; +import { createLogger } from './logger'; import type { KernelClient } from './kernel-client'; +const log = createLogger('KernelHands'); + +/** Payload emitted by the Rust backend on `hand-execution-complete` events. */ +export interface HandExecutionCompletePayload { + approvalId: string; + handId: string; + success: boolean; + error?: string | null; +} + export function installHandMethods(ClientClass: { prototype: KernelClient }): void { const proto = ClientClass.prototype as any; @@ -92,7 +104,7 @@ export function installHandMethods(ClientClass: { prototype: KernelClient }): vo */ proto.getHandStatus = async function (this: KernelClient, name: string, runId: string): Promise<{ status: string; result?: unknown }> { try { - return await invoke('hand_run_status', { handName: name, runId }); + return await invoke('hand_run_status', { runId }); } catch (e) { const { createLogger } = await import('./logger'); createLogger('KernelHands').debug('hand_run_status failed', { name, runId, error: e }); @@ -171,4 +183,26 @@ export function installHandMethods(ClientClass: { prototype: KernelClient }): vo proto.respondToApproval = async function (this: KernelClient, approvalId: string, approved: boolean, reason?: string): Promise { return invoke('approval_respond', { id: approvalId, approved, reason }); }; + + // ─── Event Listeners ─── + + /** + * Listen for `hand-execution-complete` events emitted by the Rust backend + * after a hand finishes executing (both from direct trigger and approval flow). + * + * Returns an unlisten function for cleanup. + */ + proto.onHandExecutionComplete = async function ( + this: KernelClient, + callback: (payload: HandExecutionCompletePayload) => void, + ): Promise { + const unlisten = await listen( + 'hand-execution-complete', + (event) => { + log.debug('hand-execution-complete', event.payload); + callback(event.payload); + }, + ); + return unlisten; + }; } diff --git a/desktop/src/lib/kernel-skills.ts b/desktop/src/lib/kernel-skills.ts index c0f3798..092a6bf 100644 --- a/desktop/src/lib/kernel-skills.ts +++ b/desktop/src/lib/kernel-skills.ts @@ -109,7 +109,11 @@ export function installSkillMethods(ClientClass: { prototype: KernelClient }): v }> { return invoke('skill_execute', { id, - context: {}, + context: { + agentId: '', + sessionId: '', + workingDir: '', + }, input: input || {}, }); }; diff --git a/desktop/src/lib/kernel-triggers.ts b/desktop/src/lib/kernel-triggers.ts index 551f67d..1f3c7c0 100644 --- a/desktop/src/lib/kernel-triggers.ts +++ b/desktop/src/lib/kernel-triggers.ts @@ -96,7 +96,12 @@ export function installTriggerMethods(ClientClass: { prototype: KernelClient }): triggerType?: TriggerTypeSpec; }): Promise { try { - return await invoke('trigger_update', { id, updates }); + return await invoke('trigger_update', { + id, + name: updates.name, + enabled: updates.enabled, + handId: updates.handId, + }); } catch (error) { this.log('error', `[TriggersAPI] updateTrigger(${id}) failed: ${this.formatError(error)}`); throw error; diff --git a/desktop/src/lib/kernel-types.ts b/desktop/src/lib/kernel-types.ts index f469f16..009ff41 100644 --- a/desktop/src/lib/kernel-types.ts +++ b/desktop/src/lib/kernel-types.ts @@ -58,6 +58,7 @@ export interface EventCallback { export interface StreamCallbacks { onDelta: (delta: string) => void; + onThinkingDelta?: (delta: string) => void; onTool?: (tool: string, input: string, output: string) => void; onHand?: (name: string, status: string, result?: unknown) => void; onComplete: (inputTokens?: number, outputTokens?: number) => void; @@ -71,6 +72,11 @@ export interface StreamEventDelta { delta: string; } +export interface StreamEventThinkingDelta { + type: 'thinkingDelta'; + delta: string; +} + export interface StreamEventToolStart { type: 'tool_start'; name: string; @@ -114,6 +120,7 @@ export interface StreamEventHandEnd { export type StreamChatEvent = | StreamEventDelta + | StreamEventThinkingDelta | StreamEventToolStart | StreamEventToolEnd | StreamEventIterationStart diff --git a/desktop/src/lib/saas-admin.ts b/desktop/src/lib/saas-admin.ts deleted file mode 100644 index d950b85..0000000 --- a/desktop/src/lib/saas-admin.ts +++ /dev/null @@ -1,233 +0,0 @@ -/** - * SaaS Admin Methods — Mixin - * - * Installs admin panel API methods onto SaaSClient.prototype. - * Uses the same mixin pattern as gateway-api.ts. - * - * Reserved for future admin UI (Next.js admin dashboard). - * These methods are not called by the desktop app but are kept as thin API - * wrappers for when the admin panel is built. - */ - -import type { - ProviderInfo, - CreateProviderRequest, - UpdateProviderRequest, - ModelInfo, - CreateModelRequest, - UpdateModelRequest, - AccountApiKeyInfo, - CreateApiKeyRequest, - AccountPublic, - UpdateAccountRequest, - PaginatedResponse, - TokenInfo, - CreateTokenRequest, - OperationLogInfo, - DashboardStats, - RoleInfo, - CreateRoleRequest, - UpdateRoleRequest, - PermissionTemplate, - CreateTemplateRequest, -} from './saas-types'; - -export function installAdminMethods(ClientClass: { prototype: any }): void { - const proto = ClientClass.prototype; - - // --- Provider Management (Admin) --- - - /** List all providers */ - proto.listProviders = async function (this: { request(method: string, path: string, body?: unknown): Promise }): Promise { - return this.request('GET', '/api/v1/providers'); - }; - - /** Get provider by ID */ - proto.getProvider = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - return this.request('GET', `/api/v1/providers/${id}`); - }; - - /** Create a new provider (admin only) */ - proto.createProvider = async function (this: { request(method: string, path: string, body?: unknown): Promise }, data: CreateProviderRequest): Promise { - return this.request('POST', '/api/v1/providers', data); - }; - - /** Update a provider (admin only) */ - proto.updateProvider = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string, data: UpdateProviderRequest): Promise { - return this.request('PATCH', `/api/v1/providers/${id}`, data); - }; - - /** Delete a provider (admin only) */ - proto.deleteProvider = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - await this.request('DELETE', `/api/v1/providers/${id}`); - }; - - // --- Model Management (Admin) --- - - /** List models, optionally filtered by provider */ - proto.listModelsAdmin = async function (this: { request(method: string, path: string, body?: unknown): Promise }, providerId?: string): Promise { - const qs = providerId ? `?provider_id=${encodeURIComponent(providerId)}` : ''; - return this.request('GET', `/api/v1/models${qs}`); - }; - - /** Get model by ID */ - proto.getModel = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - return this.request('GET', `/api/v1/models/${id}`); - }; - - /** Create a new model (admin only) */ - proto.createModel = async function (this: { request(method: string, path: string, body?: unknown): Promise }, data: CreateModelRequest): Promise { - return this.request('POST', '/api/v1/models', data); - }; - - /** Update a model (admin only) */ - proto.updateModel = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string, data: UpdateModelRequest): Promise { - return this.request('PATCH', `/api/v1/models/${id}`, data); - }; - - /** Delete a model (admin only) */ - proto.deleteModel = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - await this.request('DELETE', `/api/v1/models/${id}`); - }; - - // --- Account API Keys --- - - /** List account's API keys */ - proto.listApiKeys = async function (this: { request(method: string, path: string, body?: unknown): Promise }, providerId?: string): Promise { - const qs = providerId ? `?provider_id=${encodeURIComponent(providerId)}` : ''; - return this.request('GET', `/api/v1/keys${qs}`); - }; - - /** Create a new API key */ - proto.createApiKey = async function (this: { request(method: string, path: string, body?: unknown): Promise }, data: CreateApiKeyRequest): Promise { - return this.request('POST', '/api/v1/keys', data); - }; - - /** Rotate an API key */ - proto.rotateApiKey = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string, newKeyValue: string): Promise { - await this.request('POST', `/api/v1/keys/${id}/rotate`, { new_key_value: newKeyValue }); - }; - - /** Revoke an API key */ - proto.revokeApiKey = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - await this.request('DELETE', `/api/v1/keys/${id}`); - }; - - // --- Account Management (Admin) --- - - /** List all accounts (admin only) */ - proto.listAccounts = async function (this: { request(method: string, path: string, body?: unknown): Promise }, params?: { page?: number; page_size?: number; role?: string; status?: string; search?: string }): Promise> { - const qs = new URLSearchParams(); - if (params?.page) qs.set('page', String(params.page)); - if (params?.page_size) qs.set('page_size', String(params.page_size)); - if (params?.role) qs.set('role', params.role); - if (params?.status) qs.set('status', params.status); - if (params?.search) qs.set('search', params.search); - const query = qs.toString(); - return this.request>('GET', `/api/v1/accounts${query ? '?' + query : ''}`); - }; - - /** Get account by ID (admin or self) */ - proto.getAccount = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - return this.request('GET', `/api/v1/accounts/${id}`); - }; - - /** Update account (admin or self) */ - proto.updateAccount = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string, data: UpdateAccountRequest): Promise { - return this.request('PATCH', `/api/v1/accounts/${id}`, data); - }; - - /** Update account status (admin only) */ - proto.updateAccountStatus = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string, status: 'active' | 'disabled' | 'suspended'): Promise { - await this.request('PATCH', `/api/v1/accounts/${id}/status`, { status }); - }; - - // --- API Token Management --- - - /** List API tokens for current account */ - proto.listTokens = async function (this: { request(method: string, path: string, body?: unknown): Promise }): Promise { - return this.request('GET', '/api/v1/tokens'); - }; - - /** Create a new API token */ - proto.createToken = async function (this: { request(method: string, path: string, body?: unknown): Promise }, data: CreateTokenRequest): Promise { - return this.request('POST', '/api/v1/tokens', data); - }; - - /** Revoke an API token */ - proto.revokeToken = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - await this.request('DELETE', `/api/v1/tokens/${id}`); - }; - - // --- Operation Logs (Admin) --- - - /** List operation logs (admin only) */ - proto.listOperationLogs = async function (this: { request(method: string, path: string, body?: unknown): Promise }, params?: { page?: number; page_size?: number }): Promise { - const qs = new URLSearchParams(); - if (params?.page) qs.set('page', String(params.page)); - if (params?.page_size) qs.set('page_size', String(params.page_size)); - const query = qs.toString(); - return this.request('GET', `/api/v1/logs/operations${query ? '?' + query : ''}`); - }; - - // --- Dashboard Statistics (Admin) --- - - /** Get dashboard statistics (admin only) */ - proto.getDashboardStats = async function (this: { request(method: string, path: string, body?: unknown): Promise }): Promise { - return this.request('GET', '/api/v1/stats/dashboard'); - }; - - // --- Role Management (Admin) --- - - /** List all roles */ - proto.listRoles = async function (this: { request(method: string, path: string, body?: unknown): Promise }): Promise { - return this.request('GET', '/api/v1/roles'); - }; - - /** Get role by ID */ - proto.getRole = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - return this.request('GET', `/api/v1/roles/${id}`); - }; - - /** Create a new role (admin only) */ - proto.createRole = async function (this: { request(method: string, path: string, body?: unknown): Promise }, data: CreateRoleRequest): Promise { - return this.request('POST', '/api/v1/roles', data); - }; - - /** Update a role (admin only) */ - proto.updateRole = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string, data: UpdateRoleRequest): Promise { - return this.request('PUT', `/api/v1/roles/${id}`, data); - }; - - /** Delete a role (admin only) */ - proto.deleteRole = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - await this.request('DELETE', `/api/v1/roles/${id}`); - }; - - // --- Permission Templates --- - - /** List permission templates */ - proto.listPermissionTemplates = async function (this: { request(method: string, path: string, body?: unknown): Promise }): Promise { - return this.request('GET', '/api/v1/permission-templates'); - }; - - /** Get permission template by ID */ - proto.getPermissionTemplate = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - return this.request('GET', `/api/v1/permission-templates/${id}`); - }; - - /** Create a permission template (admin only) */ - proto.createPermissionTemplate = async function (this: { request(method: string, path: string, body?: unknown): Promise }, data: CreateTemplateRequest): Promise { - return this.request('POST', '/api/v1/permission-templates', data); - }; - - /** Delete a permission template (admin only) */ - proto.deletePermissionTemplate = async function (this: { request(method: string, path: string, body?: unknown): Promise }, id: string): Promise { - await this.request('DELETE', `/api/v1/permission-templates/${id}`); - }; - - /** Apply permission template to accounts (admin only) */ - proto.applyPermissionTemplate = async function (this: { request(method: string, path: string, body?: unknown): Promise }, templateId: string, accountIds: string[]): Promise<{ ok: boolean; applied_count: number }> { - return this.request<{ ok: boolean; applied_count: number }>('POST', `/api/v1/permission-templates/${templateId}/apply`, { account_ids: accountIds }); - }; -} diff --git a/desktop/src/lib/saas-client.ts b/desktop/src/lib/saas-client.ts index dc56ecf..92b69e2 100644 --- a/desktop/src/lib/saas-client.ts +++ b/desktop/src/lib/saas-client.ts @@ -17,7 +17,6 @@ * - saas-errors.ts — SaaSApiError class * - saas-session.ts — session persistence (load/save/clear) * - saas-auth.ts — login/register/TOTP methods (mixin) - * - saas-admin.ts — admin panel API methods (mixin) * - saas-relay.ts — relay tasks, chat completion, usage (mixin) * - saas-prompt.ts — prompt OTA methods (mixin) * - saas-telemetry.ts — telemetry reporting methods (mixin) @@ -96,26 +95,6 @@ import type { SaaSErrorResponse, RelayTaskInfo, UsageStats, - ProviderInfo, - CreateProviderRequest, - UpdateProviderRequest, - ModelInfo, - CreateModelRequest, - UpdateModelRequest, - AccountApiKeyInfo, - CreateApiKeyRequest, - AccountPublic, - UpdateAccountRequest, - PaginatedResponse, - TokenInfo, - CreateTokenRequest, - OperationLogInfo, - DashboardStats, - RoleInfo, - CreateRoleRequest, - UpdateRoleRequest, - PermissionTemplate, - CreateTemplateRequest, PromptCheckResult, PromptTemplateInfo, PromptVersionInfo, @@ -128,7 +107,7 @@ import { createLogger } from './logger'; const saasLog = createLogger('saas-client'); import { installAuthMethods } from './saas-auth'; -import { installAdminMethods } from './saas-admin'; + import { installRelayMethods } from './saas-relay'; import { installPromptMethods } from './saas-prompt'; import { installTelemetryMethods } from './saas-telemetry'; @@ -140,6 +119,25 @@ export class SaaSClient { private baseUrl: string; private token: string | null = null; + /** + * Refresh mutex: shared Promise to prevent concurrent token refresh. + * When multiple requests hit 401 simultaneously, they all await the same + * refresh Promise instead of triggering N parallel refresh calls. + */ + private _refreshPromise: Promise | null = null; + + /** + * Thread-safe token refresh — coalesces concurrent refresh attempts into one. + * First caller triggers the actual refresh; subsequent callers await the same Promise. + */ + async refreshMutex(): Promise { + if (this._refreshPromise) return this._refreshPromise; + this._refreshPromise = this.refreshToken().finally(() => { + this._refreshPromise = null; + }); + return this._refreshPromise; + } + constructor(baseUrl: string) { this.baseUrl = baseUrl.replace(/\/+$/, ''); } @@ -237,7 +235,7 @@ export class SaaSClient { // 401: 尝试刷新 Token 后重试 (防止递归) if (response.status === 401 && !this._isAuthEndpoint(path) && !_isRefreshRetry) { try { - const newToken = await this.refreshToken(); + const newToken = await this.refreshMutex(); if (newToken) { return this.request(method, path, body, timeoutMs, true); } @@ -394,7 +392,7 @@ export class SaaSClient { * Used for template selection during onboarding. */ async fetchAvailableTemplates(): Promise { - return this.request('GET', '/agent-templates/available'); + return this.request('GET', '/api/v1/agent-templates/available'); } /** @@ -402,13 +400,12 @@ export class SaaSClient { * Returns all fields needed to create an agent from template. */ async fetchTemplateFull(id: string): Promise { - return this.request('GET', `/agent-templates/${id}/full`); + return this.request('GET', `/api/v1/agent-templates/${id}/full`); } } // === Install mixin methods === installAuthMethods(SaaSClient); -installAdminMethods(SaaSClient); installRelayMethods(SaaSClient); installPromptMethods(SaaSClient); installTelemetryMethods(SaaSClient); @@ -429,57 +426,6 @@ export interface SaaSClient { verifyTotp(code: string): Promise; disableTotp(password: string): Promise; - // --- Admin: Providers (saas-admin.ts) --- - listProviders(): Promise; - getProvider(id: string): Promise; - createProvider(data: CreateProviderRequest): Promise; - updateProvider(id: string, data: UpdateProviderRequest): Promise; - deleteProvider(id: string): Promise; - - // --- Admin: Models (saas-admin.ts) --- - listModelsAdmin(providerId?: string): Promise; - getModel(id: string): Promise; - createModel(data: CreateModelRequest): Promise; - updateModel(id: string, data: UpdateModelRequest): Promise; - deleteModel(id: string): Promise; - - // --- Admin: API Keys (saas-admin.ts) --- - listApiKeys(providerId?: string): Promise; - createApiKey(data: CreateApiKeyRequest): Promise; - rotateApiKey(id: string, newKeyValue: string): Promise; - revokeApiKey(id: string): Promise; - - // --- Admin: Accounts (saas-admin.ts) --- - listAccounts(params?: { page?: number; page_size?: number; role?: string; status?: string; search?: string }): Promise>; - getAccount(id: string): Promise; - updateAccount(id: string, data: UpdateAccountRequest): Promise; - updateAccountStatus(id: string, status: 'active' | 'disabled' | 'suspended'): Promise; - - // --- Admin: Tokens (saas-admin.ts) --- - listTokens(): Promise; - createToken(data: CreateTokenRequest): Promise; - revokeToken(id: string): Promise; - - // --- Admin: Logs (saas-admin.ts) --- - listOperationLogs(params?: { page?: number; page_size?: number }): Promise; - - // --- Admin: Dashboard (saas-admin.ts) --- - getDashboardStats(): Promise; - - // --- Admin: Roles (saas-admin.ts) --- - listRoles(): Promise; - getRole(id: string): Promise; - createRole(data: CreateRoleRequest): Promise; - updateRole(id: string, data: UpdateRoleRequest): Promise; - deleteRole(id: string): Promise; - - // --- Admin: Permission Templates (saas-admin.ts) --- - listPermissionTemplates(): Promise; - getPermissionTemplate(id: string): Promise; - createPermissionTemplate(data: CreateTemplateRequest): Promise; - deletePermissionTemplate(id: string): Promise; - applyPermissionTemplate(templateId: string, accountIds: string[]): Promise<{ ok: boolean; applied_count: number }>; - // --- Relay (saas-relay.ts) --- listRelayTasks(query?: { status?: string; page?: number; page_size?: number }): Promise; getRelayTask(taskId: string): Promise; diff --git a/desktop/src/lib/saas-relay.ts b/desktop/src/lib/saas-relay.ts index 2d481fe..6a0322b 100644 --- a/desktop/src/lib/saas-relay.ts +++ b/desktop/src/lib/saas-relay.ts @@ -55,6 +55,7 @@ export function installRelayMethods(ClientClass: { prototype: any }): void { _serverReachable: boolean; _isAuthEndpoint(path: string): boolean; refreshToken(): Promise; + refreshMutex(): Promise; }, body: unknown, signal?: AbortSignal, @@ -87,7 +88,7 @@ export function installRelayMethods(ClientClass: { prototype: any }): void { // On 401, attempt token refresh once if (response.status === 401 && attempt === 0 && !this._isAuthEndpoint('/api/v1/relay/chat/completions')) { try { - const newToken = await this.refreshToken(); + const newToken = await this.refreshMutex(); if (newToken) continue; // Retry with refreshed token } catch (e) { logger.debug('Token refresh failed', { error: e }); diff --git a/desktop/src/lib/secure-storage.ts b/desktop/src/lib/secure-storage.ts index 82cf598..ff0bd97 100644 --- a/desktop/src/lib/secure-storage.ts +++ b/desktop/src/lib/secure-storage.ts @@ -299,36 +299,6 @@ function readLocalStorageBackup(key: string): string | null { } } -/** - * Synchronous versions for compatibility with existing code - * These use localStorage only and are provided for gradual migration - */ -export const secureStorageSync = { - /** - * Synchronously get a value from localStorage (for migration only) - * @deprecated Use async secureStorage.get() instead - */ - get(key: string): string | null { - return readLocalStorageBackup(key); - }, - - /** - * Synchronously set a value in localStorage (for migration only) - * @deprecated Use async secureStorage.set() instead - */ - set(key: string, value: string): void { - writeLocalStorageBackup(key, value); - }, - - /** - * Synchronously delete a value from localStorage (for migration only) - * @deprecated Use async secureStorage.delete() instead - */ - delete(key: string): void { - clearLocalStorageBackup(key); - }, -}; - // === Device Keys Secure Storage === /** diff --git a/desktop/src/lib/security-index.ts b/desktop/src/lib/security-index.ts index fc91cb9..1143cdb 100644 --- a/desktop/src/lib/security-index.ts +++ b/desktop/src/lib/security-index.ts @@ -47,7 +47,6 @@ export type { EncryptedData } from './crypto-utils'; // Re-export secure storage export { secureStorage, - secureStorageSync, isSecureStorageAvailable, storeDeviceKeys, getDeviceKeys, diff --git a/desktop/src/store/chat/artifactStore.ts b/desktop/src/store/chat/artifactStore.ts new file mode 100644 index 0000000..bd9747c --- /dev/null +++ b/desktop/src/store/chat/artifactStore.ts @@ -0,0 +1,54 @@ +/** + * ArtifactStore — manages the artifact panel state. + * + * Extracted from chatStore.ts as part of the structured refactor. + * This store has zero external dependencies — the simplest slice to extract. + * + * @see docs/superpowers/specs/2026-04-02-chatstore-refactor-design.md §3.5 + */ + +import { create } from 'zustand'; +import type { ArtifactFile } from '../../components/ai/ArtifactPanel'; + +// --------------------------------------------------------------------------- +// State interface +// --------------------------------------------------------------------------- + +export interface ArtifactState { + /** All artifacts generated in the current session */ + artifacts: ArtifactFile[]; + /** Currently selected artifact ID */ + selectedArtifactId: string | null; + /** Whether the artifact panel is open */ + artifactPanelOpen: boolean; + + // Actions + addArtifact: (artifact: ArtifactFile) => void; + selectArtifact: (id: string | null) => void; + setArtifactPanelOpen: (open: boolean) => void; + clearArtifacts: () => void; +} + +// --------------------------------------------------------------------------- +// Store +// --------------------------------------------------------------------------- + +export const useArtifactStore = create()((set) => ({ + artifacts: [], + selectedArtifactId: null, + artifactPanelOpen: false, + + addArtifact: (artifact: ArtifactFile) => + set((state) => ({ + artifacts: [...state.artifacts, artifact], + selectedArtifactId: artifact.id, + artifactPanelOpen: true, + })), + + selectArtifact: (id: string | null) => set({ selectedArtifactId: id }), + + setArtifactPanelOpen: (open: boolean) => set({ artifactPanelOpen: open }), + + clearArtifacts: () => + set({ artifacts: [], selectedArtifactId: null, artifactPanelOpen: false }), +})); diff --git a/desktop/src/store/chat/conversationStore.ts b/desktop/src/store/chat/conversationStore.ts new file mode 100644 index 0000000..808f60e --- /dev/null +++ b/desktop/src/store/chat/conversationStore.ts @@ -0,0 +1,368 @@ +/** + * ConversationStore — manages conversation lifecycle, agent switching, and persistence. + * + * Extracted from chatStore.ts as part of the structured refactor. + * Responsible for: conversation CRUD, agent list/sync, session/model state. + * + * @see docs/superpowers/specs/2026-04-02-chatstore-refactor-design.md §3.2 + */ + +import { create } from 'zustand'; +import { persist } from 'zustand/middleware'; +import { generateRandomString } from '../lib/crypto-utils'; +import { createLogger } from '../lib/logger'; +import type { Message } from './chatStore'; + +const log = createLogger('ConversationStore'); + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface Conversation { + id: string; + title: string; + messages: Message[]; + sessionKey: string | null; + agentId: string | null; + createdAt: Date; + updatedAt: Date; +} + +export interface Agent { + id: string; + name: string; + icon: string; + color: string; + lastMessage: string; + time: string; +} + +export interface AgentProfileLike { + id: string; + name: string; + nickname?: string; + role?: string; +} + +// Re-export Message for internal use (avoids circular imports during migration) +export type { Message }; + +// --------------------------------------------------------------------------- +// State interface +// --------------------------------------------------------------------------- + +export interface ConversationState { + conversations: Conversation[]; + currentConversationId: string | null; + agents: Agent[]; + currentAgent: Agent | null; + sessionKey: string | null; + currentModel: string; + + // Actions + newConversation: (currentMessages: Message[]) => Conversation[]; + switchConversation: (id: string, currentMessages: Message[]) => { + conversations: Conversation[]; + messages: Message[]; + sessionKey: string | null; + currentAgent: Agent; + currentConversationId: string; + isStreaming: boolean; + } | null; + deleteConversation: (id: string, currentConversationId: string | null) => { + conversations: Conversation[]; + resetMessages: boolean; + }; + setCurrentAgent: (agent: Agent, currentMessages: Message[]) => { + conversations: Conversation[]; + currentAgent: Agent; + messages: Message[]; + sessionKey: string | null; + isStreaming: boolean; + currentConversationId: string | null; + }; + syncAgents: (profiles: AgentProfileLike[]) => { + agents: Agent[]; + currentAgent: Agent; + }; + setCurrentModel: (model: string) => void; + upsertActiveConversation: (currentMessages: Message[]) => Conversation[]; + getCurrentConversationId: () => string | null; + getCurrentAgent: () => Agent | null; + getSessionKey: () => string | null; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function generateConvId(): string { + return `conv_${Date.now()}_${generateRandomString(4)}`; +} + +function deriveTitle(messages: Message[]): string { + const firstUser = messages.find(m => m.role === 'user'); + if (firstUser) { + const text = firstUser.content.trim(); + return text.length > 30 ? text.slice(0, 30) + '...' : text; + } + return '新对话'; +} + +const DEFAULT_AGENT: Agent = { + id: '1', + name: 'ZCLAW', + icon: '🦞', + color: 'bg-gradient-to-br from-orange-500 to-red-500', + lastMessage: '发送消息开始对话', + time: '', +}; + +export { DEFAULT_AGENT }; + +export function toChatAgent(profile: AgentProfileLike): Agent { + return { + id: profile.id, + name: profile.name, + icon: profile.nickname?.slice(0, 1) || '🦞', + color: 'bg-gradient-to-br from-orange-500 to-red-500', + lastMessage: profile.role || '新分身', + time: '', + }; +} + +export function resolveConversationAgentId(agent: Agent | null): string | null { + if (!agent || agent.id === DEFAULT_AGENT.id) { + return null; + } + return agent.id; +} + +export function resolveGatewayAgentId(agent: Agent | null): string | undefined { + if (!agent || agent.id === DEFAULT_AGENT.id || agent.id.startsWith('clone_')) { + return undefined; + } + return agent.id; +} + +export function resolveAgentForConversation(agentId: string | null, agents: Agent[]): Agent { + if (!agentId) { + return DEFAULT_AGENT; + } + return agents.find((agent) => agent.id === agentId) || DEFAULT_AGENT; +} + +function upsertActiveConversation( + conversations: Conversation[], + messages: Message[], + sessionKey: string | null, + currentConversationId: string | null, + currentAgent: Agent | null, +): Conversation[] { + if (messages.length === 0) { + return conversations; + } + + const currentId = currentConversationId || generateConvId(); + const existingIdx = conversations.findIndex((conv) => conv.id === currentId); + const nextConversation: Conversation = { + id: currentId, + title: deriveTitle(messages), + messages: [...messages], + sessionKey, + agentId: resolveConversationAgentId(currentAgent), + createdAt: existingIdx >= 0 ? conversations[existingIdx].createdAt : new Date(), + updatedAt: new Date(), + }; + + if (existingIdx >= 0) { + const updated = [...conversations]; + updated[existingIdx] = nextConversation; + return updated; + } + + return [nextConversation, ...conversations]; +} + +// --------------------------------------------------------------------------- +// Store +// --------------------------------------------------------------------------- + +export const useConversationStore = create()( + persist( + (set, get) => ({ + conversations: [], + currentConversationId: null, + agents: [DEFAULT_AGENT], + currentAgent: DEFAULT_AGENT, + sessionKey: null, + currentModel: 'glm-4-flash', + + newConversation: (currentMessages: Message[]) => { + const state = get(); + const conversations = upsertActiveConversation( + [...state.conversations], currentMessages, state.sessionKey, + state.currentConversationId, state.currentAgent, + ); + set({ + conversations, + sessionKey: null, + currentConversationId: null, + }); + return conversations; + }, + + switchConversation: (id: string, currentMessages: Message[]) => { + const state = get(); + const conversations = upsertActiveConversation( + [...state.conversations], currentMessages, state.sessionKey, + state.currentConversationId, state.currentAgent, + ); + + const target = conversations.find(c => c.id === id); + if (target) { + set({ + conversations, + currentAgent: resolveAgentForConversation(target.agentId, state.agents), + currentConversationId: target.id, + }); + return { + conversations, + messages: [...target.messages], + sessionKey: target.sessionKey, + currentAgent: resolveAgentForConversation(target.agentId, state.agents), + currentConversationId: target.id, + isStreaming: false, + }; + } + return null; + }, + + deleteConversation: (id: string, currentConversationId: string | null) => { + const state = get(); + const conversations = state.conversations.filter(c => c.id !== id); + const resetMessages = currentConversationId === id; + if (resetMessages) { + set({ conversations, currentConversationId: null, sessionKey: null }); + } else { + set({ conversations }); + } + return { conversations, resetMessages }; + }, + + setCurrentAgent: (agent: Agent, currentMessages: Message[]) => { + const state = get(); + if (state.currentAgent?.id === agent.id) { + set({ currentAgent: agent }); + return { + conversations: state.conversations, + currentAgent: agent, + messages: currentMessages, + sessionKey: state.sessionKey, + isStreaming: false, + currentConversationId: state.currentConversationId, + }; + } + + const conversations = upsertActiveConversation( + [...state.conversations], currentMessages, state.sessionKey, + state.currentConversationId, state.currentAgent, + ); + + const agentConversation = conversations.find(c => + c.agentId === agent.id || + (agent.id === DEFAULT_AGENT.id && c.agentId === null) + ); + + if (agentConversation) { + set({ + conversations, + currentAgent: agent, + currentConversationId: agentConversation.id, + }); + return { + conversations, + currentAgent: agent, + messages: [...agentConversation.messages], + sessionKey: agentConversation.sessionKey, + isStreaming: false, + currentConversationId: agentConversation.id, + }; + } + + set({ + conversations, + currentAgent: agent, + sessionKey: null, + currentConversationId: null, + }); + return { + conversations, + currentAgent: agent, + messages: [], + sessionKey: null, + isStreaming: false, + currentConversationId: null, + }; + }, + + syncAgents: (profiles: AgentProfileLike[]) => { + const state = get(); + const cloneAgents = profiles.length > 0 ? profiles.map(toChatAgent) : []; + const agents = cloneAgents.length > 0 + ? [DEFAULT_AGENT, ...cloneAgents] + : [DEFAULT_AGENT]; + const currentAgent = state.currentConversationId + ? resolveAgentForConversation( + state.conversations.find((conv) => conv.id === state.currentConversationId)?.agentId || null, + agents + ) + : state.currentAgent + ? agents.find((a) => a.id === state.currentAgent?.id) || agents[0] + : agents[0]; + + set({ agents, currentAgent }); + return { agents, currentAgent }; + }, + + setCurrentModel: (model: string) => set({ currentModel: model }), + + upsertActiveConversation: (currentMessages: Message[]) => { + const state = get(); + const conversations = upsertActiveConversation( + [...state.conversations], currentMessages, state.sessionKey, + state.currentConversationId, state.currentAgent, + ); + set({ conversations }); + return conversations; + }, + + getCurrentConversationId: () => get().currentConversationId, + getCurrentAgent: () => get().currentAgent, + getSessionKey: () => get().sessionKey, +}), + { + name: 'zclaw-conversation-storage', + partialize: (state) => ({ + conversations: state.conversations, + currentModel: state.currentModel, + currentAgentId: state.currentAgent?.id, + currentConversationId: state.currentConversationId, + }), + onRehydrateStorage: () => (state) => { + if (state?.conversations) { + for (const conv of state.conversations) { + conv.createdAt = new Date(conv.createdAt); + conv.updatedAt = new Date(conv.updatedAt); + for (const msg of conv.messages) { + msg.timestamp = new Date(msg.timestamp); + msg.streaming = false; + msg.optimistic = false; + } + } + } + }, + }, + ), +); diff --git a/desktop/src/store/chatStore.ts b/desktop/src/store/chatStore.ts index 80f4b69..06b7182 100644 --- a/desktop/src/store/chatStore.ts +++ b/desktop/src/store/chatStore.ts @@ -103,10 +103,6 @@ interface ChatState { chatMode: ChatModeType; // Follow-up suggestions suggestions: string[]; - // Artifacts (DeerFlow-inspired) - artifacts: import('../components/ai/ArtifactPanel').ArtifactFile[]; - selectedArtifactId: string | null; - artifactPanelOpen: boolean; addMessage: (message: Message) => void; updateMessage: (id: string, updates: Partial) => void; @@ -128,11 +124,6 @@ interface ChatState { setSuggestions: (suggestions: string[]) => void; addSubtask: (messageId: string, task: Subtask) => void; updateSubtask: (messageId: string, taskId: string, updates: Partial) => void; - // Artifact management (DeerFlow-inspired) - addArtifact: (artifact: import('../components/ai/ArtifactPanel').ArtifactFile) => void; - selectArtifact: (id: string | null) => void; - setArtifactPanelOpen: (open: boolean) => void; - clearArtifacts: () => void; } function generateConvId(): string { @@ -271,10 +262,6 @@ export const useChatStore = create()( totalOutputTokens: 0, chatMode: 'thinking' as ChatModeType, suggestions: [], - artifacts: [], - selectedArtifactId: null, - artifactPanelOpen: false, - addMessage: (message: Message) => set((state) => ({ messages: [...state.messages, message] })), @@ -401,6 +388,10 @@ export const useChatStore = create()( }, sendMessage: async (content: string) => { + // Concurrency guard: prevent rapid double-click bypassing UI-level isStreaming check. + // React re-render is async — two clicks within the same frame both read isStreaming=false. + if (get().isStreaming) return; + const { addMessage, currentAgent, sessionKey } = get(); // Clear stale suggestions when user sends a new message set({ suggestions: [] }); @@ -436,27 +427,10 @@ export const useChatStore = create()( // Context compaction is handled by the kernel (AgentLoop with_compaction_threshold). // Frontend no longer performs duplicate compaction — see crates/zclaw-runtime/src/compaction.rs. - - // Build memory-enhanced content using layered context (L0/L1/L2) - let enhancedContent = content; - try { - const contextResult = await intelligenceClient.memory.buildContext( - agentId, - content, - 500, // token budget for memory context - ); - if (contextResult.systemPromptAddition) { - const systemPrompt = await intelligenceClient.identity.buildPrompt( - agentId, - contextResult.systemPromptAddition, - ); - if (systemPrompt) { - enhancedContent = `\n${systemPrompt}\n\n\n${content}`; - } - } - } catch (err) { - log.warn('Memory context build failed, proceeding without:', err); - } + // Memory context injection is handled by backend MemoryMiddleware (before_completion), + // which injects relevant memories into the system prompt. Frontend must NOT duplicate + // this by embedding old conversation memories into the user message content — that causes + // context leaking (old conversations appearing in new chat thinking/output). // Add user message (original content for display) // Mark as optimistic -- will be cleared when server confirms via onComplete @@ -504,7 +478,7 @@ export const useChatStore = create()( // Try streaming first (ZCLAW WebSocket) const result = await client.chatStream( - enhancedContent, + content, { onDelta: (delta: string) => { // Update message content directly (works for both KernelClient and GatewayClient) @@ -516,6 +490,15 @@ export const useChatStore = create()( ), })); }, + onThinkingDelta: (delta: string) => { + set((s) => ({ + messages: s.messages.map((m) => + m.id === assistantId + ? { ...m, thinkingContent: (m.thinkingContent || '') + delta } + : m + ), + })); + }, onTool: (tool: string, input: string, output: string) => { const step: ToolCallStep = { id: `step_${Date.now()}_${generateRandomString(4)}`, @@ -732,20 +715,6 @@ export const useChatStore = create()( ), })), - // Artifact management (DeerFlow-inspired) - addArtifact: (artifact) => - set((state) => ({ - artifacts: [...state.artifacts, artifact], - selectedArtifactId: artifact.id, - artifactPanelOpen: true, - })), - - selectArtifact: (id) => set({ selectedArtifactId: id }), - - setArtifactPanelOpen: (open) => set({ artifactPanelOpen: open }), - - clearArtifacts: () => set({ artifacts: [], selectedArtifactId: null, artifactPanelOpen: false }), - initStreamListener: () => { const client = getClient(); diff --git a/desktop/src/store/classroomStore.ts b/desktop/src/store/classroomStore.ts new file mode 100644 index 0000000..31673ab --- /dev/null +++ b/desktop/src/store/classroomStore.ts @@ -0,0 +1,223 @@ +/** + * Classroom Store + * + * Zustand store for classroom generation, chat messages, + * and active classroom data. Uses Tauri invoke for backend calls. + */ + +import { create } from 'zustand'; +import { invoke } from '@tauri-apps/api/core'; +import { listen } from '@tauri-apps/api/event'; +import type { + Classroom, + ClassroomChatMessage, +} from '../types/classroom'; +import { createLogger } from '../lib/logger'; + +const log = createLogger('ClassroomStore'); + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface GenerationRequest { + topic: string; + document?: string; + style?: string; + level?: string; + targetDurationMinutes?: number; + sceneCount?: number; + customInstructions?: string; + language?: string; +} + +export interface GenerationResult { + classroomId: string; +} + +export interface GenerationProgressEvent { + topic: string; + stage: string; + progress: number; + activity: string; +} + +// --------------------------------------------------------------------------- +// Store interface +// --------------------------------------------------------------------------- + +export interface ClassroomState { + /** Currently generating classroom */ + generating: boolean; + /** Generation progress stage */ + progressStage: string | null; + progressPercent: number; + progressActivity: string; + /** Topic being generated (used for cancel) */ + generatingTopic: string | null; + /** The active classroom */ + activeClassroom: Classroom | null; + /** Whether the ClassroomPlayer overlay is open */ + classroomOpen: boolean; + /** Chat messages for the active classroom */ + chatMessages: ClassroomChatMessage[]; + /** Whether chat is loading */ + chatLoading: boolean; + /** Generation error message */ + error: string | null; +} + +export interface ClassroomActions { + startGeneration: (request: GenerationRequest) => Promise; + cancelGeneration: () => void; + loadClassroom: (id: string) => Promise; + setActiveClassroom: (classroom: Classroom) => void; + openClassroom: () => void; + closeClassroom: () => void; + sendChatMessage: (message: string, sceneContext?: string) => Promise; + clearError: () => void; + reset: () => void; +} + +export type ClassroomStore = ClassroomState & ClassroomActions; + +// --------------------------------------------------------------------------- +// Store +// --------------------------------------------------------------------------- + +export const useClassroomStore = create()((set, get) => ({ + generating: false, + progressStage: null, + progressPercent: 0, + progressActivity: '', + generatingTopic: null, + activeClassroom: null, + classroomOpen: false, + chatMessages: [], + chatLoading: false, + error: null, + + startGeneration: async (request) => { + set({ + generating: true, + progressStage: 'agent_profiles', + progressPercent: 0, + progressActivity: 'Starting generation...', + generatingTopic: request.topic, + error: null, + }); + + // Listen for progress events from Rust + const unlisten = await listen('classroom:progress', (event) => { + const { stage, progress, activity } = event.payload; + set({ + progressStage: stage, + progressPercent: progress, + progressActivity: activity, + }); + }); + + try { + const result = await invoke('classroom_generate', { request }); + set({ generating: false }); + await get().loadClassroom(result.classroomId); + set({ classroomOpen: true }); + return result.classroomId; + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log.error('Generation failed', { error: msg }); + set({ generating: false, error: msg }); + throw e; + } finally { + unlisten(); + } + }, + + cancelGeneration: () => { + const topic = get().generatingTopic; + if (topic) { + invoke('classroom_cancel_generation', { topic }).catch(() => {}); + } + set({ generating: false, generatingTopic: null }); + }, + + loadClassroom: async (id) => { + try { + const classroom = await invoke('classroom_get', { classroomId: id }); + set({ activeClassroom: classroom, chatMessages: [] }); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log.error('Failed to load classroom', { error: msg }); + set({ error: msg }); + } + }, + + setActiveClassroom: (classroom) => { + set({ activeClassroom: classroom, chatMessages: [], classroomOpen: true }); + }, + + openClassroom: () => { + set({ classroomOpen: true }); + }, + + closeClassroom: () => { + set({ classroomOpen: false }); + }, + + sendChatMessage: async (message, sceneContext) => { + const classroom = get().activeClassroom; + if (!classroom) { + log.error('No active classroom'); + return; + } + + // Create a local user message for display + const userMsg: ClassroomChatMessage = { + id: `user-${Date.now()}`, + agentId: 'user', + agentName: '你', + agentAvatar: '👤', + content: message, + timestamp: Date.now(), + role: 'user', + color: '#3b82f6', + }; + + set((state) => ({ + chatMessages: [...state.chatMessages, userMsg], + chatLoading: true, + })); + + try { + const responses = await invoke('classroom_chat', { + request: { + classroomId: classroom.id, + userMessage: message, + sceneContext: sceneContext ?? null, + }, + }); + set((state) => ({ + chatMessages: [...state.chatMessages, ...responses], + chatLoading: false, + })); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + log.error('Chat failed', { error: msg }); + set({ chatLoading: false }); + } + }, + + clearError: () => set({ error: null }), + + reset: () => set({ + generating: false, + progressStage: null, + progressPercent: 0, + progressActivity: '', + activeClassroom: null, + classroomOpen: false, + chatMessages: [], + chatLoading: false, + error: null, + }), +})); diff --git a/desktop/src/store/index.ts b/desktop/src/store/index.ts index 75d0083..2218d64 100644 --- a/desktop/src/store/index.ts +++ b/desktop/src/store/index.ts @@ -55,6 +55,17 @@ export type { SessionOptions, } from '../components/BrowserHand/templates/types'; +// === Classroom Store === +export { useClassroomStore } from './classroomStore'; +export type { + ClassroomState, + ClassroomActions, + ClassroomStore, + GenerationRequest, + GenerationResult, + GenerationProgressEvent, +} from './classroomStore'; + // === Store Initialization === import { getClient } from './connectionStore'; diff --git a/desktop/src/store/saasStore.ts b/desktop/src/store/saasStore.ts index a34dac7..3be4bfa 100644 --- a/desktop/src/store/saasStore.ts +++ b/desktop/src/store/saasStore.ts @@ -536,6 +536,27 @@ export const useSaaSStore = create((set, get) => { // Update last sync timestamp localStorage.setItem(lastSyncKey, result.pulled_at); log.info(`Synced ${result.configs.length} config items from SaaS`); + + // Propagate Kernel-relevant configs to Rust backend + const kernelCategories = ['agent', 'llm']; + const kernelConfigs = result.configs.filter( + (c) => kernelCategories.includes(c.category) && c.value !== null + ); + if (kernelConfigs.length > 0) { + try { + const { invoke } = await import('@tauri-apps/api/core'); + await invoke('kernel_apply_saas_config', { + configs: kernelConfigs.map((c) => ({ + category: c.category, + key: c.key, + value: c.value, + })), + }); + log.info(`Propagated ${kernelConfigs.length} Kernel configs to Rust backend`); + } catch (invokeErr: unknown) { + log.warn('Failed to propagate configs to Kernel (non-fatal):', invokeErr); + } + } } catch (err: unknown) { log.warn('Failed to sync config from SaaS:', err); } diff --git a/desktop/src/types/chat.ts b/desktop/src/types/chat.ts new file mode 100644 index 0000000..d4ea95c --- /dev/null +++ b/desktop/src/types/chat.ts @@ -0,0 +1,133 @@ +/** + * Unified chat types for the ZCLAW desktop chat system. + * + * This module consolidates types previously scattered across + * chatStore.ts, session.ts, and component-level type exports. + */ + +// --- Re-export from component modules for backward compat --- +export type { ChatModeType, ChatModeConfig } from '../components/ai/ChatMode'; +export { CHAT_MODES } from '../components/ai/ChatMode'; +export type { Subtask } from '../components/ai/TaskProgress'; +export type { ToolCallStep } from '../components/ai/ToolCallChain'; +export type { ArtifactFile } from '../components/ai/ArtifactPanel'; + +// --- Core chat types --- + +export interface MessageFile { + name: string; + path?: string; + size?: number; + type?: string; +} + +export interface CodeBlock { + language?: string; + filename?: string; + content?: string; +} + +/** + * Unified message type for all chat messages. + * Supersedes both ChatStore.Message (6 roles) and SessionMessage (3 roles). + */ +export interface ChatMessage { + id: string; + role: 'user' | 'assistant' | 'tool' | 'hand' | 'workflow' | 'system'; + content: string; + timestamp: Date; + streaming?: boolean; + optimistic?: boolean; + runId?: string; + + // Thinking/reasoning + thinkingContent?: string; + + // Error & retry + error?: string; + /** Preserved original content before error overlay, used for retry */ + originalContent?: string; + + // Tool call chain + toolSteps?: import('../components/ai/ToolCallChain').ToolCallStep[]; + toolName?: string; + toolInput?: string; + toolOutput?: string; + + // Hand event fields + handName?: string; + handStatus?: string; + handResult?: unknown; + + // Workflow event fields + workflowId?: string; + workflowStep?: string; + workflowStatus?: string; + workflowResult?: unknown; + + // Sub-agent task tracking + subtasks?: import('../components/ai/TaskProgress').Subtask[]; + + // Attachments + files?: MessageFile[]; + codeBlocks?: CodeBlock[]; + + // Metadata + metadata?: { + inputTokens?: number; + outputTokens?: number; + model?: string; + }; +} + +/** + * A conversation container with messages, session key, and agent binding. + */ +export interface Conversation { + id: string; + title: string; + messages: ChatMessage[]; + sessionKey: string | null; + agentId: string | null; + createdAt: Date; + updatedAt: Date; +} + +/** + * Lightweight agent representation for the chat UI sidebar. + * Distinct from types/agent.ts Agent (which is a backend entity). + */ +export interface ChatAgent { + id: string; + name: string; + icon: string; + color: string; + lastMessage: string; + time: string; +} + +/** + * Minimal profile shape for agent sync operations. + */ +export interface AgentProfileLike { + id: string; + name: string; + nickname?: string; + role?: string; +} + +/** + * Token usage reported on stream completion. + */ +export interface TokenUsage { + inputTokens: number; + outputTokens: number; +} + +/** + * Context passed when sending a message. + */ +export interface SendMessageContext { + files?: MessageFile[]; + parentMessageId?: string; +} diff --git a/desktop/src/types/classroom.ts b/desktop/src/types/classroom.ts new file mode 100644 index 0000000..850910d --- /dev/null +++ b/desktop/src/types/classroom.ts @@ -0,0 +1,181 @@ +/** + * Classroom Generation Types + * + * Mirror of Rust `zclaw-kernel::generation` module types. + * Used by classroom player, hooks, and store. + */ + +// --- Agent Types --- + +export enum AgentRole { + Teacher = 'teacher', + Assistant = 'assistant', + Student = 'student', +} + +export interface AgentProfile { + id: string; + name: string; + role: AgentRole; + persona: string; + avatar: string; + color: string; + allowedActions: string[]; + priority: number; +} + +// --- Scene Types --- + +export enum SceneType { + Slide = 'slide', + Quiz = 'quiz', + Interactive = 'interactive', + Pbl = 'pbl', + Discussion = 'discussion', + Media = 'media', + Text = 'text', +} + +export enum GenerationStage { + AgentProfiles = 'agent_profiles', + Outline = 'outline', + Scene = 'scene', + Complete = 'complete', +} + +// --- Scene Actions --- + +export type SceneAction = + | { type: 'speech'; text: string; agentRole: string } + | { type: 'whiteboard_draw_text'; x: number; y: number; text: string; fontSize?: number; color?: string } + | { type: 'whiteboard_draw_shape'; shape: string; x: number; y: number; width: number; height: number; fill?: string } + | { type: 'whiteboard_draw_chart'; chartType: string; data: unknown; x: number; y: number; width: number; height: number } + | { type: 'whiteboard_draw_latex'; latex: string; x: number; y: number } + | { type: 'whiteboard_clear' } + | { type: 'slideshow_spotlight'; elementId: string } + | { type: 'slideshow_next' } + | { type: 'quiz_show'; quizId: string } + | { type: 'discussion'; topic: string; durationSeconds?: number }; + +// --- Content Structures --- + +export interface SceneContent { + title: string; + sceneType: SceneType; + content: Record; + actions: SceneAction[]; + durationSeconds: number; + notes?: string; +} + +export interface OutlineItem { + id: string; + title: string; + description: string; + sceneType: SceneType; + keyPoints: string[]; + durationSeconds: number; + dependencies: string[]; +} + +export interface GeneratedScene { + id: string; + outlineId: string; + content: SceneContent; + order: number; +} + +// --- Teaching Config --- + +export enum TeachingStyle { + Lecture = 'lecture', + Discussion = 'discussion', + Pbl = 'pbl', + Flipped = 'flipped', + Socratic = 'socratic', +} + +export enum DifficultyLevel { + Beginner = 'beginner', + Intermediate = 'intermediate', + Advanced = 'advanced', + Expert = 'expert', +} + +// --- Classroom --- + +export interface ClassroomMetadata { + generatedAt: number; + sourceDocument?: string; + model?: string; + version: string; + custom: Record; +} + +export interface Classroom { + id: string; + title: string; + description: string; + topic: string; + style: TeachingStyle; + level: DifficultyLevel; + totalDuration: number; + objectives: string[]; + scenes: GeneratedScene[]; + agents: AgentProfile[]; + metadata: ClassroomMetadata; + outline?: string; +} + +// --- Generation Request --- + +export interface GenerationRequest { + topic: string; + document?: string; + style: TeachingStyle; + level: DifficultyLevel; + targetDurationMinutes: number; + sceneCount?: number; + customInstructions?: string; + language?: string; +} + +// --- Generation Progress --- + +export interface GenerationProgress { + stage: GenerationStage; + progress: number; + activity: string; + itemsProgress?: [number, number]; + etaSeconds?: number; +} + +// --- Chat Types --- + +export interface ClassroomChatMessage { + id: string; + agentId: string; + agentName: string; + agentAvatar: string; + content: string; + timestamp: number; + role: string; + color: string; +} + +export interface ClassroomChatState { + messages: ClassroomChatMessage[]; + active: boolean; +} + +export interface ClassroomChatRequest { + classroomId: string; + userMessage: string; + agents: AgentProfile[]; + sceneContext?: string; + history: ClassroomChatMessage[]; +} + +export interface ClassroomChatResponse { + responses: ClassroomChatMessage[]; +} diff --git a/desktop/src/types/index.ts b/desktop/src/types/index.ts index 1e8a52d..a3c0e86 100644 --- a/desktop/src/types/index.ts +++ b/desktop/src/types/index.ts @@ -156,3 +156,44 @@ export { filterByStatus, searchAutomationItems, } from './automation'; + +// Classroom Types +export type { + AgentProfile, + SceneContent, + GeneratedScene, + ClassroomMetadata, + Classroom, + GenerationRequest, + GenerationProgress, + ClassroomChatMessage, + ClassroomChatState, + ClassroomChatRequest, + ClassroomChatResponse, + SceneAction, + OutlineItem, +} from './classroom'; + +export { + AgentRole, + SceneType, + GenerationStage, + TeachingStyle, + DifficultyLevel, +} from './classroom'; + +// Chat Types (unified) +export type { + ChatMessage, + Conversation, + ChatAgent, + AgentProfileLike, + TokenUsage, + SendMessageContext, + MessageFile, + CodeBlock, +} from './chat'; + +export { + CHAT_MODES, +} from './chat';