feat(runtime): add streaming support to LlmDriver trait

- Add StreamChunk and StreamEvent types for Tauri event emission
- Add stream() method to LlmDriver trait with async-stream
- Implement Anthropic streaming with SSE parsing
- Implement OpenAI streaming with SSE parsing
- Add placeholder stream() for Gemini and Local drivers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-03-24 01:44:40 +08:00
parent 4ba0a531aa
commit 820e3a1ffe
8 changed files with 409 additions and 13 deletions

View File

@@ -14,6 +14,7 @@ zclaw-memory = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
futures = { workspace = true }
async-stream = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }

View File

@@ -1,12 +1,16 @@
//! Anthropic Claude driver implementation
use async_trait::async_trait;
use async_stream::stream;
use futures::{Stream, StreamExt};
use secrecy::{ExposeSecret, SecretString};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use zclaw_types::{Result, ZclawError};
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
use crate::stream::StreamChunk;
/// Anthropic API driver
pub struct AnthropicDriver {
@@ -69,6 +73,130 @@ impl LlmDriver for AnthropicDriver {
Ok(self.convert_response(api_response))
}
fn stream(
&self,
request: CompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
let mut stream_request = self.build_api_request(&request);
stream_request.stream = true;
let base_url = self.base_url.clone();
let api_key = self.api_key.expose_secret().to_string();
Box::pin(stream! {
let response = match self.client
.post(format!("{}/v1/messages", base_url))
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&stream_request)
.send()
.await
{
Ok(r) => r,
Err(e) => {
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
return;
}
let mut byte_stream = response.bytes_stream();
let mut current_tool_id: Option<String> = None;
let mut tool_input_buffer = String::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
continue;
}
};
let text = String::from_utf8_lossy(&chunk);
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
continue;
}
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
match event.event_type.as_str() {
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
yield Ok(StreamChunk::TextDelta { delta: text });
}
if let Some(thinking) = delta.thinking {
yield Ok(StreamChunk::ThinkingDelta { delta: thinking });
}
if let Some(json) = delta.partial_json {
tool_input_buffer.push_str(&json);
}
}
}
"content_block_start" => {
if let Some(block) = event.content_block {
match block.block_type.as_str() {
"tool_use" => {
current_tool_id = block.id.clone();
yield Ok(StreamChunk::ToolUseStart {
id: block.id.unwrap_or_default(),
name: block.name.unwrap_or_default(),
});
}
_ => {}
}
}
}
"content_block_stop" => {
if let Some(id) = current_tool_id.take() {
let input: serde_json::Value = serde_json::from_str(&tool_input_buffer)
.unwrap_or(serde_json::Value::Object(Default::default()));
yield Ok(StreamChunk::ToolUseEnd {
id,
input,
});
tool_input_buffer.clear();
}
}
"message_delta" => {
if let Some(msg) = event.message {
if msg.stop_reason.is_some() {
yield Ok(StreamChunk::Complete {
input_tokens: msg.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
output_tokens: msg.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
stop_reason: msg.stop_reason.unwrap_or_else(|| "end_turn".to_string()),
});
}
}
}
"error" => {
yield Ok(StreamChunk::Error {
message: "Stream error".to_string(),
});
}
_ => {}
}
}
Err(e) => {
tracing::warn!("Failed to parse SSE event: {} - {}", e, data);
}
}
}
}
}
})
}
}
impl AnthropicDriver {
@@ -224,3 +352,56 @@ struct AnthropicUsage {
input_tokens: u32,
output_tokens: u32,
}
// Streaming types
/// SSE event from Anthropic API
#[derive(Debug, Deserialize)]
struct AnthropicStreamEvent {
#[serde(rename = "type")]
event_type: String,
#[serde(default)]
index: Option<u32>,
#[serde(default)]
delta: Option<AnthropicDelta>,
#[serde(default)]
content_block: Option<AnthropicStreamContentBlock>,
#[serde(default)]
message: Option<AnthropicStreamMessage>,
}
#[derive(Debug, Deserialize)]
struct AnthropicDelta {
#[serde(default)]
text: Option<String>,
#[serde(default)]
thinking: Option<String>,
#[serde(default)]
partial_json: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AnthropicStreamContentBlock {
#[serde(rename = "type")]
block_type: String,
#[serde(default)]
id: Option<String>,
#[serde(default)]
name: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AnthropicStreamMessage {
#[serde(default)]
stop_reason: Option<String>,
#[serde(default)]
usage: Option<AnthropicStreamUsage>,
}
#[derive(Debug, Deserialize)]
struct AnthropicStreamUsage {
#[serde(default)]
input_tokens: u32,
#[serde(default)]
output_tokens: u32,
}

View File

@@ -1,11 +1,14 @@
//! Google Gemini driver implementation
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use secrecy::{ExposeSecret, SecretString};
use reqwest::Client;
use zclaw_types::Result;
use std::pin::Pin;
use zclaw_types::{Result, ZclawError};
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
use crate::stream::StreamChunk;
/// Google Gemini driver
pub struct GeminiDriver {
@@ -46,4 +49,14 @@ impl LlmDriver for GeminiDriver {
stop_reason: StopReason::EndTurn,
})
}
fn stream(
&self,
_request: CompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
// Placeholder - return error stream
Box::pin(futures::stream::once(async {
Err(ZclawError::LlmError("Gemini streaming not yet implemented".to_string()))
}))
}
}

View File

@@ -1,10 +1,13 @@
//! Local LLM driver (Ollama, LM Studio, vLLM, etc.)
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use reqwest::Client;
use zclaw_types::Result;
use std::pin::Pin;
use zclaw_types::{Result, ZclawError};
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
use crate::stream::StreamChunk;
/// Local LLM driver for Ollama, LM Studio, vLLM, etc.
pub struct LocalDriver {
@@ -56,4 +59,14 @@ impl LlmDriver for LocalDriver {
stop_reason: StopReason::EndTurn,
})
}
fn stream(
&self,
_request: CompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
// Placeholder - return error stream
Box::pin(futures::stream::once(async {
Err(ZclawError::LlmError("Local driver streaming not yet implemented".to_string()))
}))
}
}

