Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
refactor: 统一Hands系统常量到单个源文件 refactor: 更新Hands中文名称和描述 fix: 修复技能市场在连接状态变化时重新加载 fix: 修复身份变更提案的错误处理逻辑 docs: 更新多个功能文档的验证状态和实现位置 docs: 更新Hands系统文档 test: 添加测试文件验证工作区路径
725 lines
30 KiB
Rust
725 lines
30 KiB
Rust
//! OpenAI-compatible driver implementation
|
|
|
|
use async_trait::async_trait;
|
|
use async_stream::stream;
|
|
use futures::{Stream, StreamExt};
|
|
use secrecy::{ExposeSecret, SecretString};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::pin::Pin;
|
|
use zclaw_types::{Result, ZclawError};
|
|
|
|
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
|
|
use crate::stream::StreamChunk;
|
|
|
|
/// OpenAI-compatible driver
|
|
pub struct OpenAiDriver {
|
|
client: Client,
|
|
api_key: SecretString,
|
|
base_url: String,
|
|
}
|
|
|
|
impl OpenAiDriver {
|
|
pub fn new(api_key: SecretString) -> Self {
|
|
Self {
|
|
client: Client::builder()
|
|
.user_agent(crate::USER_AGENT)
|
|
.http1_only()
|
|
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
|
.connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout
|
|
.build()
|
|
.unwrap_or_else(|_| Client::new()),
|
|
api_key,
|
|
base_url: "https://api.openai.com/v1".to_string(),
|
|
}
|
|
}
|
|
|
|
pub fn with_base_url(api_key: SecretString, base_url: String) -> Self {
|
|
Self {
|
|
client: Client::builder()
|
|
.user_agent(crate::USER_AGENT)
|
|
.http1_only()
|
|
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
|
.connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout
|
|
.build()
|
|
.unwrap_or_else(|_| Client::new()),
|
|
api_key,
|
|
base_url,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmDriver for OpenAiDriver {
|
|
fn provider(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
fn is_configured(&self) -> bool {
|
|
!self.api_key.expose_secret().is_empty()
|
|
}
|
|
|
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
|
let api_request = self.build_api_request(&request);
|
|
|
|
// Debug: log the request details
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
let request_body = serde_json::to_string(&api_request).unwrap_or_default();
|
|
eprintln!("[OpenAiDriver] Sending request to: {}", url);
|
|
eprintln!("[OpenAiDriver] Request body: {}", request_body);
|
|
|
|
let response = self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key.expose_secret()))
|
|
.header("Accept", "*/*")
|
|
.json(&api_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| ZclawError::LlmError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
eprintln!("[OpenAiDriver] API error {}: {}", status, body);
|
|
return Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
|
}
|
|
|
|
eprintln!("[OpenAiDriver] Response status: {}", response.status());
|
|
|
|
let api_response: OpenAiResponse = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| ZclawError::LlmError(format!("Failed to parse response: {}", e)))?;
|
|
|
|
Ok(self.convert_response(api_response, request.model))
|
|
}
|
|
|
|
fn stream(
|
|
&self,
|
|
request: CompletionRequest,
|
|
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
|
// Check if we should use non-streaming mode for tool calls
|
|
// Some providers don't support streaming with tools:
|
|
// - Alibaba DashScope: "tools暂时无法与stream=True同时使用"
|
|
// - Zhipu GLM: May have similar limitations
|
|
let has_tools = !request.tools.is_empty();
|
|
let needs_non_streaming = self.base_url.contains("dashscope") ||
|
|
self.base_url.contains("aliyuncs") ||
|
|
self.base_url.contains("bigmodel.cn");
|
|
|
|
eprintln!("[OpenAiDriver:stream] base_url={}, has_tools={}, needs_non_streaming={}",
|
|
self.base_url, has_tools, needs_non_streaming);
|
|
|
|
if has_tools && needs_non_streaming {
|
|
eprintln!("[OpenAiDriver:stream] Provider detected that may not support streaming with tools, using non-streaming mode. URL: {}", self.base_url);
|
|
// Use non-streaming mode and convert to stream
|
|
return self.stream_from_complete(request);
|
|
}
|
|
|
|
let mut stream_request = self.build_api_request(&request);
|
|
stream_request.stream = true;
|
|
|
|
// Debug: log the request details
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
let request_body = serde_json::to_string(&stream_request).unwrap_or_default();
|
|
tracing::debug!("[OpenAiDriver:stream] Sending request to: {}", url);
|
|
tracing::debug!("[OpenAiDriver:stream] Request body length: {} bytes", request_body.len());
|
|
tracing::trace!("[OpenAiDriver:stream] Request body: {}", request_body);
|
|
|
|
let base_url = self.base_url.clone();
|
|
let api_key = self.api_key.expose_secret().to_string();
|
|
|
|
Box::pin(stream! {
|
|
tracing::debug!("[OpenAiDriver:stream] Starting HTTP request...");
|
|
let response = match self.client
|
|
.post(format!("{}/chat/completions", base_url))
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.header("Content-Type", "application/json")
|
|
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
|
.json(&stream_request)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => {
|
|
tracing::debug!("[OpenAiDriver:stream] Got response, status: {}", r.status());
|
|
r
|
|
},
|
|
Err(e) => {
|
|
tracing::error!("[OpenAiDriver:stream] HTTP request failed: {:?}", e);
|
|
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
|
return;
|
|
}
|
|
|
|
let mut byte_stream = response.bytes_stream();
|
|
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
|
let mut current_tool_id: Option<String> = None;
|
|
|
|
while let Some(chunk_result) = byte_stream.next().await {
|
|
let chunk = match chunk_result {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let text = String::from_utf8_lossy(&chunk);
|
|
for line in text.lines() {
|
|
if let Some(data) = line.strip_prefix("data: ") {
|
|
if data == "[DONE]" {
|
|
tracing::debug!("[OpenAI] Stream done, accumulated_tool_calls: {:?}", accumulated_tool_calls.len());
|
|
|
|
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
|
for (id, (name, args)) in &accumulated_tool_calls {
|
|
// Skip tool calls with empty name - they are invalid
|
|
if name.is_empty() {
|
|
tracing::warn!("[OpenAI] Skipping invalid tool call with empty name: id={}", id);
|
|
continue;
|
|
}
|
|
tracing::debug!("[OpenAI] Emitting ToolUseEnd: id={}, name={}, args={}", id, name, args);
|
|
// Ensure parsed args is always a valid JSON object
|
|
let parsed_args: serde_json::Value = if args.is_empty() {
|
|
serde_json::json!({})
|
|
} else {
|
|
serde_json::from_str(args).unwrap_or_else(|e| {
|
|
tracing::warn!("[OpenAI] Failed to parse tool args '{}': {}, using empty object", args, e);
|
|
serde_json::json!({})
|
|
})
|
|
};
|
|
yield Ok(StreamChunk::ToolUseEnd {
|
|
id: id.clone(),
|
|
input: parsed_args,
|
|
});
|
|
}
|
|
|
|
yield Ok(StreamChunk::Complete {
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
stop_reason: "end_turn".to_string(),
|
|
});
|
|
continue;
|
|
}
|
|
|
|
match serde_json::from_str::<OpenAiStreamResponse>(data) {
|
|
Ok(resp) => {
|
|
if let Some(choice) = resp.choices.first() {
|
|
let delta = &choice.delta;
|
|
|
|
// Handle text content
|
|
if let Some(content) = &delta.content {
|
|
if !content.is_empty() {
|
|
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
|
}
|
|
}
|
|
|
|
// Handle tool calls
|
|
if let Some(tool_calls) = &delta.tool_calls {
|
|
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
|
for tc in tool_calls {
|
|
// Tool call start - has id and name
|
|
if let Some(id) = &tc.id {
|
|
// Get function name if available
|
|
let name = tc.function.as_ref()
|
|
.and_then(|f| f.name.clone())
|
|
.unwrap_or_default();
|
|
|
|
// Only emit ToolUseStart if we have a valid tool name
|
|
if !name.is_empty() {
|
|
tracing::debug!("[OpenAI] ToolUseStart: id={}, name={}", id, name);
|
|
current_tool_id = Some(id.clone());
|
|
accumulated_tool_calls.insert(id.clone(), (name.clone(), String::new()));
|
|
yield Ok(StreamChunk::ToolUseStart {
|
|
id: id.clone(),
|
|
name,
|
|
});
|
|
} else {
|
|
tracing::debug!("[OpenAI] Tool call with empty name, waiting for name delta: id={}", id);
|
|
// Still track the tool call but don't emit yet
|
|
current_tool_id = Some(id.clone());
|
|
accumulated_tool_calls.insert(id.clone(), (String::new(), String::new()));
|
|
}
|
|
}
|
|
|
|
// Tool call delta - has arguments
|
|
if let Some(function) = &tc.function {
|
|
tracing::trace!("[OpenAI] Function delta: name={:?}, arguments={:?}", function.name, function.arguments);
|
|
if let Some(args) = &function.arguments {
|
|
tracing::debug!("[OpenAI] ToolUseDelta: args={}", args);
|
|
// Try to find the tool by id or use current
|
|
let tool_id = tc.id.as_ref()
|
|
.or(current_tool_id.as_ref())
|
|
.cloned()
|
|
.unwrap_or_default();
|
|
|
|
yield Ok(StreamChunk::ToolUseDelta {
|
|
id: tool_id.clone(),
|
|
delta: args.clone(),
|
|
});
|
|
|
|
// Accumulate arguments
|
|
if let Some(entry) = accumulated_tool_calls.get_mut(&tool_id) {
|
|
tracing::debug!("[OpenAI] Accumulating args for tool {}: '{}' -> '{}'", tool_id, args, entry.1);
|
|
entry.1.push_str(args);
|
|
} else {
|
|
tracing::warn!("[OpenAI] No entry found for tool_id '{}' to accumulate args", tool_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("[OpenAI] Failed to parse SSE: {}, data: {}", e, data);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
impl OpenAiDriver {
|
|
/// Check if this is a Coding Plan endpoint (requires coding context)
|
|
fn is_coding_plan_endpoint(&self) -> bool {
|
|
self.base_url.contains("coding.dashscope") ||
|
|
self.base_url.contains("coding/paas") ||
|
|
self.base_url.contains("api.kimi.com/coding")
|
|
}
|
|
|
|
fn build_api_request(&self, request: &CompletionRequest) -> OpenAiRequest {
|
|
// For Coding Plan endpoints, auto-add a coding assistant system prompt if not provided
|
|
let system_prompt = if request.system.is_none() && self.is_coding_plan_endpoint() {
|
|
Some("你是一个专业的编程助手,可以帮助用户解决编程问题、写代码、调试等。".to_string())
|
|
} else {
|
|
request.system.clone()
|
|
};
|
|
|
|
let messages: Vec<OpenAiMessage> = request.messages
|
|
.iter()
|
|
.filter_map(|msg| match msg {
|
|
zclaw_types::Message::User { content } => Some(OpenAiMessage {
|
|
role: "user".to_string(),
|
|
content: Some(content.clone()),
|
|
tool_calls: None,
|
|
}),
|
|
zclaw_types::Message::Assistant { content, thinking: _ } => Some(OpenAiMessage {
|
|
role: "assistant".to_string(),
|
|
content: Some(content.clone()),
|
|
tool_calls: None,
|
|
}),
|
|
zclaw_types::Message::System { content } => Some(OpenAiMessage {
|
|
role: "system".to_string(),
|
|
content: Some(content.clone()),
|
|
tool_calls: None,
|
|
}),
|
|
zclaw_types::Message::ToolUse { id, tool, input } => {
|
|
// Ensure arguments is always a valid JSON object, never null or invalid
|
|
let args = if input.is_null() {
|
|
"{}".to_string()
|
|
} else {
|
|
serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string())
|
|
};
|
|
Some(OpenAiMessage {
|
|
role: "assistant".to_string(),
|
|
content: None,
|
|
tool_calls: Some(vec![OpenAiToolCall {
|
|
id: id.clone(),
|
|
r#type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: tool.to_string(),
|
|
arguments: args,
|
|
},
|
|
}]),
|
|
})
|
|
}
|
|
zclaw_types::Message::ToolResult { tool_call_id: _, output, is_error, .. } => Some(OpenAiMessage {
|
|
role: "tool".to_string(),
|
|
content: Some(if *is_error {
|
|
format!("Error: {}", output)
|
|
} else {
|
|
output.to_string()
|
|
}),
|
|
tool_calls: None,
|
|
}),
|
|
})
|
|
.collect();
|
|
|
|
// Add system prompt if provided
|
|
let mut messages = messages;
|
|
if let Some(system) = &system_prompt {
|
|
messages.insert(0, OpenAiMessage {
|
|
role: "system".to_string(),
|
|
content: Some(system.clone()),
|
|
tool_calls: None,
|
|
});
|
|
}
|
|
|
|
let tools: Vec<OpenAiTool> = request.tools
|
|
.iter()
|
|
.map(|t| OpenAiTool {
|
|
r#type: "function".to_string(),
|
|
function: FunctionDef {
|
|
name: t.name.clone(),
|
|
description: t.description.clone(),
|
|
parameters: t.input_schema.clone(),
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
OpenAiRequest {
|
|
model: request.model.clone(), // Use model ID directly without any transformation
|
|
messages,
|
|
max_tokens: request.max_tokens,
|
|
temperature: request.temperature,
|
|
stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) },
|
|
stream: request.stream,
|
|
tools: if tools.is_empty() { None } else { Some(tools) },
|
|
}
|
|
}
|
|
|
|
fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse {
|
|
let choice = api_response.choices.first();
|
|
|
|
tracing::debug!("[OpenAiDriver:convert_response] Processing response: {} choices, first choice: {:?}", api_response.choices.len(), choice.map(|c| format!("content={:?}, tool_calls={:?}, finish_reason={:?}", c.message.content, c.message.tool_calls.as_ref().map(|tc| tc.len()), c.finish_reason)));
|
|
|
|
let (content, stop_reason) = match choice {
|
|
Some(c) => {
|
|
// Priority: tool_calls > non-empty content > empty content
|
|
// This is important because some providers return empty content with tool_calls
|
|
let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false);
|
|
let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
|
|
|
let blocks = if has_tool_calls {
|
|
// Tool calls take priority
|
|
let tool_calls = c.message.tool_calls.as_ref().unwrap();
|
|
tracing::debug!("[OpenAiDriver:convert_response] Using tool_calls: {} calls", tool_calls.len());
|
|
tool_calls.iter().map(|tc| ContentBlock::ToolUse {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
input: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
|
}).collect()
|
|
} else if has_content {
|
|
// Non-empty content
|
|
let text = c.message.content.as_ref().unwrap();
|
|
tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len());
|
|
vec![ContentBlock::Text { text: text.clone() }]
|
|
} else {
|
|
// No content or tool_calls
|
|
tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text");
|
|
vec![ContentBlock::Text { text: String::new() }]
|
|
};
|
|
|
|
let stop = match c.finish_reason.as_deref() {
|
|
Some("stop") => StopReason::EndTurn,
|
|
Some("length") => StopReason::MaxTokens,
|
|
Some("tool_calls") => StopReason::ToolUse,
|
|
_ => StopReason::EndTurn,
|
|
};
|
|
|
|
(blocks, stop)
|
|
}
|
|
None => {
|
|
tracing::debug!("[OpenAiDriver:convert_response] No choices in response");
|
|
(vec![ContentBlock::Text { text: String::new() }], StopReason::EndTurn)
|
|
}
|
|
};
|
|
|
|
let (input_tokens, output_tokens) = api_response.usage
|
|
.map(|u| (u.prompt_tokens, u.completion_tokens))
|
|
.unwrap_or((0, 0));
|
|
|
|
CompletionResponse {
|
|
content,
|
|
model,
|
|
input_tokens,
|
|
output_tokens,
|
|
stop_reason,
|
|
}
|
|
}
|
|
|
|
/// Convert a non-streaming completion to a stream for providers that don't support streaming with tools
|
|
fn stream_from_complete(&self, request: CompletionRequest) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
|
// Build non-streaming request
|
|
let mut complete_request = self.build_api_request(&request);
|
|
complete_request.stream = false;
|
|
|
|
// Capture values before entering the stream
|
|
let base_url = self.base_url.clone();
|
|
let api_key = self.api_key.expose_secret().to_string();
|
|
let model = request.model.clone();
|
|
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Starting non-streaming request to: {}/chat/completions", base_url);
|
|
|
|
Box::pin(stream! {
|
|
let url = format!("{}/chat/completions", base_url);
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Sending non-streaming request to: {}", url);
|
|
|
|
let response = match self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.header("Content-Type", "application/json")
|
|
.timeout(std::time::Duration::from_secs(120))
|
|
.json(&complete_request)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
|
return;
|
|
}
|
|
|
|
let api_response: OpenAiResponse = match response.json().await {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Failed to parse response: {}", e);
|
|
yield Err(ZclawError::LlmError(format!("Failed to parse response: {}", e)));
|
|
return;
|
|
}
|
|
};
|
|
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Got response with {} choices", api_response.choices.len());
|
|
if let Some(choice) = api_response.choices.first() {
|
|
eprintln!("[OpenAiDriver:stream_from_complete] First choice: content={:?}, tool_calls={:?}, finish_reason={:?}",
|
|
choice.message.content.as_ref().map(|c| {
|
|
if c.len() > 100 {
|
|
// 使用 floor_char_boundary 确保不在多字节字符中间截断
|
|
let end = c.floor_char_boundary(100);
|
|
&c[..end]
|
|
} else {
|
|
c.as_str()
|
|
}
|
|
}),
|
|
choice.message.tool_calls.as_ref().map(|tc| tc.len()),
|
|
choice.finish_reason);
|
|
}
|
|
|
|
// Convert response to stream chunks
|
|
let completion = self.convert_response(api_response, model.clone());
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Converted to {} content blocks, stop_reason: {:?}", completion.content.len(), completion.stop_reason);
|
|
|
|
// Emit content blocks as stream chunks
|
|
for block in &completion.content {
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Emitting block: {:?}", block);
|
|
match block {
|
|
ContentBlock::Text { text } => {
|
|
if !text.is_empty() {
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Emitting TextDelta: {} chars", text.len());
|
|
yield Ok(StreamChunk::TextDelta { delta: text.clone() });
|
|
}
|
|
}
|
|
ContentBlock::Thinking { thinking } => {
|
|
yield Ok(StreamChunk::ThinkingDelta { delta: thinking.clone() });
|
|
}
|
|
ContentBlock::ToolUse { id, name, input } => {
|
|
eprintln!("[OpenAiDriver:stream_from_complete] Emitting ToolUse: id={}, name={}", id, name);
|
|
// Emit tool use start
|
|
yield Ok(StreamChunk::ToolUseStart {
|
|
id: id.clone(),
|
|
name: name.clone(),
|
|
});
|
|
// Emit tool use delta with arguments
|
|
if !input.is_null() {
|
|
let args_str = serde_json::to_string(input).unwrap_or_default();
|
|
yield Ok(StreamChunk::ToolUseDelta {
|
|
id: id.clone(),
|
|
delta: args_str,
|
|
});
|
|
}
|
|
// Emit tool use end
|
|
yield Ok(StreamChunk::ToolUseEnd {
|
|
id: id.clone(),
|
|
input: input.clone(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Emit completion
|
|
yield Ok(StreamChunk::Complete {
|
|
input_tokens: completion.input_tokens,
|
|
output_tokens: completion.output_tokens,
|
|
stop_reason: match completion.stop_reason {
|
|
StopReason::EndTurn => "end_turn",
|
|
StopReason::MaxTokens => "max_tokens",
|
|
StopReason::ToolUse => "tool_use",
|
|
StopReason::StopSequence => "stop",
|
|
StopReason::Error => "error",
|
|
}.to_string(),
|
|
});
|
|
})
|
|
}
|
|
}
|
|
|
|
// OpenAI API types
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiRequest {
|
|
model: String,
|
|
messages: Vec<OpenAiMessage>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
temperature: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
stop: Option<Vec<String>>,
|
|
#[serde(default)]
|
|
stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<OpenAiTool>>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiMessage {
|
|
role: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_calls: Option<Vec<OpenAiToolCall>>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiToolCall {
|
|
id: String,
|
|
r#type: String,
|
|
function: FunctionCall,
|
|
}
|
|
|
|
impl Default for OpenAiToolCall {
|
|
fn default() -> Self {
|
|
Self {
|
|
id: String::new(),
|
|
r#type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: String::new(),
|
|
arguments: String::new(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct FunctionCall {
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct OpenAiTool {
|
|
r#type: String,
|
|
function: FunctionDef,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct FunctionDef {
|
|
name: String,
|
|
description: String,
|
|
parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiResponse {
|
|
#[serde(default)]
|
|
choices: Vec<OpenAiChoice>,
|
|
#[serde(default)]
|
|
usage: Option<OpenAiUsage>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiChoice {
|
|
#[serde(default)]
|
|
message: OpenAiResponseMessage,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiResponseMessage {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<OpenAiToolCallResponse>>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiToolCallResponse {
|
|
#[serde(default)]
|
|
id: String,
|
|
#[serde(default)]
|
|
function: FunctionCallResponse,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct FunctionCallResponse {
|
|
#[serde(default)]
|
|
name: String,
|
|
#[serde(default)]
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
struct OpenAiUsage {
|
|
#[serde(default)]
|
|
prompt_tokens: u32,
|
|
#[serde(default)]
|
|
completion_tokens: u32,
|
|
}
|
|
|
|
// OpenAI Streaming types
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiStreamResponse {
|
|
#[serde(default)]
|
|
choices: Vec<OpenAiStreamChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiStreamChoice {
|
|
#[serde(default)]
|
|
delta: OpenAiDelta,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct OpenAiDelta {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiToolCallDelta {
|
|
#[serde(default)]
|
|
id: Option<String>,
|
|
#[serde(default)]
|
|
function: Option<OpenAiFunctionDelta>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiFunctionDelta {
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
arguments: Option<String>,
|
|
}
|