feat(runtime): add streaming support to LlmDriver trait
- Add StreamChunk and StreamEvent types for Tauri event emission - Add stream() method to LlmDriver trait with async-stream - Implement Anthropic streaming with SSE parsing - Implement OpenAI streaming with SSE parsing - Add placeholder stream() for Gemini and Local drivers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
//! OpenAI-compatible driver implementation
|
||||
|
||||
use async_trait::async_trait;
|
||||
use async_stream::stream;
|
||||
use futures::{Stream, StreamExt};
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::pin::Pin;
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason, ToolDefinition};
|
||||
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
|
||||
use crate::stream::StreamChunk;
|
||||
|
||||
/// OpenAI-compatible driver
|
||||
pub struct OpenAiDriver {
|
||||
@@ -85,6 +89,93 @@ impl LlmDriver for OpenAiDriver {
|
||||
|
||||
Ok(self.convert_response(api_response, request.model))
|
||||
}
|
||||
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
||||
let mut stream_request = self.build_api_request(&request);
|
||||
stream_request.stream = true;
|
||||
|
||||
let base_url = self.base_url.clone();
|
||||
let api_key = self.api_key.expose_secret().to_string();
|
||||
|
||||
Box::pin(stream! {
|
||||
let response = match self.client
|
||||
.post(format!("{}/chat/completions", base_url))
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&stream_request)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
||||
return;
|
||||
}
|
||||
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk_result) = byte_stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
for line in text.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
yield Ok(StreamChunk::Complete {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAiStreamResponse>(data) {
|
||||
Ok(resp) => {
|
||||
if let Some(choice) = resp.choices.first() {
|
||||
let delta = &choice.delta;
|
||||
if let Some(content) = &delta.content {
|
||||
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
||||
}
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tc in tool_calls {
|
||||
if let Some(function) = &tc.function {
|
||||
if let Some(args) = &function.arguments {
|
||||
yield Ok(StreamChunk::ToolUseDelta {
|
||||
id: tc.id.clone().unwrap_or_default(),
|
||||
delta: args.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse OpenAI SSE: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiDriver {
|
||||
@@ -334,3 +425,41 @@ struct OpenAiUsage {
|
||||
#[serde(default)]
|
||||
completion_tokens: u32,
|
||||
}
|
||||
|
||||
// OpenAI Streaming types
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiStreamResponse {
|
||||
#[serde(default)]
|
||||
choices: Vec<OpenAiStreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiStreamChoice {
|
||||
#[serde(default)]
|
||||
delta: OpenAiDelta,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct OpenAiDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiToolCallDelta {
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
function: Option<OpenAiFunctionDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiFunctionDelta {
|
||||
#[serde(default)]
|
||||
arguments: Option<String>,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user