feat(ai): OpenAI Provider 实现 generate_with_tools — function calling 支持
This commit is contained in:
@@ -34,13 +34,32 @@ struct ChatRequest {
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ChatToolDef>>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatToolDef {
|
||||
r#type: String,
|
||||
function: ChatFunctionDef,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatFunctionDef {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ChatToolCallResp>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -54,9 +73,24 @@ struct ChatChoice {
|
||||
message: ChatMessageResp,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct ChatMessageResp {
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<ChatToolCallResp>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ChatToolCallResp {
|
||||
id: String,
|
||||
r#type: String,
|
||||
function: ChatFunctionResp,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ChatFunctionResp {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -99,13 +133,18 @@ impl AiProvider for OpenAIProvider {
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".into(),
|
||||
content: req.system_prompt,
|
||||
content: Some(req.system_prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: req.user_prompt,
|
||||
content: Some(req.user_prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
tools: None,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
@@ -175,13 +214,18 @@ impl AiProvider for OpenAIProvider {
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".into(),
|
||||
content: req.system_prompt,
|
||||
content: Some(req.system_prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: req.user_prompt,
|
||||
content: Some(req.user_prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
tools: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
@@ -245,6 +289,138 @@ impl AiProvider for OpenAIProvider {
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
messages: Vec<crate::dto::ChatMessage>,
|
||||
tools: Vec<crate::dto::ToolDefinition>,
|
||||
system_prompt: &str,
|
||||
model: &str,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
) -> AiResult<crate::dto::AgentGenerateResponse> {
|
||||
use crate::dto::ChatMessageRole;
|
||||
|
||||
let model = if model == "auto" || model.is_empty() {
|
||||
self.default_model.clone()
|
||||
} else {
|
||||
model.to_string()
|
||||
};
|
||||
|
||||
let mut chat_messages = vec![ChatMessage {
|
||||
role: "system".into(),
|
||||
content: Some(system_prompt.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
|
||||
for m in &messages {
|
||||
let (role, content) = match m.role {
|
||||
ChatMessageRole::User => ("user", Some(m.content.clone())),
|
||||
ChatMessageRole::Assistant => (
|
||||
"assistant",
|
||||
if m.content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(m.content.clone())
|
||||
},
|
||||
),
|
||||
ChatMessageRole::Tool => ("tool", Some(m.content.clone())),
|
||||
};
|
||||
|
||||
let tool_calls = m.tool_calls.as_ref().map(|tcs| {
|
||||
tcs.iter()
|
||||
.map(|tc| ChatToolCallResp {
|
||||
id: tc.id.clone(),
|
||||
r#type: "function".into(),
|
||||
function: ChatFunctionResp {
|
||||
name: tc.name.clone(),
|
||||
arguments: tc.arguments.to_string(),
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
chat_messages.push(ChatMessage {
|
||||
role: role.into(),
|
||||
content,
|
||||
tool_calls,
|
||||
tool_call_id: m.tool_call_id.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let chat_tools: Vec<ChatToolDef> = tools
|
||||
.into_iter()
|
||||
.map(|t| ChatToolDef {
|
||||
r#type: "function".into(),
|
||||
function: ChatFunctionDef {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.parameters,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
let req = ChatRequest {
|
||||
model: model.clone(),
|
||||
max_tokens,
|
||||
temperature,
|
||||
messages: chat_messages,
|
||||
tools: Some(chat_tools),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/v1/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("content-type", "application/json")
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(AiError::ProviderError(format!("OpenAI {status}: {body}")));
|
||||
}
|
||||
|
||||
let parsed: ChatResponse = serde_json::from_str(&body)
|
||||
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
|
||||
|
||||
let msg = parsed
|
||||
.choices
|
||||
.first()
|
||||
.map(|c| &c.message)
|
||||
.ok_or_else(|| AiError::ProviderError("无响应选项".into()))?;
|
||||
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|tcs| {
|
||||
tcs.iter()
|
||||
.map(|tc| crate::dto::ToolCall {
|
||||
id: tc.id.clone(),
|
||||
name: tc.function.name.clone(),
|
||||
arguments: serde_json::from_str(&tc.function.arguments)
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let usage = parsed.usage.map(|u| crate::dto::TokenUsage {
|
||||
input: u.prompt_tokens,
|
||||
output: u.completion_tokens,
|
||||
});
|
||||
|
||||
Ok(crate::dto::AgentGenerateResponse {
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -271,13 +447,18 @@ mod tests {
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".into(),
|
||||
content: "你是助手".into(),
|
||||
content: Some("你是助手".into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: "你好".into(),
|
||||
content: Some("你好".into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
tools: None,
|
||||
stream: false,
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user