diff --git a/Cargo.toml b/Cargo.toml index 31af1a4..85fbe9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ rust-version = "1.75" tokio = { version = "1", features = ["full"] } tokio-stream = "0.1" futures = "0.3" +async-stream = "0.3" # Serialization serde = { version = "1", features = ["derive"] } diff --git a/crates/zclaw-runtime/Cargo.toml b/crates/zclaw-runtime/Cargo.toml index 11c9cde..7d37d54 100644 --- a/crates/zclaw-runtime/Cargo.toml +++ b/crates/zclaw-runtime/Cargo.toml @@ -14,6 +14,7 @@ zclaw-memory = { workspace = true } tokio = { workspace = true } tokio-stream = { workspace = true } futures = { workspace = true } +async-stream = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } diff --git a/crates/zclaw-runtime/src/driver/anthropic.rs b/crates/zclaw-runtime/src/driver/anthropic.rs index 99b0e02..3082f6f 100644 --- a/crates/zclaw-runtime/src/driver/anthropic.rs +++ b/crates/zclaw-runtime/src/driver/anthropic.rs @@ -1,12 +1,16 @@ //! Anthropic Claude driver implementation use async_trait::async_trait; +use async_stream::stream; +use futures::{Stream, StreamExt}; use secrecy::{ExposeSecret, SecretString}; use reqwest::Client; use serde::{Deserialize, Serialize}; +use std::pin::Pin; use zclaw_types::{Result, ZclawError}; use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; +use crate::stream::StreamChunk; /// Anthropic API driver pub struct AnthropicDriver { @@ -69,6 +73,130 @@ impl LlmDriver for AnthropicDriver { Ok(self.convert_response(api_response)) } + + fn stream( + &self, + request: CompletionRequest, + ) -> Pin> + Send + '_>> { + let mut stream_request = self.build_api_request(&request); + stream_request.stream = true; + + let base_url = self.base_url.clone(); + let api_key = self.api_key.expose_secret().to_string(); + + Box::pin(stream! { + let response = match self.client + .post(format!("{}/v1/messages", base_url)) + .header("x-api-key", api_key) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&stream_request) + .send() + .await + { + Ok(r) => r, + Err(e) => { + yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e))); + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); + return; + } + + let mut byte_stream = response.bytes_stream(); + let mut current_tool_id: Option = None; + let mut tool_input_buffer = String::new(); + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ZclawError::LlmError(format!("Stream error: {}", e))); + continue; + } + }; + + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + continue; + } + + match serde_json::from_str::(data) { + Ok(event) => { + match event.event_type.as_str() { + "content_block_delta" => { + if let Some(delta) = event.delta { + if let Some(text) = delta.text { + yield Ok(StreamChunk::TextDelta { delta: text }); + } + if let Some(thinking) = delta.thinking { + yield Ok(StreamChunk::ThinkingDelta { delta: thinking }); + } + if let Some(json) = delta.partial_json { + tool_input_buffer.push_str(&json); + } + } + } + "content_block_start" => { + if let Some(block) = event.content_block { + match block.block_type.as_str() { + "tool_use" => { + current_tool_id = block.id.clone(); + yield Ok(StreamChunk::ToolUseStart { + id: block.id.unwrap_or_default(), + name: block.name.unwrap_or_default(), + }); + } + _ => {} + } + } + } + "content_block_stop" => { + if let Some(id) = current_tool_id.take() { + let input: serde_json::Value = serde_json::from_str(&tool_input_buffer) + .unwrap_or(serde_json::Value::Object(Default::default())); + yield Ok(StreamChunk::ToolUseEnd { + id, + input, + }); + tool_input_buffer.clear(); + } + } + "message_delta" => { + if let Some(msg) = event.message { + if msg.stop_reason.is_some() { + yield Ok(StreamChunk::Complete { + input_tokens: msg.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0), + output_tokens: msg.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), + stop_reason: msg.stop_reason.unwrap_or_else(|| "end_turn".to_string()), + }); + } + } + } + "error" => { + yield Ok(StreamChunk::Error { + message: "Stream error".to_string(), + }); + } + _ => {} + } + } + Err(e) => { + tracing::warn!("Failed to parse SSE event: {} - {}", e, data); + } + } + } + } + } + }) + } } impl AnthropicDriver { @@ -224,3 +352,56 @@ struct AnthropicUsage { input_tokens: u32, output_tokens: u32, } + +// Streaming types + +/// SSE event from Anthropic API +#[derive(Debug, Deserialize)] +struct AnthropicStreamEvent { + #[serde(rename = "type")] + event_type: String, + #[serde(default)] + index: Option, + #[serde(default)] + delta: Option, + #[serde(default)] + content_block: Option, + #[serde(default)] + message: Option, +} + +#[derive(Debug, Deserialize)] +struct AnthropicDelta { + #[serde(default)] + text: Option, + #[serde(default)] + thinking: Option, + #[serde(default)] + partial_json: Option, +} + +#[derive(Debug, Deserialize)] +struct AnthropicStreamContentBlock { + #[serde(rename = "type")] + block_type: String, + #[serde(default)] + id: Option, + #[serde(default)] + name: Option, +} + +#[derive(Debug, Deserialize)] +struct AnthropicStreamMessage { + #[serde(default)] + stop_reason: Option, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct AnthropicStreamUsage { + #[serde(default)] + input_tokens: u32, + #[serde(default)] + output_tokens: u32, +} diff --git a/crates/zclaw-runtime/src/driver/gemini.rs b/crates/zclaw-runtime/src/driver/gemini.rs index a2f450d..9003ee6 100644 --- a/crates/zclaw-runtime/src/driver/gemini.rs +++ b/crates/zclaw-runtime/src/driver/gemini.rs @@ -1,11 +1,14 @@ //! Google Gemini driver implementation use async_trait::async_trait; +use futures::{Stream, StreamExt}; use secrecy::{ExposeSecret, SecretString}; use reqwest::Client; -use zclaw_types::Result; +use std::pin::Pin; +use zclaw_types::{Result, ZclawError}; use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; +use crate::stream::StreamChunk; /// Google Gemini driver pub struct GeminiDriver { @@ -46,4 +49,14 @@ impl LlmDriver for GeminiDriver { stop_reason: StopReason::EndTurn, }) } + + fn stream( + &self, + _request: CompletionRequest, + ) -> Pin> + Send + '_>> { + // Placeholder - return error stream + Box::pin(futures::stream::once(async { + Err(ZclawError::LlmError("Gemini streaming not yet implemented".to_string())) + })) + } } diff --git a/crates/zclaw-runtime/src/driver/local.rs b/crates/zclaw-runtime/src/driver/local.rs index 0224133..d7234c3 100644 --- a/crates/zclaw-runtime/src/driver/local.rs +++ b/crates/zclaw-runtime/src/driver/local.rs @@ -1,10 +1,13 @@ //! Local LLM driver (Ollama, LM Studio, vLLM, etc.) use async_trait::async_trait; +use futures::{Stream, StreamExt}; use reqwest::Client; -use zclaw_types::Result; +use std::pin::Pin; +use zclaw_types::{Result, ZclawError}; use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; +use crate::stream::StreamChunk; /// Local LLM driver for Ollama, LM Studio, vLLM, etc. pub struct LocalDriver { @@ -56,4 +59,14 @@ impl LlmDriver for LocalDriver { stop_reason: StopReason::EndTurn, }) } + + fn stream( + &self, + _request: CompletionRequest, + ) -> Pin> + Send + '_>> { + // Placeholder - return error stream + Box::pin(futures::stream::once(async { + Err(ZclawError::LlmError("Local driver streaming not yet implemented".to_string())) + })) + } } diff --git a/crates/zclaw-runtime/src/driver/mod.rs b/crates/zclaw-runtime/src/driver/mod.rs index 0a8d4cf..8390cde 100644 --- a/crates/zclaw-runtime/src/driver/mod.rs +++ b/crates/zclaw-runtime/src/driver/mod.rs @@ -3,10 +3,14 @@ //! This module provides a unified interface for multiple LLM providers. use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use futures::Stream; use secrecy::SecretString; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; use zclaw_types::Result; +use crate::stream::StreamChunk; + mod anthropic; mod openai; mod gemini; @@ -26,6 +30,13 @@ pub trait LlmDriver: Send + Sync { /// Send a completion request async fn complete(&self, request: CompletionRequest) -> Result; + /// Send a streaming completion request + /// Returns a stream of chunks + fn stream( + &self, + request: CompletionRequest, + ) -> Pin> + Send + '_>>; + /// Check if the driver is properly configured fn is_configured(&self) -> bool; } diff --git a/crates/zclaw-runtime/src/driver/openai.rs b/crates/zclaw-runtime/src/driver/openai.rs index 0926220..b9d7562 100644 --- a/crates/zclaw-runtime/src/driver/openai.rs +++ b/crates/zclaw-runtime/src/driver/openai.rs @@ -1,12 +1,16 @@ //! OpenAI-compatible driver implementation use async_trait::async_trait; +use async_stream::stream; +use futures::{Stream, StreamExt}; use secrecy::{ExposeSecret, SecretString}; use reqwest::Client; use serde::{Deserialize, Serialize}; +use std::pin::Pin; use zclaw_types::{Result, ZclawError}; -use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason, ToolDefinition}; +use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; +use crate::stream::StreamChunk; /// OpenAI-compatible driver pub struct OpenAiDriver { @@ -85,6 +89,93 @@ impl LlmDriver for OpenAiDriver { Ok(self.convert_response(api_response, request.model)) } + + fn stream( + &self, + request: CompletionRequest, + ) -> Pin> + Send + '_>> { + let mut stream_request = self.build_api_request(&request); + stream_request.stream = true; + + let base_url = self.base_url.clone(); + let api_key = self.api_key.expose_secret().to_string(); + + Box::pin(stream! { + let response = match self.client + .post(format!("{}/chat/completions", base_url)) + .header("Authorization", format!("Bearer {}", api_key)) + .header("Content-Type", "application/json") + .json(&stream_request) + .send() + .await + { + Ok(r) => r, + Err(e) => { + yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e))); + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); + return; + } + + let mut byte_stream = response.bytes_stream(); + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ZclawError::LlmError(format!("Stream error: {}", e))); + continue; + } + }; + + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + yield Ok(StreamChunk::Complete { + input_tokens: 0, + output_tokens: 0, + stop_reason: "end_turn".to_string(), + }); + continue; + } + + match serde_json::from_str::(data) { + Ok(resp) => { + if let Some(choice) = resp.choices.first() { + let delta = &choice.delta; + if let Some(content) = &delta.content { + yield Ok(StreamChunk::TextDelta { delta: content.clone() }); + } + if let Some(tool_calls) = &delta.tool_calls { + for tc in tool_calls { + if let Some(function) = &tc.function { + if let Some(args) = &function.arguments { + yield Ok(StreamChunk::ToolUseDelta { + id: tc.id.clone().unwrap_or_default(), + delta: args.clone(), + }); + } + } + } + } + } + } + Err(e) => { + tracing::warn!("Failed to parse OpenAI SSE: {}", e); + } + } + } + } + } + }) + } } impl OpenAiDriver { @@ -334,3 +425,41 @@ struct OpenAiUsage { #[serde(default)] completion_tokens: u32, } + +// OpenAI Streaming types + +#[derive(Debug, Deserialize)] +struct OpenAiStreamResponse { + #[serde(default)] + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiStreamChoice { + #[serde(default)] + delta: OpenAiDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize, Default)] +struct OpenAiDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +struct OpenAiToolCallDelta { + #[serde(default)] + id: Option, + #[serde(default)] + function: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAiFunctionDelta { + #[serde(default)] + arguments: Option, +} diff --git a/crates/zclaw-runtime/src/stream.rs b/crates/zclaw-runtime/src/stream.rs index b96a3d7..51a12e9 100644 --- a/crates/zclaw-runtime/src/stream.rs +++ b/crates/zclaw-runtime/src/stream.rs @@ -1,11 +1,58 @@ -//! Streaming utilities +//! Streaming response types +use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use zclaw_types::Result; -/// Stream event for LLM responses +/// Stream chunk emitted during streaming +/// This is the serializable type sent via Tauri events +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamChunk { + /// Text delta + TextDelta { delta: String }, + /// Thinking delta (for extended thinking models) + ThinkingDelta { delta: String }, + /// Tool use started + ToolUseStart { id: String, name: String }, + /// Tool use input delta + ToolUseDelta { id: String, delta: String }, + /// Tool use completed + ToolUseEnd { id: String, input: serde_json::Value }, + /// Stream completed + Complete { + input_tokens: u32, + output_tokens: u32, + stop_reason: String, + }, + /// Error occurred + Error { message: String }, +} + +/// Streaming event for Tauri emission +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamEvent { + /// Session ID for routing + pub session_id: String, + /// Agent ID for routing + pub agent_id: String, + /// The chunk content + pub chunk: StreamChunk, +} + +impl StreamEvent { + pub fn new(session_id: impl Into, agent_id: impl Into, chunk: StreamChunk) -> Self { + Self { + session_id: session_id.into(), + agent_id: agent_id.into(), + chunk, + } + } +} + +/// Legacy stream event for internal use with mpsc channels #[derive(Debug, Clone)] -pub enum StreamEvent { +pub enum InternalStreamEvent { /// Text delta received TextDelta(String), /// Thinking delta received @@ -24,31 +71,31 @@ pub enum StreamEvent { /// Stream sender wrapper pub struct StreamSender { - tx: mpsc::Sender, + tx: mpsc::Sender, } impl StreamSender { - pub fn new(tx: mpsc::Sender) -> Self { + pub fn new(tx: mpsc::Sender) -> Self { Self { tx } } pub async fn send_text(&self, delta: impl Into) -> Result<()> { - self.tx.send(StreamEvent::TextDelta(delta.into())).await.ok(); + self.tx.send(InternalStreamEvent::TextDelta(delta.into())).await.ok(); Ok(()) } pub async fn send_thinking(&self, delta: impl Into) -> Result<()> { - self.tx.send(StreamEvent::ThinkingDelta(delta.into())).await.ok(); + self.tx.send(InternalStreamEvent::ThinkingDelta(delta.into())).await.ok(); Ok(()) } pub async fn send_complete(&self, input_tokens: u32, output_tokens: u32) -> Result<()> { - self.tx.send(StreamEvent::Complete { input_tokens, output_tokens }).await.ok(); + self.tx.send(InternalStreamEvent::Complete { input_tokens, output_tokens }).await.ok(); Ok(()) } pub async fn send_error(&self, error: impl Into) -> Result<()> { - self.tx.send(StreamEvent::Error(error.into())).await.ok(); + self.tx.send(InternalStreamEvent::Error(error.into())).await.ok(); Ok(()) } }