diff --git a/crates/erp-ai/src/provider/mod.rs b/crates/erp-ai/src/provider/mod.rs index c5d7aeb..fd68e14 100644 --- a/crates/erp-ai/src/provider/mod.rs +++ b/crates/erp-ai/src/provider/mod.rs @@ -1,4 +1,5 @@ pub mod claude; +pub mod openai; pub mod registry; use async_trait::async_trait; diff --git a/crates/erp-ai/src/provider/openai.rs b/crates/erp-ai/src/provider/openai.rs new file mode 100644 index 0000000..7eb48de --- /dev/null +++ b/crates/erp-ai/src/provider/openai.rs @@ -0,0 +1,309 @@ +use async_stream::stream; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; + +use super::AiProvider; +use crate::dto::GenerateRequest; +use crate::error::{AiError, AiResult}; + +#[derive(Debug, Clone)] +pub struct OpenAIProvider { + client: Client, + api_key: String, + base_url: String, + default_model: String, +} + +impl OpenAIProvider { + pub fn new(api_key: String, base_url: String, default_model: String) -> Self { + Self { + client: Client::new(), + api_key, + base_url, + default_model, + } + } +} + +#[derive(Serialize)] +struct ChatRequest { + model: String, + max_tokens: u32, + temperature: f32, + messages: Vec, + stream: bool, +} + +#[derive(Serialize)] +struct ChatMessage { + role: String, + content: String, +} + +#[derive(Deserialize)] +struct ChatResponse { + choices: Vec, + usage: Option, +} + +#[derive(Deserialize)] +struct ChatChoice { + message: ChatMessageResp, +} + +#[derive(Deserialize)] +struct ChatMessageResp { + content: Option, +} + +#[derive(Deserialize)] +struct ChatUsage { + prompt_tokens: u32, + completion_tokens: u32, +} + +#[derive(Deserialize)] +struct StreamChunk { + choices: Vec, +} + +#[derive(Deserialize)] +struct StreamChoice { + delta: StreamDelta, +} + +#[derive(Deserialize)] +struct StreamDelta { + content: Option, +} + +#[async_trait] +impl AiProvider for OpenAIProvider { + async fn stream_generate( + &self, + req: GenerateRequest, + ) -> AiResult> + Send>>> { + let model = if req.model.is_empty() { + self.default_model.clone() + } else { + req.model + }; + + let chat_req = ChatRequest { + model, + max_tokens: req.max_tokens, + temperature: req.temperature, + messages: vec![ + ChatMessage { + role: "system".into(), + content: req.system_prompt, + }, + ChatMessage { + role: "user".into(), + content: req.user_prompt, + }, + ], + stream: true, + }; + + let response = self + .client + .post(format!("{}/v1/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("content-type", "application/json") + .json(&chat_req) + .send() + .await + .map_err(|e| AiError::ProviderError(format!("OpenAI API 请求失败: {e}")))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(AiError::ProviderError(format!( + "OpenAI API 错误 {status}: {body}" + ))); + } + + let s = Box::pin(stream! { + let mut stream = response.bytes_stream(); + while let Some(chunk_result) = stream.next().await { + let bytes = match chunk_result { + Ok(b) => b, + Err(e) => { + yield Err(AiError::ProviderError(format!("流读取错误: {e}"))); + break; + } + }; + + let text = String::from_utf8_lossy(&bytes); + for line in text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + return; + } + if let Ok(chunk) = serde_json::from_str::(data) { + for choice in chunk.choices { + if let Some(content) = choice.delta.content { + yield Ok(content); + } + } + } + } + } + } + }); + + Ok(s) + } + + async fn generate(&self, req: GenerateRequest) -> AiResult { + let start = std::time::Instant::now(); + + let model = if req.model.is_empty() { + self.default_model.clone() + } else { + req.model.clone() + }; + + let chat_req = ChatRequest { + model: model.clone(), + max_tokens: req.max_tokens, + temperature: req.temperature, + messages: vec![ + ChatMessage { + role: "system".into(), + content: req.system_prompt, + }, + ChatMessage { + role: "user".into(), + content: req.user_prompt, + }, + ], + stream: false, + }; + + let resp = self + .client + .post(format!("{}/v1/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("content-type", "application/json") + .json(&chat_req) + .send() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + if !status.is_success() { + return Err(AiError::ProviderError(format!( + "OpenAI {status}: {body}" + ))); + } + + let parsed: ChatResponse = serde_json::from_str(&body) + .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; + + let content = parsed + .choices + .first() + .and_then(|c| c.message.content.clone()) + .unwrap_or_default(); + + let (input_tokens, output_tokens) = parsed + .usage + .map(|u| (u.prompt_tokens, u.completion_tokens)) + .unwrap_or((0, 0)); + + Ok(crate::dto::GenerateResponse { + content, + model, + input_tokens, + output_tokens, + duration_ms: start.elapsed().as_millis() as u64, + }) + } + + fn name(&self) -> &str { + "openai" + } + + async fn health_check(&self) -> AiResult { + let resp = self + .client + .get(format!("{}/v1/models", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .send() + .await; + + match resp { + Ok(r) => Ok(r.status().is_success()), + Err(_) => Ok(false), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn openai_provider_construction() { + let provider = OpenAIProvider::new( + "sk-test".into(), + "https://api.openai.com".into(), + "gpt-4o".into(), + ); + assert_eq!(provider.name(), "openai"); + assert_eq!(provider.default_model, "gpt-4o"); + } + + #[test] + fn chat_request_serialization() { + let req = ChatRequest { + model: "gpt-4o".into(), + max_tokens: 1024, + temperature: 0.7, + messages: vec![ + ChatMessage { + role: "system".into(), + content: "你是助手".into(), + }, + ChatMessage { + role: "user".into(), + content: "你好".into(), + }, + ], + stream: false, + }; + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["model"], "gpt-4o"); + assert_eq!(json["messages"].as_array().unwrap().len(), 2); + } + + #[test] + fn chat_response_deserialization() { + let json = r#"{ + "choices": [{"message": {"role": "assistant", "content": "你好!"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5} + }"#; + let resp: ChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.choices[0].message.content.as_deref(), Some("你好!")); + assert_eq!(resp.usage.unwrap().prompt_tokens, 10); + } + + #[test] + fn stream_chunk_deserialization() { + let json = r#"{ + "choices": [{"delta": {"content": "Hello"}, "index": 0}] + }"#; + let chunk: StreamChunk = serde_json::from_str(json).unwrap(); + assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello")); + } +}