//! Chat commands: send message, streaming chat use std::sync::Arc; use serde::{Deserialize, Serialize}; use tauri::{AppHandle, Emitter, State}; use zclaw_types::AgentId; use super::{validate_agent_id, KernelState, SessionStreamGuard, StreamCancelFlags}; use crate::intelligence::validation::validate_string_length; // --------------------------------------------------------------------------- // Request / Response types // --------------------------------------------------------------------------- /// Chat request #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ChatRequest { pub agent_id: String, pub message: String, } /// Chat response #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ChatResponse { pub content: String, pub input_tokens: u32, pub output_tokens: u32, } /// Streaming chat event for Tauri emission #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase", tag = "type")] pub enum StreamChatEvent { Delta { delta: String }, ThinkingDelta { delta: String }, ToolStart { name: String, input: serde_json::Value }, ToolEnd { name: String, output: serde_json::Value }, IterationStart { iteration: usize, max_iterations: usize }, HandStart { name: String, params: serde_json::Value }, HandEnd { name: String, result: serde_json::Value }, Complete { input_tokens: u32, output_tokens: u32 }, Error { message: String }, } /// Streaming chat request #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct StreamChatRequest { pub agent_id: String, pub session_id: String, pub message: String, /// Enable extended thinking/reasoning #[serde(default)] pub thinking_enabled: Option, /// Reasoning effort level (low/medium/high) #[serde(default)] pub reasoning_effort: Option, /// Enable plan mode #[serde(default)] pub plan_mode: Option, /// Enable sub-agent delegation (Ultra mode only) #[serde(default)] pub subagent_enabled: Option, } // --------------------------------------------------------------------------- // Commands // --------------------------------------------------------------------------- /// Send a message to an agent // @connected #[tauri::command] pub async fn agent_chat( state: State<'_, KernelState>, request: ChatRequest, ) -> Result { validate_agent_id(&request.agent_id)?; validate_string_length(&request.message, "message", 100000) .map_err(|e| format!("Invalid message: {}", e))?; let kernel_lock = state.lock().await; let kernel = kernel_lock.as_ref() .ok_or_else(|| "Kernel not initialized. Call kernel_init first.".to_string())?; let id: AgentId = request.agent_id.parse() .map_err(|_| "Invalid agent ID format".to_string())?; let response = kernel.send_message(&id, request.message) .await .map_err(|e| format!("Chat failed: {}", e))?; Ok(ChatResponse { content: response.content, input_tokens: response.input_tokens, output_tokens: response.output_tokens, }) } /// Send a message to an agent with streaming response /// /// This command initiates a streaming chat session. Events are emitted /// via Tauri's event system with the name "stream:chunk" and include /// the session_id for routing. // @connected #[tauri::command] pub async fn agent_chat_stream( app: AppHandle, state: State<'_, KernelState>, identity_state: State<'_, crate::intelligence::IdentityManagerState>, heartbeat_state: State<'_, crate::intelligence::HeartbeatEngineState>, reflection_state: State<'_, crate::intelligence::ReflectionEngineState>, stream_guard: State<'_, SessionStreamGuard>, cancel_flags: State<'_, StreamCancelFlags>, request: StreamChatRequest, ) -> Result<(), String> { validate_agent_id(&request.agent_id)?; validate_string_length(&request.message, "message", 100000) .map_err(|e| format!("Invalid message: {}", e))?; let id: AgentId = request.agent_id.parse() .map_err(|_| "Invalid agent ID format".to_string())?; let session_id = request.session_id.clone(); let agent_id_str = request.agent_id.clone(); let message = request.message.clone(); // Session-level concurrency guard using atomic flag let session_active = stream_guard .entry(session_id.clone()) .or_insert_with(|| Arc::new(std::sync::atomic::AtomicBool::new(false))); // Atomically set flag from false→true, fail if already true if session_active .compare_exchange(false, true, std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst) .is_err() { tracing::warn!( "[agent_chat_stream] Session {} already has an active stream — rejecting", session_id ); return Err(format!("Session {} already has an active stream", session_id)); } // Prepare cleanup resources for error paths (before spawn takes ownership) let err_cleanup_guard = stream_guard.inner().clone(); let err_cleanup_cancel = cancel_flags.inner().clone(); let err_cleanup_session_id = session_id.clone(); let err_cleanup_flag = Arc::clone(&*session_active); // Register cancellation flag for this session let cancel_flag = cancel_flags .entry(session_id.clone()) .or_insert_with(|| Arc::new(std::sync::atomic::AtomicBool::new(false))); // Ensure flag is reset (in case of stale entry from a previous stream) cancel_flag.store(false, std::sync::atomic::Ordering::SeqCst); let cancel_clone = Arc::clone(&*cancel_flag); let cancel_flags_map: StreamCancelFlags = cancel_flags.inner().clone(); // AUTO-INIT HEARTBEAT { let mut engines = heartbeat_state.lock().await; if !engines.contains_key(&request.agent_id) { let engine = crate::intelligence::heartbeat::HeartbeatEngine::new( request.agent_id.clone(), None, ); engines.insert(request.agent_id.clone(), engine); // Start the engine after insertion via the stored reference if let Some(e) = engines.get(&request.agent_id) { e.start().await; } tracing::info!("[agent_chat_stream] Auto-initialized and started heartbeat for agent: {}", request.agent_id); } } // PRE-CONVERSATION: Build intelligence-enhanced system prompt let enhanced_prompt = crate::intelligence_hooks::pre_conversation_hook( &request.agent_id, &request.message, &identity_state, ).await.unwrap_or_default(); // Get the streaming receiver while holding the lock, then release it let (mut rx, llm_driver) = { let kernel_lock = state.lock().await; let kernel = kernel_lock.as_ref() .ok_or_else(|| { // Cleanup on error: release guard + cancel flag err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst); err_cleanup_guard.remove(&err_cleanup_session_id); err_cleanup_cancel.remove(&err_cleanup_session_id); "Kernel not initialized. Call kernel_init first.".to_string() })?; let driver = Some(kernel.driver()); let prompt_arg = if enhanced_prompt.is_empty() { None } else { Some(enhanced_prompt) }; let session_id_parsed = if session_id.is_empty() { None } else { match uuid::Uuid::parse_str(&session_id) { Ok(uuid) => Some(zclaw_types::SessionId::from_uuid(uuid)), Err(e) => { // Cleanup on error err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst); err_cleanup_guard.remove(&err_cleanup_session_id); err_cleanup_cancel.remove(&err_cleanup_session_id); return Err(format!( "Invalid session_id '{}': {}. Cannot reuse conversation context.", session_id, e )); } } }; // Build chat mode config from request parameters let chat_mode_config = zclaw_kernel::ChatModeConfig { thinking_enabled: request.thinking_enabled, reasoning_effort: request.reasoning_effort.clone(), plan_mode: request.plan_mode, subagent_enabled: request.subagent_enabled, }; let rx = kernel.send_message_stream_with_prompt( &id, message.clone(), prompt_arg, session_id_parsed, Some(chat_mode_config), ) .await .map_err(|e| { // Cleanup on error err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst); err_cleanup_guard.remove(&err_cleanup_session_id); err_cleanup_cancel.remove(&err_cleanup_session_id); format!("Failed to start streaming: {}", e) })?; (rx, driver) }; let hb_state = heartbeat_state.inner().clone(); let rf_state = reflection_state.inner().clone(); // Clone the guard map for cleanup in the spawned task let guard_map: SessionStreamGuard = stream_guard.inner().clone(); // Spawn a task to process stream events. // The session_active flag is cleared when task completes. let guard_clone = Arc::clone(&*session_active); tokio::spawn(async move { use zclaw_runtime::LoopEvent; tracing::debug!("[agent_chat_stream] Starting stream processing for session: {}", session_id); let stream_timeout = tokio::time::Duration::from_secs(300); loop { // Check cancellation flag before each recv if cancel_clone.load(std::sync::atomic::Ordering::SeqCst) { tracing::info!("[agent_chat_stream] Stream cancelled for session: {}", session_id); let _ = app.emit("stream:chunk", serde_json::json!({ "sessionId": session_id, "event": StreamChatEvent::Error { message: "已取消".to_string() } })); break; } match tokio::time::timeout(stream_timeout, rx.recv()).await { Ok(Some(event)) => { let stream_event = match &event { LoopEvent::Delta(delta) => { tracing::trace!("[agent_chat_stream] Delta: {} bytes", delta.len()); StreamChatEvent::Delta { delta: delta.clone() } } LoopEvent::ThinkingDelta(delta) => { tracing::trace!("[agent_chat_stream] ThinkingDelta: {} bytes", delta.len()); StreamChatEvent::ThinkingDelta { delta: delta.clone() } } LoopEvent::ToolStart { name, input } => { tracing::debug!("[agent_chat_stream] ToolStart: {}", name); if name.starts_with("hand_") { StreamChatEvent::HandStart { name: name.clone(), params: input.clone() } } else { StreamChatEvent::ToolStart { name: name.clone(), input: input.clone() } } } LoopEvent::ToolEnd { name, output } => { tracing::debug!("[agent_chat_stream] ToolEnd: {}", name); if name.starts_with("hand_") { StreamChatEvent::HandEnd { name: name.clone(), result: output.clone() } } else { StreamChatEvent::ToolEnd { name: name.clone(), output: output.clone() } } } LoopEvent::IterationStart { iteration, max_iterations } => { tracing::debug!("[agent_chat_stream] IterationStart: {}/{}", iteration, max_iterations); StreamChatEvent::IterationStart { iteration: *iteration, max_iterations: *max_iterations } } LoopEvent::Complete(result) => { tracing::info!("[agent_chat_stream] Complete: input_tokens={}, output_tokens={}", result.input_tokens, result.output_tokens); let agent_id_hook = agent_id_str.clone(); let message_hook = message.clone(); let hb = hb_state.clone(); let rf = rf_state.clone(); let driver = llm_driver.clone(); tokio::spawn(async move { crate::intelligence_hooks::post_conversation_hook( &agent_id_hook, &message_hook, &hb, &rf, driver, ).await; }); StreamChatEvent::Complete { input_tokens: result.input_tokens, output_tokens: result.output_tokens, } } LoopEvent::Error(message) => { tracing::warn!("[agent_chat_stream] Error: {}", message); StreamChatEvent::Error { message: message.clone() } } }; if let Err(e) = app.emit("stream:chunk", serde_json::json!({ "sessionId": session_id, "event": stream_event })) { tracing::warn!("[agent_chat_stream] Failed to emit event: {}", e); break; } if matches!(event, LoopEvent::Complete(_) | LoopEvent::Error(_)) { break; } } Ok(None) => { tracing::info!("[agent_chat_stream] Stream channel closed for session: {}", session_id); break; } Err(_) => { tracing::warn!("[agent_chat_stream] Stream idle timeout for session: {}", session_id); let _ = app.emit("stream:chunk", serde_json::json!({ "sessionId": session_id, "event": StreamChatEvent::Error { message: "流式响应超时,请重试".to_string() } })); break; } } } tracing::debug!("[agent_chat_stream] Stream processing ended for session: {}", session_id); // Release session lock and clean up DashMap entries to prevent memory leaks. // Use compare_exchange to only remove if we still own the flag (guards against // a new stream for the same session_id starting after we broke out of the loop). if guard_clone.compare_exchange(true, false, std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst).is_ok() { guard_map.remove(&session_id); } // Clean up cancellation flag (always safe — cancel is session-scoped) cancel_flags_map.remove(&session_id); }); Ok(()) } /// Cancel an active stream for a given session. /// /// Sets the cancellation flag for the session, which the streaming task /// checks on each iteration. The task will then emit an error event /// and clean up. // @connected #[tauri::command] pub async fn cancel_stream( cancel_flags: State<'_, StreamCancelFlags>, session_id: String, ) -> Result<(), String> { if let Some(flag) = cancel_flags.get(&session_id) { flag.store(true, std::sync::atomic::Ordering::SeqCst); tracing::info!("[cancel_stream] Cancel requested for session: {}", session_id); Ok(()) } else { // No active stream for this session — not an error, just a no-op tracing::debug!("[cancel_stream] No active stream for session: {}", session_id); Ok(()) } }