View File

@@ -3,10 +3,14 @@
//! This module provides a unified interface for multiple LLM providers.
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use futures::Stream;
use secrecy::SecretString;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use zclaw_types::Result;
use crate::stream::StreamChunk;
mod anthropic;
mod openai;
mod gemini;
@@ -26,6 +30,13 @@ pub trait LlmDriver: Send + Sync {
/// Send a completion request
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
/// Send a streaming completion request
/// Returns a stream of chunks
fn stream(
&self,
request: CompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>>;
/// Check if the driver is properly configured
fn is_configured(&self) -> bool;
}

View File

@@ -1,12 +1,16 @@
//! OpenAI-compatible driver implementation
use async_trait::async_trait;
use async_stream::stream;
use futures::{Stream, StreamExt};
use secrecy::{ExposeSecret, SecretString};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use zclaw_types::{Result, ZclawError};
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason, ToolDefinition};
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
use crate::stream::StreamChunk;
/// OpenAI-compatible driver
pub struct OpenAiDriver {
@@ -85,6 +89,93 @@ impl LlmDriver for OpenAiDriver {
Ok(self.convert_response(api_response, request.model))
}
fn stream(
&self,
request: CompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
let mut stream_request = self.build_api_request(&request);
stream_request.stream = true;
let base_url = self.base_url.clone();
let api_key = self.api_key.expose_secret().to_string();
Box::pin(stream! {
let response = match self.client
.post(format!("{}/chat/completions", base_url))
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&stream_request)
.send()
.await
{
Ok(r) => r,
Err(e) => {
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
return;
}
let mut byte_stream = response.bytes_stream();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
continue;
}
};
let text = String::from_utf8_lossy(&chunk);
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
yield Ok(StreamChunk::Complete {
input_tokens: 0,
output_tokens: 0,
stop_reason: "end_turn".to_string(),
});
continue;
}
match serde_json::from_str::<OpenAiStreamResponse>(data) {
Ok(resp) => {
if let Some(choice) = resp.choices.first() {
let delta = &choice.delta;
if let Some(content) = &delta.content {
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
}
if let Some(tool_calls) = &delta.tool_calls {
for tc in tool_calls {
if let Some(function) = &tc.function {
if let Some(args) = &function.arguments {
yield Ok(StreamChunk::ToolUseDelta {
id: tc.id.clone().unwrap_or_default(),
delta: args.clone(),
});
}
}
}
}
}
}
Err(e) => {
tracing::warn!("Failed to parse OpenAI SSE: {}", e);
}
}
}
}
}
})
}
}
impl OpenAiDriver {
@@ -334,3 +425,41 @@ struct OpenAiUsage {
#[serde(default)]
completion_tokens: u32,
}
// OpenAI Streaming types
#[derive(Debug, Deserialize)]
struct OpenAiStreamResponse {
#[serde(default)]
choices: Vec<OpenAiStreamChoice>,
}
#[derive(Debug, Deserialize)]
struct OpenAiStreamChoice {
#[serde(default)]
delta: OpenAiDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize, Default)]
struct OpenAiDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
}
#[derive(Debug, Deserialize)]
struct OpenAiToolCallDelta {
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: Option<OpenAiFunctionDelta>,
}
#[derive(Debug, Deserialize)]
struct OpenAiFunctionDelta {
#[serde(default)]
arguments: Option<String>,
}

