feat(ai): 实现 OpenAIProvider 兼容 OpenAI API 格式

支持 /v1/chat/completions 端点的流式/非流式生成 + 健康检查
含序列化/反序列化单元测试
This commit is contained in:
iven
2026-05-05 15:08:41 +08:00
parent 74b1d44068
commit b728618d61
2 changed files with 310 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
pub mod claude;
pub mod openai;
pub mod registry;
use async_trait::async_trait;

View File

@@ -0,0 +1,309 @@
use async_stream::stream;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use super::AiProvider;
use crate::dto::GenerateRequest;
use crate::error::{AiError, AiResult};
#[derive(Debug, Clone)]
pub struct OpenAIProvider {
client: Client,
api_key: String,
base_url: String,
default_model: String,
}
impl OpenAIProvider {
pub fn new(api_key: String, base_url: String, default_model: String) -> Self {
Self {
client: Client::new(),
api_key,
base_url,
default_model,
}
}
}
#[derive(Serialize)]
struct ChatRequest {
model: String,
max_tokens: u32,
temperature: f32,
messages: Vec<ChatMessage>,
stream: bool,
}
#[derive(Serialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Deserialize)]
struct ChatResponse {
choices: Vec<ChatChoice>,
usage: Option<ChatUsage>,
}
#[derive(Deserialize)]
struct ChatChoice {
message: ChatMessageResp,
}
#[derive(Deserialize)]
struct ChatMessageResp {
content: Option<String>,
}
#[derive(Deserialize)]
struct ChatUsage {
prompt_tokens: u32,
completion_tokens: u32,
}
#[derive(Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Deserialize)]
struct StreamChoice {
delta: StreamDelta,
}
#[derive(Deserialize)]
struct StreamDelta {
content: Option<String>,
}
#[async_trait]
impl AiProvider for OpenAIProvider {
async fn stream_generate(
&self,
req: GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
let model = if req.model.is_empty() {
self.default_model.clone()
} else {
req.model
};
let chat_req = ChatRequest {
model,
max_tokens: req.max_tokens,
temperature: req.temperature,
messages: vec![
ChatMessage {
role: "system".into(),
content: req.system_prompt,
},
ChatMessage {
role: "user".into(),
content: req.user_prompt,
},
],
stream: true,
};
let response = self
.client
.post(format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("content-type", "application/json")
.json(&chat_req)
.send()
.await
.map_err(|e| AiError::ProviderError(format!("OpenAI API 请求失败: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AiError::ProviderError(format!(
"OpenAI API 错误 {status}: {body}"
)));
}
let s = Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let bytes = match chunk_result {
Ok(b) => b,
Err(e) => {
yield Err(AiError::ProviderError(format!("流读取错误: {e}")));
break;
}
};
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return;
}
if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
for choice in chunk.choices {
if let Some(content) = choice.delta.content {
yield Ok(content);
}
}
}
}
}
}
});
Ok(s)
}
async fn generate(&self, req: GenerateRequest) -> AiResult<crate::dto::GenerateResponse> {
let start = std::time::Instant::now();
let model = if req.model.is_empty() {
self.default_model.clone()
} else {
req.model.clone()
};
let chat_req = ChatRequest {
model: model.clone(),
max_tokens: req.max_tokens,
temperature: req.temperature,
messages: vec![
ChatMessage {
role: "system".into(),
content: req.system_prompt,
},
ChatMessage {
role: "user".into(),
content: req.user_prompt,
},
],
stream: false,
};
let resp = self
.client
.post(format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("content-type", "application/json")
.json(&chat_req)
.send()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
if !status.is_success() {
return Err(AiError::ProviderError(format!(
"OpenAI {status}: {body}"
)));
}
let parsed: ChatResponse = serde_json::from_str(&body)
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
let content = parsed
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default();
let (input_tokens, output_tokens) = parsed
.usage
.map(|u| (u.prompt_tokens, u.completion_tokens))
.unwrap_or((0, 0));
Ok(crate::dto::GenerateResponse {
content,
model,
input_tokens,
output_tokens,
duration_ms: start.elapsed().as_millis() as u64,
})
}
fn name(&self) -> &str {
"openai"
}
async fn health_check(&self) -> AiResult<bool> {
let resp = self
.client
.get(format!("{}/v1/models", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await;
match resp {
Ok(r) => Ok(r.status().is_success()),
Err(_) => Ok(false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn openai_provider_construction() {
let provider = OpenAIProvider::new(
"sk-test".into(),
"https://api.openai.com".into(),
"gpt-4o".into(),
);
assert_eq!(provider.name(), "openai");
assert_eq!(provider.default_model, "gpt-4o");
}
#[test]
fn chat_request_serialization() {
let req = ChatRequest {
model: "gpt-4o".into(),
max_tokens: 1024,
temperature: 0.7,
messages: vec![
ChatMessage {
role: "system".into(),
content: "你是助手".into(),
},
ChatMessage {
role: "user".into(),
content: "你好".into(),
},
],
stream: false,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "gpt-4o");
assert_eq!(json["messages"].as_array().unwrap().len(), 2);
}
#[test]
fn chat_response_deserialization() {
let json = r#"{
"choices": [{"message": {"role": "assistant", "content": "你好!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5}
}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content.as_deref(), Some("你好!"));
assert_eq!(resp.usage.unwrap().prompt_tokens, 10);
}
#[test]
fn stream_chunk_deserialization() {
let json = r#"{
"choices": [{"delta": {"content": "Hello"}, "index": 0}]
}"#;
let chunk: StreamChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
}
}