diff --git a/crates/erp-ai/src/provider/openai.rs b/crates/erp-ai/src/provider/openai.rs index a188da3..6121922 100644 --- a/crates/erp-ai/src/provider/openai.rs +++ b/crates/erp-ai/src/provider/openai.rs @@ -34,13 +34,32 @@ struct ChatRequest { max_tokens: u32, temperature: f32, messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, stream: bool, } +#[derive(Serialize)] +struct ChatToolDef { + r#type: String, + function: ChatFunctionDef, +} + +#[derive(Serialize)] +struct ChatFunctionDef { + name: String, + description: String, + parameters: serde_json::Value, +} + #[derive(Serialize)] struct ChatMessage { role: String, - content: String, + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, } #[derive(Deserialize)] @@ -54,9 +73,24 @@ struct ChatChoice { message: ChatMessageResp, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] struct ChatMessageResp { content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +struct ChatToolCallResp { + id: String, + r#type: String, + function: ChatFunctionResp, +} + +#[derive(Debug, Deserialize, Serialize)] +struct ChatFunctionResp { + name: String, + arguments: String, } #[derive(Deserialize)] @@ -99,13 +133,18 @@ impl AiProvider for OpenAIProvider { messages: vec![ ChatMessage { role: "system".into(), - content: req.system_prompt, + content: Some(req.system_prompt), + tool_calls: None, + tool_call_id: None, }, ChatMessage { role: "user".into(), - content: req.user_prompt, + content: Some(req.user_prompt), + tool_calls: None, + tool_call_id: None, }, ], + tools: None, stream: true, }; @@ -175,13 +214,18 @@ impl AiProvider for OpenAIProvider { messages: vec![ ChatMessage { role: "system".into(), - content: req.system_prompt, + content: Some(req.system_prompt), + tool_calls: None, + tool_call_id: None, }, ChatMessage { role: "user".into(), - content: req.user_prompt, + content: Some(req.user_prompt), + tool_calls: None, + tool_call_id: None, }, ], + tools: None, stream: false, }; @@ -245,6 +289,138 @@ impl AiProvider for OpenAIProvider { Err(_) => Ok(false), } } + + async fn generate_with_tools( + &self, + messages: Vec, + tools: Vec, + system_prompt: &str, + model: &str, + temperature: f32, + max_tokens: u32, + ) -> AiResult { + use crate::dto::ChatMessageRole; + + let model = if model == "auto" || model.is_empty() { + self.default_model.clone() + } else { + model.to_string() + }; + + let mut chat_messages = vec![ChatMessage { + role: "system".into(), + content: Some(system_prompt.to_string()), + tool_calls: None, + tool_call_id: None, + }]; + + for m in &messages { + let (role, content) = match m.role { + ChatMessageRole::User => ("user", Some(m.content.clone())), + ChatMessageRole::Assistant => ( + "assistant", + if m.content.is_empty() { + None + } else { + Some(m.content.clone()) + }, + ), + ChatMessageRole::Tool => ("tool", Some(m.content.clone())), + }; + + let tool_calls = m.tool_calls.as_ref().map(|tcs| { + tcs.iter() + .map(|tc| ChatToolCallResp { + id: tc.id.clone(), + r#type: "function".into(), + function: ChatFunctionResp { + name: tc.name.clone(), + arguments: tc.arguments.to_string(), + }, + }) + .collect::>() + }); + + chat_messages.push(ChatMessage { + role: role.into(), + content, + tool_calls, + tool_call_id: m.tool_call_id.clone(), + }); + } + + let chat_tools: Vec = tools + .into_iter() + .map(|t| ChatToolDef { + r#type: "function".into(), + function: ChatFunctionDef { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }) + .collect(); + + let req = ChatRequest { + model: model.clone(), + max_tokens, + temperature, + messages: chat_messages, + tools: Some(chat_tools), + stream: false, + }; + + let resp = self + .client + .post(format!("{}/v1/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("content-type", "application/json") + .json(&req) + .send() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + if !status.is_success() { + return Err(AiError::ProviderError(format!("OpenAI {status}: {body}"))); + } + + let parsed: ChatResponse = serde_json::from_str(&body) + .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; + + let msg = parsed + .choices + .first() + .map(|c| &c.message) + .ok_or_else(|| AiError::ProviderError("无响应选项".into()))?; + + let tool_calls = msg.tool_calls.as_ref().map(|tcs| { + tcs.iter() + .map(|tc| crate::dto::ToolCall { + id: tc.id.clone(), + name: tc.function.name.clone(), + arguments: serde_json::from_str(&tc.function.arguments) + .unwrap_or(serde_json::Value::Null), + }) + .collect::>() + }); + + let usage = parsed.usage.map(|u| crate::dto::TokenUsage { + input: u.prompt_tokens, + output: u.completion_tokens, + }); + + Ok(crate::dto::AgentGenerateResponse { + content: msg.content.clone(), + tool_calls, + usage, + }) + } } #[cfg(test)] @@ -271,13 +447,18 @@ mod tests { messages: vec![ ChatMessage { role: "system".into(), - content: "你是助手".into(), + content: Some("你是助手".into()), + tool_calls: None, + tool_call_id: None, }, ChatMessage { role: "user".into(), - content: "你好".into(), + content: Some("你好".into()), + tool_calls: None, + tool_call_id: None, }, ], + tools: None, stream: false, }; let json = serde_json::to_value(&req).unwrap();