use async_stream::stream; use async_trait::async_trait; use futures::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::pin::Pin; use super::AiProvider; use crate::dto::GenerateRequest; use crate::error::{AiError, AiResult}; #[derive(Debug, Clone)] pub struct OpenAIProvider { client: Client, api_key: String, base_url: String, default_model: String, } impl OpenAIProvider { pub fn new(api_key: String, base_url: String, default_model: String) -> Self { Self { client: Client::new(), api_key, base_url, default_model, } } } #[derive(Serialize)] struct ChatRequest { model: String, max_tokens: u32, temperature: f32, messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, stream: bool, } #[derive(Serialize)] struct ChatToolDef { r#type: String, function: ChatFunctionDef, } #[derive(Serialize)] struct ChatFunctionDef { name: String, description: String, parameters: serde_json::Value, } #[derive(Serialize)] struct ChatMessage { role: String, content: Option, #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_call_id: Option, } #[derive(Deserialize)] struct ChatResponse { choices: Vec, usage: Option, } #[derive(Deserialize)] struct ChatChoice { message: ChatMessageResp, } #[derive(Deserialize, Serialize)] struct ChatMessageResp { content: Option, #[serde(default)] tool_calls: Option>, } #[derive(Debug, Deserialize, Serialize)] struct ChatToolCallResp { id: String, r#type: String, function: ChatFunctionResp, } #[derive(Debug, Deserialize, Serialize)] struct ChatFunctionResp { name: String, arguments: String, } #[derive(Deserialize)] struct ChatUsage { prompt_tokens: u32, completion_tokens: u32, } #[derive(Deserialize)] struct StreamChunk { choices: Vec, } #[derive(Deserialize)] struct StreamChoice { delta: StreamDelta, } #[derive(Deserialize)] struct StreamDelta { content: Option, } #[async_trait] impl AiProvider for OpenAIProvider { async fn stream_generate( &self, req: GenerateRequest, ) -> AiResult> + Send>>> { let model = if req.model.is_empty() { self.default_model.clone() } else { req.model }; let chat_req = ChatRequest { model, max_tokens: req.max_tokens, temperature: req.temperature, messages: vec![ ChatMessage { role: "system".into(), content: Some(req.system_prompt), tool_calls: None, tool_call_id: None, }, ChatMessage { role: "user".into(), content: Some(req.user_prompt), tool_calls: None, tool_call_id: None, }, ], tools: None, stream: true, }; let response = self .client .post(format!("{}/v1/chat/completions", self.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .header("content-type", "application/json") .json(&chat_req) .send() .await .map_err(|e| AiError::ProviderError(format!("OpenAI API 请求失败: {e}")))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); return Err(AiError::ProviderError(format!( "OpenAI API 错误 {status}: {body}" ))); } let s = Box::pin(stream! { let mut stream = response.bytes_stream(); while let Some(chunk_result) = stream.next().await { let bytes = match chunk_result { Ok(b) => b, Err(e) => { yield Err(AiError::ProviderError(format!("流读取错误: {e}"))); break; } }; let text = String::from_utf8_lossy(&bytes); for line in text.lines() { if let Some(data) = line.strip_prefix("data: ") { if data == "[DONE]" { return; } if let Ok(chunk) = serde_json::from_str::(data) { for choice in chunk.choices { if let Some(content) = choice.delta.content { yield Ok(content); } } } } } } }); Ok(s) } async fn generate(&self, req: GenerateRequest) -> AiResult { let start = std::time::Instant::now(); let model = if req.model.is_empty() { self.default_model.clone() } else { req.model.clone() }; let chat_req = ChatRequest { model: model.clone(), max_tokens: req.max_tokens, temperature: req.temperature, messages: vec![ ChatMessage { role: "system".into(), content: Some(req.system_prompt), tool_calls: None, tool_call_id: None, }, ChatMessage { role: "user".into(), content: Some(req.user_prompt), tool_calls: None, tool_call_id: None, }, ], tools: None, stream: false, }; let resp = self .client .post(format!("{}/v1/chat/completions", self.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .header("content-type", "application/json") .json(&chat_req) .send() .await .map_err(|e| AiError::ProviderError(e.to_string()))?; let status = resp.status(); let body = resp .text() .await .map_err(|e| AiError::ProviderError(e.to_string()))?; if !status.is_success() { return Err(AiError::ProviderError(format!("OpenAI {status}: {body}"))); } let parsed: ChatResponse = serde_json::from_str(&body) .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; let content = parsed .choices .first() .and_then(|c| c.message.content.clone()) .unwrap_or_default(); let (input_tokens, output_tokens) = parsed .usage .map(|u| (u.prompt_tokens, u.completion_tokens)) .unwrap_or((0, 0)); Ok(crate::dto::GenerateResponse { content, model, input_tokens, output_tokens, duration_ms: start.elapsed().as_millis() as u64, }) } fn name(&self) -> &str { "openai" } async fn health_check(&self) -> AiResult { let resp = self .client .get(format!("{}/v1/models", self.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .send() .await; match resp { Ok(r) => Ok(r.status().is_success()), Err(_) => Ok(false), } } async fn generate_with_tools( &self, messages: Vec, tools: Vec, system_prompt: &str, model: &str, temperature: f32, max_tokens: u32, ) -> AiResult { use crate::dto::ChatMessageRole; let model = if model == "auto" || model.is_empty() { self.default_model.clone() } else { model.to_string() }; let mut chat_messages = vec![ChatMessage { role: "system".into(), content: Some(system_prompt.to_string()), tool_calls: None, tool_call_id: None, }]; for m in &messages { let (role, content) = match m.role { ChatMessageRole::User => ("user", Some(m.content.clone())), ChatMessageRole::Assistant => ( "assistant", if m.content.is_empty() { None } else { Some(m.content.clone()) }, ), ChatMessageRole::Tool => ("tool", Some(m.content.clone())), }; let tool_calls = m.tool_calls.as_ref().map(|tcs| { tcs.iter() .map(|tc| ChatToolCallResp { id: tc.id.clone(), r#type: "function".into(), function: ChatFunctionResp { name: tc.name.clone(), arguments: tc.arguments.to_string(), }, }) .collect::>() }); chat_messages.push(ChatMessage { role: role.into(), content, tool_calls, tool_call_id: m.tool_call_id.clone(), }); } let chat_tools: Vec = tools .into_iter() .map(|t| ChatToolDef { r#type: "function".into(), function: ChatFunctionDef { name: t.name, description: t.description, parameters: t.parameters, }, }) .collect(); let req = ChatRequest { model: model.clone(), max_tokens, temperature, messages: chat_messages, tools: Some(chat_tools), stream: false, }; let resp = self .client .post(format!("{}/v1/chat/completions", self.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .header("content-type", "application/json") .json(&req) .send() .await .map_err(|e| AiError::ProviderError(e.to_string()))?; let status = resp.status(); let body = resp .text() .await .map_err(|e| AiError::ProviderError(e.to_string()))?; if !status.is_success() { return Err(AiError::ProviderError(format!("OpenAI {status}: {body}"))); } let parsed: ChatResponse = serde_json::from_str(&body) .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; let msg = parsed .choices .first() .map(|c| &c.message) .ok_or_else(|| AiError::ProviderError("无响应选项".into()))?; let tool_calls = msg.tool_calls.as_ref().map(|tcs| { tcs.iter() .map(|tc| crate::dto::ToolCall { id: tc.id.clone(), name: tc.function.name.clone(), arguments: serde_json::from_str(&tc.function.arguments) .unwrap_or(serde_json::Value::Null), }) .collect::>() }); let usage = parsed.usage.map(|u| crate::dto::TokenUsage { input: u.prompt_tokens, output: u.completion_tokens, }); Ok(crate::dto::AgentGenerateResponse { content: msg.content.clone(), tool_calls, usage, }) } } #[cfg(test)] mod tests { use super::*; #[test] fn openai_provider_construction() { let provider = OpenAIProvider::new( "sk-test".into(), "https://api.openai.com".into(), "gpt-4o".into(), ); assert_eq!(provider.name(), "openai"); assert_eq!(provider.default_model, "gpt-4o"); } #[test] fn chat_request_serialization() { let req = ChatRequest { model: "gpt-4o".into(), max_tokens: 1024, temperature: 0.7, messages: vec![ ChatMessage { role: "system".into(), content: Some("你是助手".into()), tool_calls: None, tool_call_id: None, }, ChatMessage { role: "user".into(), content: Some("你好".into()), tool_calls: None, tool_call_id: None, }, ], tools: None, stream: false, }; let json = serde_json::to_value(&req).unwrap(); assert_eq!(json["model"], "gpt-4o"); assert_eq!(json["messages"].as_array().unwrap().len(), 2); } #[test] fn chat_response_deserialization() { let json = r#"{ "choices": [{"message": {"role": "assistant", "content": "你好!"}}], "usage": {"prompt_tokens": 10, "completion_tokens": 5} }"#; let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices[0].message.content.as_deref(), Some("你好!")); assert_eq!(resp.usage.unwrap().prompt_tokens, 10); } #[test] fn stream_chunk_deserialization() { let json = r#"{ "choices": [{"delta": {"content": "Hello"}, "index": 0}] }"#; let chunk: StreamChunk = serde_json::from_str(json).unwrap(); assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello")); } }