//! OpenAI-compatible driver implementation use async_trait::async_trait; use async_stream::stream; use futures::{Stream, StreamExt}; use secrecy::{ExposeSecret, SecretString}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::pin::Pin; use zclaw_types::{Result, ZclawError}; use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; use crate::stream::StreamChunk; /// OpenAI-compatible driver pub struct OpenAiDriver { client: Client, api_key: SecretString, base_url: String, } impl OpenAiDriver { pub fn new(api_key: SecretString) -> Self { Self { client: Client::builder() .user_agent(crate::USER_AGENT) .http1_only() .timeout(std::time::Duration::from_secs(120)) // 2 minute timeout .connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout .build() .unwrap_or_else(|_| Client::new()), api_key, base_url: "https://api.openai.com/v1".to_string(), } } pub fn with_base_url(api_key: SecretString, base_url: String) -> Self { Self { client: Client::builder() .user_agent(crate::USER_AGENT) .http1_only() .timeout(std::time::Duration::from_secs(120)) // 2 minute timeout .connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout .build() .unwrap_or_else(|_| Client::new()), api_key, base_url, } } } #[async_trait] impl LlmDriver for OpenAiDriver { fn provider(&self) -> &str { "openai" } fn is_configured(&self) -> bool { !self.api_key.expose_secret().is_empty() } async fn complete(&self, request: CompletionRequest) -> Result { let api_request = self.build_api_request(&request); // Debug: log the request details let url = format!("{}/chat/completions", self.base_url); let request_body = serde_json::to_string(&api_request).unwrap_or_default(); eprintln!("[OpenAiDriver] Sending request to: {}", url); eprintln!("[OpenAiDriver] Request body: {}", request_body); let response = self.client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key.expose_secret())) .header("Accept", "*/*") .json(&api_request) .send() .await .map_err(|e| ZclawError::LlmError(format!("HTTP request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); eprintln!("[OpenAiDriver] API error {}: {}", status, body); return Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); } eprintln!("[OpenAiDriver] Response status: {}", response.status()); let api_response: OpenAiResponse = response .json() .await .map_err(|e| ZclawError::LlmError(format!("Failed to parse response: {}", e)))?; Ok(self.convert_response(api_response, request.model)) } fn stream( &self, request: CompletionRequest, ) -> Pin> + Send + '_>> { // Check if we should use non-streaming mode for tool calls // Some providers don't support streaming with tools: // - Alibaba DashScope: "tools暂时无法与stream=True同时使用" // - Zhipu GLM: May have similar limitations let has_tools = !request.tools.is_empty(); let needs_non_streaming = self.base_url.contains("dashscope") || self.base_url.contains("aliyuncs") || self.base_url.contains("bigmodel.cn"); eprintln!("[OpenAiDriver:stream] base_url={}, has_tools={}, needs_non_streaming={}", self.base_url, has_tools, needs_non_streaming); if has_tools && needs_non_streaming { eprintln!("[OpenAiDriver:stream] Provider detected that may not support streaming with tools, using non-streaming mode. URL: {}", self.base_url); // Use non-streaming mode and convert to stream return self.stream_from_complete(request); } let mut stream_request = self.build_api_request(&request); stream_request.stream = true; // Debug: log the request details let url = format!("{}/chat/completions", self.base_url); let request_body = serde_json::to_string(&stream_request).unwrap_or_default(); tracing::debug!("[OpenAiDriver:stream] Sending request to: {}", url); tracing::debug!("[OpenAiDriver:stream] Request body length: {} bytes", request_body.len()); tracing::trace!("[OpenAiDriver:stream] Request body: {}", request_body); let base_url = self.base_url.clone(); let api_key = self.api_key.expose_secret().to_string(); Box::pin(stream! { tracing::debug!("[OpenAiDriver:stream] Starting HTTP request..."); let response = match self.client .post(format!("{}/chat/completions", base_url)) .header("Authorization", format!("Bearer {}", api_key)) .header("Content-Type", "application/json") .timeout(std::time::Duration::from_secs(120)) // 2 minute timeout .json(&stream_request) .send() .await { Ok(r) => { tracing::debug!("[OpenAiDriver:stream] Got response, status: {}", r.status()); r }, Err(e) => { tracing::error!("[OpenAiDriver:stream] HTTP request failed: {:?}", e); yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e))); return; } }; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); return; } let mut byte_stream = response.bytes_stream(); let mut accumulated_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); let mut current_tool_id: Option = None; while let Some(chunk_result) = byte_stream.next().await { let chunk = match chunk_result { Ok(c) => c, Err(e) => { yield Err(ZclawError::LlmError(format!("Stream error: {}", e))); continue; } }; let text = String::from_utf8_lossy(&chunk); for line in text.lines() { if let Some(data) = line.strip_prefix("data: ") { if data == "[DONE]" { tracing::debug!("[OpenAI] Stream done, accumulated_tool_calls: {:?}", accumulated_tool_calls.len()); // Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name) for (id, (name, args)) in &accumulated_tool_calls { // Skip tool calls with empty name - they are invalid if name.is_empty() { tracing::warn!("[OpenAI] Skipping invalid tool call with empty name: id={}", id); continue; } tracing::debug!("[OpenAI] Emitting ToolUseEnd: id={}, name={}, args={}", id, name, args); // Ensure parsed args is always a valid JSON object let parsed_args: serde_json::Value = if args.is_empty() { serde_json::json!({}) } else { serde_json::from_str(args).unwrap_or_else(|e| { tracing::warn!("[OpenAI] Failed to parse tool args '{}': {}, using empty object", args, e); serde_json::json!({}) }) }; yield Ok(StreamChunk::ToolUseEnd { id: id.clone(), input: parsed_args, }); } yield Ok(StreamChunk::Complete { input_tokens: 0, output_tokens: 0, stop_reason: "end_turn".to_string(), }); continue; } match serde_json::from_str::(data) { Ok(resp) => { if let Some(choice) = resp.choices.first() { let delta = &choice.delta; // Handle text content if let Some(content) = &delta.content { if !content.is_empty() { yield Ok(StreamChunk::TextDelta { delta: content.clone() }); } } // Handle tool calls if let Some(tool_calls) = &delta.tool_calls { tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls); for tc in tool_calls { // Tool call start - has id and name if let Some(id) = &tc.id { // Get function name if available let name = tc.function.as_ref() .and_then(|f| f.name.clone()) .unwrap_or_default(); // Only emit ToolUseStart if we have a valid tool name if !name.is_empty() { tracing::debug!("[OpenAI] ToolUseStart: id={}, name={}", id, name); current_tool_id = Some(id.clone()); accumulated_tool_calls.insert(id.clone(), (name.clone(), String::new())); yield Ok(StreamChunk::ToolUseStart { id: id.clone(), name, }); } else { tracing::debug!("[OpenAI] Tool call with empty name, waiting for name delta: id={}", id); // Still track the tool call but don't emit yet current_tool_id = Some(id.clone()); accumulated_tool_calls.insert(id.clone(), (String::new(), String::new())); } } // Tool call delta - has arguments if let Some(function) = &tc.function { tracing::trace!("[OpenAI] Function delta: name={:?}, arguments={:?}", function.name, function.arguments); if let Some(args) = &function.arguments { tracing::debug!("[OpenAI] ToolUseDelta: args={}", args); // Try to find the tool by id or use current let tool_id = tc.id.as_ref() .or(current_tool_id.as_ref()) .cloned() .unwrap_or_default(); yield Ok(StreamChunk::ToolUseDelta { id: tool_id.clone(), delta: args.clone(), }); // Accumulate arguments if let Some(entry) = accumulated_tool_calls.get_mut(&tool_id) { tracing::debug!("[OpenAI] Accumulating args for tool {}: '{}' -> '{}'", tool_id, args, entry.1); entry.1.push_str(args); } else { tracing::warn!("[OpenAI] No entry found for tool_id '{}' to accumulate args", tool_id); } } } } } } } Err(e) => { tracing::warn!("[OpenAI] Failed to parse SSE: {}, data: {}", e, data); } } } } } }) } } impl OpenAiDriver { /// Check if this is a Coding Plan endpoint (requires coding context) fn is_coding_plan_endpoint(&self) -> bool { self.base_url.contains("coding.dashscope") || self.base_url.contains("coding/paas") || self.base_url.contains("api.kimi.com/coding") } fn build_api_request(&self, request: &CompletionRequest) -> OpenAiRequest { // For Coding Plan endpoints, auto-add a coding assistant system prompt if not provided let system_prompt = if request.system.is_none() && self.is_coding_plan_endpoint() { Some("你是一个专业的编程助手,可以帮助用户解决编程问题、写代码、调试等。".to_string()) } else { request.system.clone() }; let messages: Vec = request.messages .iter() .filter_map(|msg| match msg { zclaw_types::Message::User { content } => Some(OpenAiMessage { role: "user".to_string(), content: Some(content.clone()), tool_calls: None, }), zclaw_types::Message::Assistant { content, thinking: _ } => Some(OpenAiMessage { role: "assistant".to_string(), content: Some(content.clone()), tool_calls: None, }), zclaw_types::Message::System { content } => Some(OpenAiMessage { role: "system".to_string(), content: Some(content.clone()), tool_calls: None, }), zclaw_types::Message::ToolUse { id, tool, input } => { // Ensure arguments is always a valid JSON object, never null or invalid let args = if input.is_null() { "{}".to_string() } else { serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string()) }; Some(OpenAiMessage { role: "assistant".to_string(), content: None, tool_calls: Some(vec![OpenAiToolCall { id: id.clone(), r#type: "function".to_string(), function: FunctionCall { name: tool.to_string(), arguments: args, }, }]), }) } zclaw_types::Message::ToolResult { tool_call_id: _, output, is_error, .. } => Some(OpenAiMessage { role: "tool".to_string(), content: Some(if *is_error { format!("Error: {}", output) } else { output.to_string() }), tool_calls: None, }), }) .collect(); // Add system prompt if provided let mut messages = messages; if let Some(system) = &system_prompt { messages.insert(0, OpenAiMessage { role: "system".to_string(), content: Some(system.clone()), tool_calls: None, }); } let tools: Vec = request.tools .iter() .map(|t| OpenAiTool { r#type: "function".to_string(), function: FunctionDef { name: t.name.clone(), description: t.description.clone(), parameters: t.input_schema.clone(), }, }) .collect(); OpenAiRequest { model: request.model.clone(), // Use model ID directly without any transformation messages, max_tokens: request.max_tokens, temperature: request.temperature, stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) }, stream: request.stream, tools: if tools.is_empty() { None } else { Some(tools) }, } } fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse { let choice = api_response.choices.first(); tracing::debug!("[OpenAiDriver:convert_response] Processing response: {} choices, first choice: {:?}", api_response.choices.len(), choice.map(|c| format!("content={:?}, tool_calls={:?}, finish_reason={:?}", c.message.content, c.message.tool_calls.as_ref().map(|tc| tc.len()), c.finish_reason))); let (content, stop_reason) = match choice { Some(c) => { // Priority: tool_calls > non-empty content > empty content // This is important because some providers return empty content with tool_calls let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false); let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false); let blocks = if has_tool_calls { // Tool calls take priority let tool_calls = c.message.tool_calls.as_ref().unwrap(); tracing::debug!("[OpenAiDriver:convert_response] Using tool_calls: {} calls", tool_calls.len()); tool_calls.iter().map(|tc| ContentBlock::ToolUse { id: tc.id.clone(), name: tc.function.name.clone(), input: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), }).collect() } else if has_content { // Non-empty content let text = c.message.content.as_ref().unwrap(); tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len()); vec![ContentBlock::Text { text: text.clone() }] } else { // No content or tool_calls tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text"); vec![ContentBlock::Text { text: String::new() }] }; let stop = match c.finish_reason.as_deref() { Some("stop") => StopReason::EndTurn, Some("length") => StopReason::MaxTokens, Some("tool_calls") => StopReason::ToolUse, _ => StopReason::EndTurn, }; (blocks, stop) } None => { tracing::debug!("[OpenAiDriver:convert_response] No choices in response"); (vec![ContentBlock::Text { text: String::new() }], StopReason::EndTurn) } }; let (input_tokens, output_tokens) = api_response.usage .map(|u| (u.prompt_tokens, u.completion_tokens)) .unwrap_or((0, 0)); CompletionResponse { content, model, input_tokens, output_tokens, stop_reason, } } /// Convert a non-streaming completion to a stream for providers that don't support streaming with tools fn stream_from_complete(&self, request: CompletionRequest) -> Pin> + Send + '_>> { // Build non-streaming request let mut complete_request = self.build_api_request(&request); complete_request.stream = false; // Capture values before entering the stream let base_url = self.base_url.clone(); let api_key = self.api_key.expose_secret().to_string(); let model = request.model.clone(); eprintln!("[OpenAiDriver:stream_from_complete] Starting non-streaming request to: {}/chat/completions", base_url); Box::pin(stream! { let url = format!("{}/chat/completions", base_url); eprintln!("[OpenAiDriver:stream_from_complete] Sending non-streaming request to: {}", url); let response = match self.client .post(&url) .header("Authorization", format!("Bearer {}", api_key)) .header("Content-Type", "application/json") .timeout(std::time::Duration::from_secs(120)) .json(&complete_request) .send() .await { Ok(r) => r, Err(e) => { yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e))); return; } }; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); return; } let api_response: OpenAiResponse = match response.json().await { Ok(r) => r, Err(e) => { eprintln!("[OpenAiDriver:stream_from_complete] Failed to parse response: {}", e); yield Err(ZclawError::LlmError(format!("Failed to parse response: {}", e))); return; } }; eprintln!("[OpenAiDriver:stream_from_complete] Got response with {} choices", api_response.choices.len()); if let Some(choice) = api_response.choices.first() { eprintln!("[OpenAiDriver:stream_from_complete] First choice: content={:?}, tool_calls={:?}, finish_reason={:?}", choice.message.content.as_ref().map(|c| { if c.len() > 100 { // 使用 floor_char_boundary 确保不在多字节字符中间截断 let end = c.floor_char_boundary(100); &c[..end] } else { c.as_str() } }), choice.message.tool_calls.as_ref().map(|tc| tc.len()), choice.finish_reason); } // Convert response to stream chunks let completion = self.convert_response(api_response, model.clone()); eprintln!("[OpenAiDriver:stream_from_complete] Converted to {} content blocks, stop_reason: {:?}", completion.content.len(), completion.stop_reason); // Emit content blocks as stream chunks for block in &completion.content { eprintln!("[OpenAiDriver:stream_from_complete] Emitting block: {:?}", block); match block { ContentBlock::Text { text } => { if !text.is_empty() { eprintln!("[OpenAiDriver:stream_from_complete] Emitting TextDelta: {} chars", text.len()); yield Ok(StreamChunk::TextDelta { delta: text.clone() }); } } ContentBlock::Thinking { thinking } => { yield Ok(StreamChunk::ThinkingDelta { delta: thinking.clone() }); } ContentBlock::ToolUse { id, name, input } => { eprintln!("[OpenAiDriver:stream_from_complete] Emitting ToolUse: id={}, name={}", id, name); // Emit tool use start yield Ok(StreamChunk::ToolUseStart { id: id.clone(), name: name.clone(), }); // Emit tool use delta with arguments if !input.is_null() { let args_str = serde_json::to_string(input).unwrap_or_default(); yield Ok(StreamChunk::ToolUseDelta { id: id.clone(), delta: args_str, }); } // Emit tool use end yield Ok(StreamChunk::ToolUseEnd { id: id.clone(), input: input.clone(), }); } } } // Emit completion yield Ok(StreamChunk::Complete { input_tokens: completion.input_tokens, output_tokens: completion.output_tokens, stop_reason: match completion.stop_reason { StopReason::EndTurn => "end_turn", StopReason::MaxTokens => "max_tokens", StopReason::ToolUse => "tool_use", StopReason::StopSequence => "stop", StopReason::Error => "error", }.to_string(), }); }) } } // OpenAI API types #[derive(Serialize)] struct OpenAiRequest { model: String, messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] stop: Option>, #[serde(default)] stream: bool, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, } #[derive(Serialize)] struct OpenAiMessage { role: String, #[serde(skip_serializing_if = "Option::is_none")] content: Option, #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, } #[derive(Serialize)] struct OpenAiToolCall { id: String, r#type: String, function: FunctionCall, } impl Default for OpenAiToolCall { fn default() -> Self { Self { id: String::new(), r#type: "function".to_string(), function: FunctionCall { name: String::new(), arguments: String::new(), }, } } } #[derive(Serialize)] struct FunctionCall { name: String, arguments: String, } #[derive(Serialize)] struct OpenAiTool { r#type: String, function: FunctionDef, } #[derive(Serialize)] struct FunctionDef { name: String, description: String, parameters: serde_json::Value, } #[derive(Deserialize, Default)] struct OpenAiResponse { #[serde(default)] choices: Vec, #[serde(default)] usage: Option, } #[derive(Deserialize, Default)] struct OpenAiChoice { #[serde(default)] message: OpenAiResponseMessage, #[serde(default)] finish_reason: Option, } #[derive(Deserialize, Default)] struct OpenAiResponseMessage { #[serde(default)] content: Option, #[serde(default)] tool_calls: Option>, } #[derive(Deserialize, Default)] struct OpenAiToolCallResponse { #[serde(default)] id: String, #[serde(default)] function: FunctionCallResponse, } #[derive(Deserialize, Default)] struct FunctionCallResponse { #[serde(default)] name: String, #[serde(default)] arguments: String, } #[derive(Deserialize, Default)] struct OpenAiUsage { #[serde(default)] prompt_tokens: u32, #[serde(default)] completion_tokens: u32, } // OpenAI Streaming types #[derive(Debug, Deserialize)] struct OpenAiStreamResponse { #[serde(default)] choices: Vec, } #[derive(Debug, Deserialize)] struct OpenAiStreamChoice { #[serde(default)] delta: OpenAiDelta, #[serde(default)] finish_reason: Option, } #[derive(Debug, Deserialize, Default)] struct OpenAiDelta { #[serde(default)] content: Option, #[serde(default)] tool_calls: Option>, } #[derive(Debug, Deserialize)] struct OpenAiToolCallDelta { #[serde(default)] id: Option, #[serde(default)] function: Option, } #[derive(Debug, Deserialize)] struct OpenAiFunctionDelta { #[serde(default)] name: Option, #[serde(default)] arguments: Option, }