View File

@@ -1,11 +1,58 @@
//! Streaming utilities
//! Streaming response types
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use zclaw_types::Result;
/// Stream event for LLM responses
/// Stream chunk emitted during streaming
/// This is the serializable type sent via Tauri events
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamChunk {
/// Text delta
TextDelta { delta: String },
/// Thinking delta (for extended thinking models)
ThinkingDelta { delta: String },
/// Tool use started
ToolUseStart { id: String, name: String },
/// Tool use input delta
ToolUseDelta { id: String, delta: String },
/// Tool use completed
ToolUseEnd { id: String, input: serde_json::Value },
/// Stream completed
Complete {
input_tokens: u32,
output_tokens: u32,
stop_reason: String,
},
/// Error occurred
Error { message: String },
}
/// Streaming event for Tauri emission
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamEvent {
/// Session ID for routing
pub session_id: String,
/// Agent ID for routing
pub agent_id: String,
/// The chunk content
pub chunk: StreamChunk,
}
impl StreamEvent {
pub fn new(session_id: impl Into<String>, agent_id: impl Into<String>, chunk: StreamChunk) -> Self {
Self {
session_id: session_id.into(),
agent_id: agent_id.into(),
chunk,
}
}
}
/// Legacy stream event for internal use with mpsc channels
#[derive(Debug, Clone)]
pub enum StreamEvent {
pub enum InternalStreamEvent {
/// Text delta received
TextDelta(String),
/// Thinking delta received
@@ -24,31 +71,31 @@ pub enum StreamEvent {
/// Stream sender wrapper
pub struct StreamSender {
tx: mpsc::Sender<StreamEvent>,
tx: mpsc::Sender<InternalStreamEvent>,
}
impl StreamSender {
pub fn new(tx: mpsc::Sender<StreamEvent>) -> Self {
pub fn new(tx: mpsc::Sender<InternalStreamEvent>) -> Self {
Self { tx }
}
pub async fn send_text(&self, delta: impl Into<String>) -> Result<()> {
self.tx.send(StreamEvent::TextDelta(delta.into())).await.ok();
self.tx.send(InternalStreamEvent::TextDelta(delta.into())).await.ok();
Ok(())
}
pub async fn send_thinking(&self, delta: impl Into<String>) -> Result<()> {
self.tx.send(StreamEvent::ThinkingDelta(delta.into())).await.ok();
self.tx.send(InternalStreamEvent::ThinkingDelta(delta.into())).await.ok();
Ok(())
}
pub async fn send_complete(&self, input_tokens: u32, output_tokens: u32) -> Result<()> {
self.tx.send(StreamEvent::Complete { input_tokens, output_tokens }).await.ok();
self.tx.send(InternalStreamEvent::Complete { input_tokens, output_tokens }).await.ok();
Ok(())
}
pub async fn send_error(&self, error: impl Into<String>) -> Result<()> {
self.tx.send(StreamEvent::Error(error.into())).await.ok();
self.tx.send(InternalStreamEvent::Error(error.into())).await.ok();
Ok(())
}
}