489 lines
14 KiB
Rust
489 lines
14 KiB
Rust
use async_stream::stream;
|
|
use async_trait::async_trait;
|
|
use futures::{Stream, StreamExt};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::pin::Pin;
|
|
|
|
use super::AiProvider;
|
|
use crate::dto::GenerateRequest;
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct OpenAIProvider {
|
|
client: Client,
|
|
api_key: String,
|
|
base_url: String,
|
|
default_model: String,
|
|
}
|
|
|
|
impl OpenAIProvider {
|
|
pub fn new(api_key: String, base_url: String, default_model: String) -> Self {
|
|
Self {
|
|
client: Client::new(),
|
|
api_key,
|
|
base_url,
|
|
default_model,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ChatRequest {
|
|
model: String,
|
|
max_tokens: u32,
|
|
temperature: f32,
|
|
messages: Vec<ChatMessage>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<ChatToolDef>>,
|
|
stream: bool,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ChatToolDef {
|
|
r#type: String,
|
|
function: ChatFunctionDef,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ChatFunctionDef {
|
|
name: String,
|
|
description: String,
|
|
parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ChatMessage {
|
|
role: String,
|
|
content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_calls: Option<Vec<ChatToolCallResp>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_call_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct ChatResponse {
|
|
choices: Vec<ChatChoice>,
|
|
usage: Option<ChatUsage>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct ChatChoice {
|
|
message: ChatMessageResp,
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize)]
|
|
struct ChatMessageResp {
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<ChatToolCallResp>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
struct ChatToolCallResp {
|
|
id: String,
|
|
r#type: String,
|
|
function: ChatFunctionResp,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
struct ChatFunctionResp {
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct ChatUsage {
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct StreamChunk {
|
|
choices: Vec<StreamChoice>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct StreamChoice {
|
|
delta: StreamDelta,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct StreamDelta {
|
|
content: Option<String>,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AiProvider for OpenAIProvider {
|
|
async fn stream_generate(
|
|
&self,
|
|
req: GenerateRequest,
|
|
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
|
|
let model = if req.model.is_empty() {
|
|
self.default_model.clone()
|
|
} else {
|
|
req.model
|
|
};
|
|
|
|
let chat_req = ChatRequest {
|
|
model,
|
|
max_tokens: req.max_tokens,
|
|
temperature: req.temperature,
|
|
messages: vec![
|
|
ChatMessage {
|
|
role: "system".into(),
|
|
content: Some(req.system_prompt),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
ChatMessage {
|
|
role: "user".into(),
|
|
content: Some(req.user_prompt),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
],
|
|
tools: None,
|
|
stream: true,
|
|
};
|
|
|
|
let response = self
|
|
.client
|
|
.post(format!("{}/v1/chat/completions", self.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("content-type", "application/json")
|
|
.json(&chat_req)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AiError::ProviderError(format!("OpenAI API 请求失败: {e}")))?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
return Err(AiError::ProviderError(format!(
|
|
"OpenAI API 错误 {status}: {body}"
|
|
)));
|
|
}
|
|
|
|
let s = Box::pin(stream! {
|
|
let mut stream = response.bytes_stream();
|
|
while let Some(chunk_result) = stream.next().await {
|
|
let bytes = match chunk_result {
|
|
Ok(b) => b,
|
|
Err(e) => {
|
|
yield Err(AiError::ProviderError(format!("流读取错误: {e}")));
|
|
break;
|
|
}
|
|
};
|
|
|
|
let text = String::from_utf8_lossy(&bytes);
|
|
for line in text.lines() {
|
|
if let Some(data) = line.strip_prefix("data: ") {
|
|
if data == "[DONE]" {
|
|
return;
|
|
}
|
|
if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
|
|
for choice in chunk.choices {
|
|
if let Some(content) = choice.delta.content {
|
|
yield Ok(content);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(s)
|
|
}
|
|
|
|
async fn generate(&self, req: GenerateRequest) -> AiResult<crate::dto::GenerateResponse> {
|
|
let start = std::time::Instant::now();
|
|
|
|
let model = if req.model.is_empty() {
|
|
self.default_model.clone()
|
|
} else {
|
|
req.model.clone()
|
|
};
|
|
|
|
let chat_req = ChatRequest {
|
|
model: model.clone(),
|
|
max_tokens: req.max_tokens,
|
|
temperature: req.temperature,
|
|
messages: vec![
|
|
ChatMessage {
|
|
role: "system".into(),
|
|
content: Some(req.system_prompt),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
ChatMessage {
|
|
role: "user".into(),
|
|
content: Some(req.user_prompt),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
],
|
|
tools: None,
|
|
stream: false,
|
|
};
|
|
|
|
let resp = self
|
|
.client
|
|
.post(format!("{}/v1/chat/completions", self.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("content-type", "application/json")
|
|
.json(&chat_req)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
|
|
|
let status = resp.status();
|
|
let body = resp
|
|
.text()
|
|
.await
|
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
|
|
|
if !status.is_success() {
|
|
return Err(AiError::ProviderError(format!("OpenAI {status}: {body}")));
|
|
}
|
|
|
|
let parsed: ChatResponse = serde_json::from_str(&body)
|
|
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
|
|
|
|
let content = parsed
|
|
.choices
|
|
.first()
|
|
.and_then(|c| c.message.content.clone())
|
|
.unwrap_or_default();
|
|
|
|
let (input_tokens, output_tokens) = parsed
|
|
.usage
|
|
.map(|u| (u.prompt_tokens, u.completion_tokens))
|
|
.unwrap_or((0, 0));
|
|
|
|
Ok(crate::dto::GenerateResponse {
|
|
content,
|
|
model,
|
|
input_tokens,
|
|
output_tokens,
|
|
duration_ms: start.elapsed().as_millis() as u64,
|
|
})
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
async fn health_check(&self) -> AiResult<bool> {
|
|
let resp = self
|
|
.client
|
|
.get(format!("{}/v1/models", self.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.send()
|
|
.await;
|
|
|
|
match resp {
|
|
Ok(r) => Ok(r.status().is_success()),
|
|
Err(_) => Ok(false),
|
|
}
|
|
}
|
|
|
|
async fn generate_with_tools(
|
|
&self,
|
|
messages: Vec<crate::dto::ChatMessage>,
|
|
tools: Vec<crate::dto::ToolDefinition>,
|
|
system_prompt: &str,
|
|
model: &str,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
) -> AiResult<crate::dto::AgentGenerateResponse> {
|
|
use crate::dto::ChatMessageRole;
|
|
|
|
let model = if model == "auto" || model.is_empty() {
|
|
self.default_model.clone()
|
|
} else {
|
|
model.to_string()
|
|
};
|
|
|
|
let mut chat_messages = vec![ChatMessage {
|
|
role: "system".into(),
|
|
content: Some(system_prompt.to_string()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
}];
|
|
|
|
for m in &messages {
|
|
let (role, content) = match m.role {
|
|
ChatMessageRole::User => ("user", Some(m.content.clone())),
|
|
ChatMessageRole::Assistant => (
|
|
"assistant",
|
|
if m.content.is_empty() {
|
|
None
|
|
} else {
|
|
Some(m.content.clone())
|
|
},
|
|
),
|
|
ChatMessageRole::Tool => ("tool", Some(m.content.clone())),
|
|
};
|
|
|
|
let tool_calls = m.tool_calls.as_ref().map(|tcs| {
|
|
tcs.iter()
|
|
.map(|tc| ChatToolCallResp {
|
|
id: tc.id.clone(),
|
|
r#type: "function".into(),
|
|
function: ChatFunctionResp {
|
|
name: tc.name.clone(),
|
|
arguments: tc.arguments.to_string(),
|
|
},
|
|
})
|
|
.collect::<Vec<_>>()
|
|
});
|
|
|
|
chat_messages.push(ChatMessage {
|
|
role: role.into(),
|
|
content,
|
|
tool_calls,
|
|
tool_call_id: m.tool_call_id.clone(),
|
|
});
|
|
}
|
|
|
|
let chat_tools: Vec<ChatToolDef> = tools
|
|
.into_iter()
|
|
.map(|t| ChatToolDef {
|
|
r#type: "function".into(),
|
|
function: ChatFunctionDef {
|
|
name: t.name,
|
|
description: t.description,
|
|
parameters: t.parameters,
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
let req = ChatRequest {
|
|
model: model.clone(),
|
|
max_tokens,
|
|
temperature,
|
|
messages: chat_messages,
|
|
tools: Some(chat_tools),
|
|
stream: false,
|
|
};
|
|
|
|
let resp = self
|
|
.client
|
|
.post(format!("{}/v1/chat/completions", self.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("content-type", "application/json")
|
|
.json(&req)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
|
|
|
let status = resp.status();
|
|
let body = resp
|
|
.text()
|
|
.await
|
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
|
|
|
if !status.is_success() {
|
|
return Err(AiError::ProviderError(format!("OpenAI {status}: {body}")));
|
|
}
|
|
|
|
let parsed: ChatResponse = serde_json::from_str(&body)
|
|
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
|
|
|
|
let msg = parsed
|
|
.choices
|
|
.first()
|
|
.map(|c| &c.message)
|
|
.ok_or_else(|| AiError::ProviderError("无响应选项".into()))?;
|
|
|
|
let tool_calls = msg.tool_calls.as_ref().map(|tcs| {
|
|
tcs.iter()
|
|
.map(|tc| crate::dto::ToolCall {
|
|
id: tc.id.clone(),
|
|
name: tc.function.name.clone(),
|
|
arguments: serde_json::from_str(&tc.function.arguments)
|
|
.unwrap_or(serde_json::Value::Null),
|
|
})
|
|
.collect::<Vec<_>>()
|
|
});
|
|
|
|
let usage = parsed.usage.map(|u| crate::dto::TokenUsage {
|
|
input: u.prompt_tokens,
|
|
output: u.completion_tokens,
|
|
});
|
|
|
|
Ok(crate::dto::AgentGenerateResponse {
|
|
content: msg.content.clone(),
|
|
tool_calls,
|
|
usage,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn openai_provider_construction() {
|
|
let provider = OpenAIProvider::new(
|
|
"sk-test".into(),
|
|
"https://api.openai.com".into(),
|
|
"gpt-4o".into(),
|
|
);
|
|
assert_eq!(provider.name(), "openai");
|
|
assert_eq!(provider.default_model, "gpt-4o");
|
|
}
|
|
|
|
#[test]
|
|
fn chat_request_serialization() {
|
|
let req = ChatRequest {
|
|
model: "gpt-4o".into(),
|
|
max_tokens: 1024,
|
|
temperature: 0.7,
|
|
messages: vec![
|
|
ChatMessage {
|
|
role: "system".into(),
|
|
content: Some("你是助手".into()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
ChatMessage {
|
|
role: "user".into(),
|
|
content: Some("你好".into()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
},
|
|
],
|
|
tools: None,
|
|
stream: false,
|
|
};
|
|
let json = serde_json::to_value(&req).unwrap();
|
|
assert_eq!(json["model"], "gpt-4o");
|
|
assert_eq!(json["messages"].as_array().unwrap().len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn chat_response_deserialization() {
|
|
let json = r#"{
|
|
"choices": [{"message": {"role": "assistant", "content": "你好!"}}],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5}
|
|
}"#;
|
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
|
assert_eq!(resp.choices[0].message.content.as_deref(), Some("你好!"));
|
|
assert_eq!(resp.usage.unwrap().prompt_tokens, 10);
|
|
}
|
|
|
|
#[test]
|
|
fn stream_chunk_deserialization() {
|
|
let json = r#"{
|
|
"choices": [{"delta": {"content": "Hello"}, "index": 0}]
|
|
}"#;
|
|
let chunk: StreamChunk = serde_json::from_str(json).unwrap();
|
|
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
|
|
}
|
|
}
|