- Split zclaw-kernel/kernel.rs (1486 lines) into 9 domain modules - Split zclaw-kernel/generation.rs (1080 lines) into 3 modules - Add DeerFlow-inspired middleware: DanglingTool, SubagentLimit, ToolError, ToolOutputGuard - Add PromptBuilder for structured system prompt assembly - Add FactStore (zclaw-memory) for persistent fact extraction - Add task builtin tool for agent task management - Driver improvements: Anthropic/OpenAI extended thinking, Gemini safety settings - Replace let _ = with proper log::warn! across SaaS handlers - Remove unused dependency (url) from zclaw-hands
923 lines
41 KiB
Rust
923 lines
41 KiB
Rust
//! OpenAI-compatible driver implementation
|
|
|
|
use async_trait::async_trait;
|
|
use async_stream::stream;
|
|
use futures::{Stream, StreamExt};
|
|
use secrecy::{ExposeSecret, SecretString};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::pin::Pin;
|
|
use zclaw_types::{Result, ZclawError};
|
|
|
|
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
|
|
use crate::stream::StreamChunk;
|
|
|
|
/// OpenAI-compatible driver
|
|
pub struct OpenAiDriver {
|
|
client: Client,
|
|
api_key: SecretString,
|
|
base_url: String,
|
|
}
|
|
|
|
impl OpenAiDriver {
|
|
pub fn new(api_key: SecretString) -> Self {
|
|
Self {
|
|
client: Client::builder()
|
|
.user_agent(crate::USER_AGENT)
|
|
.http1_only()
|
|
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
|
.connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout
|
|
.build()
|
|
.unwrap_or_else(|_| Client::new()),
|
|
api_key,
|
|
base_url: "https://api.openai.com/v1".to_string(),
|
|
}
|
|
}
|
|
|
|
pub fn with_base_url(api_key: SecretString, base_url: String) -> Self {
|
|
Self {
|
|
client: Client::builder()
|
|
.user_agent(crate::USER_AGENT)
|
|
.http1_only()
|
|
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
|
.connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout
|
|
.build()
|
|
.unwrap_or_else(|_| Client::new()),
|
|
api_key,
|
|
base_url,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmDriver for OpenAiDriver {
|
|
fn provider(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
fn is_configured(&self) -> bool {
|
|
!self.api_key.expose_secret().is_empty()
|
|
}
|
|
|
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
|
let api_request = self.build_api_request(&request);
|
|
|
|
// Debug: log the request details
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
let request_body = serde_json::to_string(&api_request).unwrap_or_default();
|
|
tracing::debug!(target: "openai_driver", "Sending request to: {}", url);
|
|
tracing::trace!(target: "openai_driver", "Request body: {}", request_body);
|
|
|
|
let response = self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key.expose_secret()))
|
|
.header("Accept", "*/*")
|
|
.json(&api_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| ZclawError::LlmError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
tracing::warn!(target: "openai_driver", "API error {}: {}", status, body);
|
|
return Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
|
}
|
|
|
|
tracing::debug!(target: "openai_driver", "Response status: {}", response.status());
|
|
|
|
let api_response: OpenAiResponse = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| ZclawError::LlmError(format!("Failed to parse response: {}", e)))?;
|
|
|
|
Ok(self.convert_response(api_response, request.model))
|
|
}
|
|
|
|
fn stream(
|
|
&self,
|
|
request: CompletionRequest,
|
|
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
|
// Check if we should use non-streaming mode for tool calls
|
|
// Some providers don't support streaming with tools:
|
|
// - Alibaba DashScope: "tools暂时无法与stream=True同时使用"
|
|
// - Zhipu GLM: May have similar limitations
|
|
let has_tools = !request.tools.is_empty();
|
|
let needs_non_streaming = self.base_url.contains("dashscope") ||
|
|
self.base_url.contains("aliyuncs") ||
|
|
self.base_url.contains("bigmodel.cn");
|
|
|
|
tracing::debug!(target: "openai_driver", "stream config: base_url={}, has_tools={}, needs_non_streaming={}",
|
|
self.base_url, has_tools, needs_non_streaming);
|
|
|
|
if has_tools && needs_non_streaming {
|
|
tracing::info!(target: "openai_driver", "Provider detected that may not support streaming with tools, using non-streaming mode. URL: {}", self.base_url);
|
|
// Use non-streaming mode and convert to stream
|
|
return self.stream_from_complete(request);
|
|
}
|
|
|
|
let mut stream_request = self.build_api_request(&request);
|
|
stream_request.stream = true;
|
|
|
|
// Debug: log the request details
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
let request_body = serde_json::to_string(&stream_request).unwrap_or_default();
|
|
tracing::debug!("[OpenAiDriver:stream] Sending request to: {}", url);
|
|
tracing::debug!("[OpenAiDriver:stream] Request body length: {} bytes", request_body.len());
|
|
tracing::trace!("[OpenAiDriver:stream] Request body: {}", request_body);
|
|
|
|
let base_url = self.base_url.clone();
|
|
let api_key = self.api_key.expose_secret().to_string();
|
|
|
|
Box::pin(stream! {
|
|
tracing::debug!("[OpenAI:stream] POST to {}/chat/completions", base_url);
|
|
tracing::debug!("[OpenAI:stream] Request model={}, stream={}", stream_request.model, stream_request.stream);
|
|
let response = match self.client
|
|
.post(format!("{}/chat/completions", base_url))
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.header("Content-Type", "application/json")
|
|
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
|
.json(&stream_request)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => {
|
|
tracing::debug!("[OpenAI:stream] Response status: {}, content-type: {:?}", r.status(), r.headers().get("content-type"));
|
|
r
|
|
},
|
|
Err(e) => {
|
|
tracing::debug!("[OpenAI:stream] HTTP request FAILED: {:?}", e);
|
|
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
tracing::debug!("[OpenAI:stream] API error {}: {}", status, &body[..body.len().min(500)]);
|
|
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
|
return;
|
|
}
|
|
|
|
let mut byte_stream = response.bytes_stream();
|
|
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
|
let mut current_tool_id: Option<String> = None;
|
|
let mut sse_event_count: usize = 0;
|
|
let mut raw_bytes_total: usize = 0;
|
|
|
|
while let Some(chunk_result) = byte_stream.next().await {
|
|
let chunk = match chunk_result {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
tracing::debug!("[OpenAI:stream] Byte stream error: {:?}", e);
|
|
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
raw_bytes_total += chunk.len();
|
|
let text = String::from_utf8_lossy(&chunk);
|
|
// Log first 500 bytes of raw data for debugging SSE format
|
|
if raw_bytes_total <= 600 {
|
|
tracing::debug!("[OpenAI:stream] RAW chunk ({} bytes): {:?}", text.len(), &text[..text.len().min(500)]);
|
|
}
|
|
for line in text.lines() {
|
|
let trimmed = line.trim();
|
|
if trimmed.is_empty() || trimmed.starts_with(':') {
|
|
continue; // Skip empty lines and SSE comments
|
|
}
|
|
// Handle both "data: " (standard) and "data:" (no space)
|
|
let data = if let Some(d) = trimmed.strip_prefix("data: ") {
|
|
Some(d)
|
|
} else if let Some(d) = trimmed.strip_prefix("data:") {
|
|
Some(d.trim_start())
|
|
} else {
|
|
None
|
|
};
|
|
if let Some(data) = data {
|
|
sse_event_count += 1;
|
|
if sse_event_count <= 3 || data == "[DONE]" {
|
|
tracing::debug!("[OpenAI:stream] SSE #{}: {}", sse_event_count, &data[..data.len().min(300)]);
|
|
}
|
|
if data == "[DONE]" {
|
|
tracing::debug!("[OpenAI:stream] Received [DONE], total SSE events: {}, raw bytes: {}", sse_event_count, raw_bytes_total);
|
|
|
|
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
|
for (id, (name, args)) in &accumulated_tool_calls {
|
|
// Skip tool calls with empty name - they are invalid
|
|
if name.is_empty() {
|
|
tracing::warn!("[OpenAI] Skipping invalid tool call with empty name: id={}", id);
|
|
continue;
|
|
}
|
|
tracing::debug!("[OpenAI] Emitting ToolUseEnd: id={}, name={}, args={}", id, name, args);
|
|
// Ensure parsed args is always a valid JSON object
|
|
let parsed_args: serde_json::Value = if args.is_empty() {
|
|
serde_json::json!({})
|
|
} else {
|
|
serde_json::from_str(args).unwrap_or_else(|e| {
|
|
tracing::warn!("[OpenAI] Failed to parse tool args '{}': {}, using empty object", args, e);
|
|
serde_json::json!({})
|
|
})
|
|
};
|
|
yield Ok(StreamChunk::ToolUseEnd {
|
|
id: id.clone(),
|
|
input: parsed_args,
|
|
});
|
|
}
|
|
|
|
yield Ok(StreamChunk::Complete {
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
stop_reason: "end_turn".to_string(),
|
|
});
|
|
continue;
|
|
}
|
|
|
|
match serde_json::from_str::<OpenAiStreamResponse>(data) {
|
|
Ok(resp) => {
|
|
if let Some(choice) = resp.choices.first() {
|
|
let delta = &choice.delta;
|
|
|
|
// Handle text content
|
|
if let Some(content) = &delta.content {
|
|
if !content.is_empty() {
|
|
tracing::debug!("[OpenAI:stream] TextDelta: {} chars", content.len());
|
|
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
|
}
|
|
}
|
|
|
|
// Handle reasoning_content (Kimi, Qwen, DeepSeek, GLM thinking)
|
|
if let Some(reasoning) = &delta.reasoning_content {
|
|
if !reasoning.is_empty() {
|
|
tracing::debug!("[OpenAI:stream] ThinkingDelta (reasoning_content): {} chars", reasoning.len());
|
|
yield Ok(StreamChunk::ThinkingDelta { delta: reasoning.clone() });
|
|
}
|
|
}
|
|
|
|
// Handle tool calls
|
|
if let Some(tool_calls) = &delta.tool_calls {
|
|
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
|
for tc in tool_calls {
|
|
// Tool call start - has id and name
|
|
if let Some(id) = &tc.id {
|
|
// Get function name if available
|
|
let name = tc.function.as_ref()
|
|
.and_then(|f| f.name.clone())
|
|
.unwrap_or_default();
|
|
|
|
// Only emit ToolUseStart if we have a valid tool name
|
|
if !name.is_empty() {
|
|
tracing::debug!("[OpenAI] ToolUseStart: id={}, name={}", id, name);
|
|
current_tool_id = Some(id.clone());
|
|
accumulated_tool_calls.insert(id.clone(), (name.clone(), String::new()));
|
|
yield Ok(StreamChunk::ToolUseStart {
|
|
id: id.clone(),
|
|
name,
|
|
});
|
|
} else {
|
|
tracing::debug!("[OpenAI] Tool call with empty name, waiting for name delta: id={}", id);
|
|
// Still track the tool call but don't emit yet
|
|
current_tool_id = Some(id.clone());
|
|
accumulated_tool_calls.insert(id.clone(), (String::new(), String::new()));
|
|
}
|
|
}
|
|
|
|
// Tool call delta - has arguments
|
|
if let Some(function) = &tc.function {
|
|
tracing::trace!("[OpenAI] Function delta: name={:?}, arguments={:?}", function.name, function.arguments);
|
|
if let Some(args) = &function.arguments {
|
|
tracing::debug!("[OpenAI] ToolUseDelta: args={}", args);
|
|
// Try to find the tool by id or use current
|
|
let tool_id = tc.id.as_ref()
|
|
.or(current_tool_id.as_ref())
|
|
.cloned()
|
|
.unwrap_or_default();
|
|
|
|
yield Ok(StreamChunk::ToolUseDelta {
|
|
id: tool_id.clone(),
|
|
delta: args.clone(),
|
|
});
|
|
|
|
// Accumulate arguments
|
|
if let Some(entry) = accumulated_tool_calls.get_mut(&tool_id) {
|
|
tracing::debug!("[OpenAI] Accumulating args for tool {}: '{}' -> '{}'", tool_id, args, entry.1);
|
|
entry.1.push_str(args);
|
|
} else {
|
|
tracing::warn!("[OpenAI] No entry found for tool_id '{}' to accumulate args", tool_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("[OpenAI] Failed to parse SSE: {}, data: {}", e, data);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
tracing::debug!("[OpenAI:stream] Byte stream ended. Total: {} SSE events, {} raw bytes", sse_event_count, raw_bytes_total);
|
|
})
|
|
}
|
|
}
|
|
|
|
impl OpenAiDriver {
|
|
/// Check if this is a Coding Plan endpoint (requires coding context)
|
|
fn is_coding_plan_endpoint(&self) -> bool {
|
|
self.base_url.contains("coding.dashscope") ||
|
|
self.base_url.contains("coding/paas") ||
|
|
self.base_url.contains("api.kimi.com/coding")
|
|
}
|
|
|
|
fn build_api_request(&self, request: &CompletionRequest) -> OpenAiRequest {
|
|
// For Coding Plan endpoints, auto-add a coding assistant system prompt if not provided
|
|
let system_prompt = if request.system.is_none() && self.is_coding_plan_endpoint() {
|
|
Some("你是一个专业的编程助手,可以帮助用户解决编程问题、写代码、调试等。".to_string())
|
|
} else {
|
|
request.system.clone()
|
|
};
|
|
|
|
// Build messages with tool result truncation to prevent payload overflow.
|
|
// Most LLM APIs have a 2-4MB HTTP payload limit.
|
|
const MAX_TOOL_RESULT_BYTES: usize = 32_768; // 32KB per tool result
|
|
const MAX_PAYLOAD_BYTES: usize = 1_800_000; // 1.8MB (under 2MB API limit)
|
|
|
|
let mut messages: Vec<OpenAiMessage> = Vec::new();
|
|
let mut pending_tool_calls: Option<Vec<OpenAiToolCall>> = None;
|
|
let mut pending_content: Option<String> = None;
|
|
let mut pending_reasoning: Option<String> = None;
|
|
|
|
let flush_pending = |tc: &mut Option<Vec<OpenAiToolCall>>,
|
|
c: &mut Option<String>,
|
|
r: &mut Option<String>,
|
|
out: &mut Vec<OpenAiMessage>| {
|
|
let calls = tc.take();
|
|
let content = c.take();
|
|
let reasoning = r.take();
|
|
|
|
if let Some(calls) = calls {
|
|
if !calls.is_empty() {
|
|
// Merge assistant content + reasoning into the tool call message.
|
|
// IMPORTANT: Some APIs (Kimi, Qwen) require `content` to be non-empty
|
|
// even when tool_calls is set. Use a meaningful placeholder if content is empty.
|
|
let content_value = content.filter(|s| !s.trim().is_empty())
|
|
.unwrap_or_else(|| "正在调用工具...".to_string());
|
|
out.push(OpenAiMessage {
|
|
role: "assistant".to_string(),
|
|
content: Some(content_value),
|
|
reasoning_content: reasoning.filter(|s| !s.is_empty()),
|
|
tool_calls: Some(calls),
|
|
tool_call_id: None,
|
|
});
|
|
return;
|
|
}
|
|
}
|
|
// No tool calls — emit a plain assistant message.
|
|
// Ensure content is always Some() and non-empty to satisfy API requirements.
|
|
if content.is_some() || reasoning.is_some() {
|
|
let content_value = content.filter(|s| !s.trim().is_empty())
|
|
.unwrap_or_else(|| "正在思考...".to_string());
|
|
out.push(OpenAiMessage {
|
|
role: "assistant".to_string(),
|
|
content: Some(content_value),
|
|
reasoning_content: reasoning.filter(|s| !s.is_empty()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
});
|
|
}
|
|
};
|
|
|
|
for msg in &request.messages {
|
|
match msg {
|
|
zclaw_types::Message::User { content } => {
|
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
messages.push(OpenAiMessage {
|
|
role: "user".to_string(),
|
|
content: Some(content.clone()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
reasoning_content: None,
|
|
});
|
|
}
|
|
zclaw_types::Message::Assistant { content, thinking } => {
|
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
// Don't push immediately — wait to see if next messages are ToolUse
|
|
pending_content = Some(content.clone());
|
|
pending_reasoning = thinking.clone();
|
|
}
|
|
zclaw_types::Message::System { content } => {
|
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
messages.push(OpenAiMessage {
|
|
role: "system".to_string(),
|
|
content: Some(content.clone()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
reasoning_content: None,
|
|
});
|
|
}
|
|
zclaw_types::Message::ToolUse { id, tool, input } => {
|
|
// Accumulate tool calls — they'll be merged with the pending assistant message
|
|
let args = if input.is_null() {
|
|
"{}".to_string()
|
|
} else {
|
|
serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string())
|
|
};
|
|
pending_tool_calls
|
|
.get_or_insert_with(Vec::new)
|
|
.push(OpenAiToolCall {
|
|
id: id.clone(),
|
|
r#type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: tool.to_string(),
|
|
arguments: args,
|
|
},
|
|
});
|
|
}
|
|
zclaw_types::Message::ToolResult { tool_call_id, output, is_error, .. } => {
|
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
let content_str = if *is_error {
|
|
format!("Error: {}", output)
|
|
} else {
|
|
output.to_string()
|
|
};
|
|
// Truncate oversized tool results to prevent payload overflow
|
|
let truncated = if content_str.len() > MAX_TOOL_RESULT_BYTES {
|
|
let mut s = String::from(&content_str[..MAX_TOOL_RESULT_BYTES]);
|
|
s.push_str("\n\n... [内容已截断,原文过大]");
|
|
s
|
|
} else {
|
|
content_str
|
|
};
|
|
messages.push(OpenAiMessage {
|
|
role: "tool".to_string(),
|
|
content: Some(truncated),
|
|
tool_calls: None,
|
|
tool_call_id: Some(tool_call_id.clone()),
|
|
reasoning_content: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
// Flush any remaining accumulated assistant content and/or tool calls
|
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
|
|
// Add system prompt if provided
|
|
let mut messages = messages;
|
|
if let Some(system) = &system_prompt {
|
|
messages.insert(0, OpenAiMessage {
|
|
role: "system".to_string(),
|
|
content: Some(system.clone()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
reasoning_content: None,
|
|
});
|
|
}
|
|
|
|
let tools: Vec<OpenAiTool> = request.tools
|
|
.iter()
|
|
.map(|t| OpenAiTool {
|
|
r#type: "function".to_string(),
|
|
function: FunctionDef {
|
|
name: t.name.clone(),
|
|
description: t.description.clone(),
|
|
parameters: t.input_schema.clone(),
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
let api_request = OpenAiRequest {
|
|
model: request.model.clone(), // Use model ID directly without any transformation
|
|
messages,
|
|
max_tokens: request.max_tokens,
|
|
temperature: request.temperature,
|
|
stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) },
|
|
stream: request.stream,
|
|
tools: if tools.is_empty() { None } else { Some(tools) },
|
|
reasoning_effort: request.reasoning_effort.clone(),
|
|
};
|
|
|
|
// Pre-send payload size validation
|
|
if let Ok(serialized) = serde_json::to_string(&api_request) {
|
|
if serialized.len() > MAX_PAYLOAD_BYTES {
|
|
tracing::warn!(
|
|
target: "openai_driver",
|
|
"Request payload too large: {} bytes (limit: {}), truncating messages",
|
|
serialized.len(),
|
|
MAX_PAYLOAD_BYTES
|
|
);
|
|
return Self::truncate_messages_to_fit(api_request, MAX_PAYLOAD_BYTES);
|
|
}
|
|
tracing::debug!(
|
|
target: "openai_driver",
|
|
"Request payload size: {} bytes (limit: {})",
|
|
serialized.len(),
|
|
MAX_PAYLOAD_BYTES
|
|
);
|
|
}
|
|
|
|
api_request
|
|
}
|
|
|
|
/// Emergency truncation: drop oldest non-system messages until payload fits
|
|
fn truncate_messages_to_fit(mut request: OpenAiRequest, _max_bytes: usize) -> OpenAiRequest {
|
|
// Keep system message (if any) and last 4 non-system messages
|
|
let has_system = request.messages.first()
|
|
.map(|m| m.role == "system")
|
|
.unwrap_or(false);
|
|
|
|
let non_system: Vec<OpenAiMessage> = request.messages.into_iter()
|
|
.filter(|m| m.role != "system")
|
|
.collect();
|
|
|
|
// Keep last N messages and truncate any remaining large tool results
|
|
let keep_count = 4.min(non_system.len());
|
|
let start = non_system.len() - keep_count;
|
|
let kept: Vec<OpenAiMessage> = non_system.into_iter()
|
|
.skip(start)
|
|
.map(|mut msg| {
|
|
// Additional per-message truncation for tool results
|
|
if msg.role == "tool" {
|
|
if let Some(ref content) = msg.content {
|
|
if content.len() > 16_384 {
|
|
let mut s = String::from(&content[..16_384]);
|
|
s.push_str("\n\n... [上下文压缩截断]");
|
|
msg.content = Some(s);
|
|
}
|
|
}
|
|
}
|
|
msg
|
|
})
|
|
.collect();
|
|
|
|
let mut messages = Vec::new();
|
|
if has_system {
|
|
messages.push(OpenAiMessage {
|
|
role: "system".to_string(),
|
|
content: Some("You are a helpful AI assistant. (注意:对话历史已被压缩以适应上下文大小限制)".to_string()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
reasoning_content: None,
|
|
});
|
|
}
|
|
messages.extend(kept);
|
|
|
|
request.messages = messages;
|
|
request
|
|
}
|
|
|
|
fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse {
|
|
let choice = api_response.choices.first();
|
|
|
|
tracing::debug!("[OpenAiDriver:convert_response] Processing response: {} choices, first choice: {:?}", api_response.choices.len(), choice.map(|c| format!("content={:?}, tool_calls={:?}, finish_reason={:?}", c.message.content, c.message.tool_calls.as_ref().map(|tc| tc.len()), c.finish_reason)));
|
|
|
|
let (content, stop_reason) = match choice {
|
|
Some(c) => {
|
|
// Priority: tool_calls > non-empty content > empty content
|
|
// This is important because some providers return empty content with tool_calls
|
|
let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false);
|
|
let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
|
let has_reasoning = c.message.reasoning_content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
|
|
|
let blocks = if has_tool_calls {
|
|
// Tool calls take priority — safe to unwrap after has_tool_calls check
|
|
let tool_calls = c.message.tool_calls.as_ref().cloned().unwrap_or_default();
|
|
tracing::debug!("[OpenAiDriver:convert_response] Using tool_calls: {} calls", tool_calls.len());
|
|
tool_calls.iter().map(|tc| ContentBlock::ToolUse {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
input: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
|
}).collect()
|
|
} else if has_content {
|
|
// Non-empty content — safe to unwrap after has_content check
|
|
let text = c.message.content.as_deref().unwrap_or("");
|
|
tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len());
|
|
vec![ContentBlock::Text { text: text.to_string() }]
|
|
} else if has_reasoning {
|
|
// Content empty but reasoning_content present (Kimi, Qwen, DeepSeek)
|
|
let reasoning = c.message.reasoning_content.as_deref().unwrap_or("");
|
|
tracing::debug!("[OpenAiDriver:convert_response] Using reasoning_content: {} chars", reasoning.len());
|
|
vec![ContentBlock::Text { text: reasoning.to_string() }]
|
|
} else {
|
|
// No content or tool_calls
|
|
tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text");
|
|
vec![ContentBlock::Text { text: String::new() }]
|
|
};
|
|
|
|
let stop = match c.finish_reason.as_deref() {
|
|
Some("stop") => StopReason::EndTurn,
|
|
Some("length") => StopReason::MaxTokens,
|
|
Some("tool_calls") => StopReason::ToolUse,
|
|
_ => StopReason::EndTurn,
|
|
};
|
|
|
|
(blocks, stop)
|
|
}
|
|
None => {
|
|
tracing::debug!("[OpenAiDriver:convert_response] No choices in response");
|
|
(vec![ContentBlock::Text { text: String::new() }], StopReason::EndTurn)
|
|
}
|
|
};
|
|
|
|
let (input_tokens, output_tokens) = api_response.usage
|
|
.map(|u| (u.prompt_tokens, u.completion_tokens))
|
|
.unwrap_or((0, 0));
|
|
|
|
CompletionResponse {
|
|
content,
|
|
model,
|
|
input_tokens,
|
|
output_tokens,
|
|
stop_reason,
|
|
}
|
|
}
|
|
|
|
/// Convert a non-streaming completion to a stream for providers that don't support streaming with tools
|
|
fn stream_from_complete(&self, request: CompletionRequest) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
|
// Build non-streaming request
|
|
let mut complete_request = self.build_api_request(&request);
|
|
complete_request.stream = false;
|
|
|
|
// Capture values before entering the stream
|
|
let base_url = self.base_url.clone();
|
|
let api_key = self.api_key.expose_secret().to_string();
|
|
let model = request.model.clone();
|
|
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Starting non-streaming request to: {}/chat/completions", base_url);
|
|
|
|
Box::pin(stream! {
|
|
let url = format!("{}/chat/completions", base_url);
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Sending non-streaming request to: {}", url);
|
|
|
|
let response = match self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.header("Content-Type", "application/json")
|
|
.timeout(std::time::Duration::from_secs(120))
|
|
.json(&complete_request)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
|
return;
|
|
}
|
|
|
|
let api_response: OpenAiResponse = match response.json().await {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
tracing::warn!(target: "openai_driver", "stream_from_complete: Failed to parse response: {}", e);
|
|
yield Err(ZclawError::LlmError(format!("Failed to parse response: {}", e)));
|
|
return;
|
|
}
|
|
};
|
|
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Got response with {} choices", api_response.choices.len());
|
|
if let Some(choice) = api_response.choices.first() {
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: First choice: content={:?}, tool_calls={:?}, finish_reason={:?}",
|
|
choice.message.content.as_ref().map(|c| {
|
|
if c.len() > 100 {
|
|
// 使用 floor_char_boundary 确保不在多字节字符中间截断
|
|
let end = c.floor_char_boundary(100);
|
|
&c[..end]
|
|
} else {
|
|
c.as_str()
|
|
}
|
|
}),
|
|
choice.message.tool_calls.as_ref().map(|tc| tc.len()),
|
|
choice.finish_reason);
|
|
}
|
|
|
|
// Convert response to stream chunks
|
|
let completion = self.convert_response(api_response, model.clone());
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Converted to {} content blocks, stop_reason: {:?}", completion.content.len(), completion.stop_reason);
|
|
|
|
// Emit content blocks as stream chunks
|
|
for block in &completion.content {
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Emitting block: {:?}", block);
|
|
match block {
|
|
ContentBlock::Text { text } => {
|
|
if !text.is_empty() {
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Emitting TextDelta: {} chars", text.len());
|
|
yield Ok(StreamChunk::TextDelta { delta: text.clone() });
|
|
}
|
|
}
|
|
ContentBlock::Thinking { thinking } => {
|
|
yield Ok(StreamChunk::ThinkingDelta { delta: thinking.clone() });
|
|
}
|
|
ContentBlock::ToolUse { id, name, input } => {
|
|
tracing::debug!(target: "openai_driver", "stream_from_complete: Emitting ToolUse: id={}, name={}", id, name);
|
|
// Emit tool use start
|
|
yield Ok(StreamChunk::ToolUseStart {
|
|
id: id.clone(),
|
|
name: name.clone(),
|
|
});
|
|
// Emit tool use delta with arguments
|
|
if !input.is_null() {
|
|
let args_str = serde_json::to_string(input).unwrap_or_default();
|
|
yield Ok(StreamChunk::ToolUseDelta {
|
|
id: id.clone(),
|
|
delta: args_str,
|
|
});
|
|
}
|
|
// Emit tool use end
|
|
yield Ok(StreamChunk::ToolUseEnd {
|
|
id: id.clone(),
|
|
input: input.clone(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Emit completion
|
|
yield Ok(StreamChunk::Complete {
|
|
input_tokens: completion.input_tokens,
|
|
output_tokens: completion.output_tokens,
|
|
stop_reason: match completion.stop_reason {
|
|
StopReason::EndTurn => "end_turn",
|
|
StopReason::MaxTokens => "max_tokens",
|
|
StopReason::ToolUse => "tool_use",
|
|
StopReason::StopSequence => "stop",
|
|
StopReason::Error => "error",
|
|
}.to_string(),
|
|
});
|
|
})
|
|
}
|
|
}
|
|
|
|
// OpenAI API types
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiRequest {
|
|
model: String,
|
|
messages: Vec<OpenAiMessage>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
temperature: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
stop: Option<Vec<String>>,
|
|
#[serde(default)]
|
|
stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<OpenAiTool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
reasoning_effort: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiMessage {
|
|
role: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_calls: Option<Vec<OpenAiToolCall>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_call_id: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
reasoning_content: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiToolCall {
|
|
id: String,
|
|
r#type: String,
|
|
function: FunctionCall,
|
|
}
|
|
|
|
impl Default for OpenAiToolCall {
|
|
fn default() -> Self {
|
|
Self {
|
|
id: String::new(),
|
|
r#type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: String::new(),
|
|
arguments: String::new(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct FunctionCall {
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiTool {
|
|
r#type: String,
|
|
function: FunctionDef,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct FunctionDef {
|
|
name: String,
|
|
description: String,
|
|
parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiResponse {
|
|
#[serde(default)]
|
|
choices: Vec<OpenAiChoice>,
|
|
#[serde(default)]
|
|
usage: Option<OpenAiUsage>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default, Clone)]
|
|
struct OpenAiChoice {
|
|
#[serde(default)]
|
|
message: OpenAiResponseMessage,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default, Clone)]
|
|
struct OpenAiResponseMessage {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
reasoning_content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<OpenAiToolCallResponse>>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default, Clone)]
|
|
struct OpenAiToolCallResponse {
|
|
#[serde(default)]
|
|
id: String,
|
|
#[serde(default)]
|
|
function: FunctionCallResponse,
|
|
}
|
|
|
|
#[derive(Deserialize, Default, Clone)]
|
|
struct FunctionCallResponse {
|
|
#[serde(default)]
|
|
name: String,
|
|
#[serde(default)]
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiUsage {
|
|
#[serde(default)]
|
|
prompt_tokens: u32,
|
|
#[serde(default)]
|
|
completion_tokens: u32,
|
|
}
|
|
|
|
// OpenAI Streaming types
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiStreamResponse {
|
|
#[serde(default)]
|
|
choices: Vec<OpenAiStreamChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiStreamChoice {
|
|
#[serde(default)]
|
|
delta: OpenAiDelta,
|
|
#[serde(default)]
|
|
#[allow(dead_code)] // Used for deserialization, not accessed
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct OpenAiDelta {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
reasoning_content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiToolCallDelta {
|
|
#[serde(default)]
|
|
id: Option<String>,
|
|
#[serde(default)]
|
|
function: Option<OpenAiFunctionDelta>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiFunctionDelta {
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
arguments: Option<String>,
|
|
}
|