feat(ai): Claude Provider 实现 generate_with_tools — tool_use/tool_result 解析

This commit is contained in:
iven
2026-05-18 02:32:39 +08:00
parent cad48a97d5
commit 64456d0172

View File

@@ -31,6 +31,13 @@ impl ClaudeProvider {
}
}
#[derive(Serialize)]
#[serde(untagged)]
enum ClaudeContent {
Text(String),
Blocks(Vec<serde_json::Value>),
}
#[derive(Serialize)]
struct ClaudeRequest {
model: String,
@@ -38,13 +45,22 @@ struct ClaudeRequest {
temperature: f32,
system: String,
messages: Vec<ClaudeMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ClaudeToolDef>>,
stream: bool,
}
#[derive(Serialize)]
struct ClaudeMessage {
role: String,
content: String,
content: ClaudeContent,
}
#[derive(Serialize)]
struct ClaudeToolDef {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Deserialize)]
@@ -88,8 +104,9 @@ impl AiProvider for ClaudeProvider {
system: req.system_prompt,
messages: vec![ClaudeMessage {
role: "user".into(),
content: req.user_prompt,
content: ClaudeContent::Text(req.user_prompt),
}],
tools: None,
stream: true,
};
@@ -153,8 +170,9 @@ impl AiProvider for ClaudeProvider {
system: req.system_prompt,
messages: vec![ClaudeMessage {
role: "user".into(),
content: req.user_prompt,
content: ClaudeContent::Text(req.user_prompt),
}],
tools: None,
stream: false,
};
@@ -223,4 +241,138 @@ impl AiProvider for ClaudeProvider {
Err(_) => Ok(false),
}
}
async fn generate_with_tools(
&self,
messages: Vec<crate::dto::ChatMessage>,
tools: Vec<crate::dto::ToolDefinition>,
system_prompt: &str,
model: &str,
temperature: f32,
max_tokens: u32,
) -> AiResult<crate::dto::AgentGenerateResponse> {
use crate::dto::ChatMessageRole;
let claude_messages: Vec<ClaudeMessage> = messages
.iter()
.map(|m| {
let role = match m.role {
ChatMessageRole::User => "user",
ChatMessageRole::Assistant => "assistant",
ChatMessageRole::Tool => "user",
};
let content = match m.role {
ChatMessageRole::Tool => {
let blocks = vec![serde_json::json!({
"type": "tool_result",
"tool_use_id": m.tool_call_id.as_deref().unwrap_or(""),
"content": m.content,
})];
ClaudeContent::Blocks(blocks)
}
ChatMessageRole::Assistant if m.tool_calls.is_some() => {
let mut blocks = vec![];
if !m.content.is_empty() {
blocks.push(serde_json::json!({
"type": "text",
"text": m.content,
}));
}
for tc in m.tool_calls.as_ref().unwrap() {
blocks.push(serde_json::json!({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments,
}));
}
ClaudeContent::Blocks(blocks)
}
_ => ClaudeContent::Text(m.content.clone()),
};
ClaudeMessage {
role: role.to_string(),
content,
}
})
.collect();
let claude_tools: Vec<ClaudeToolDef> = tools
.into_iter()
.map(|t| ClaudeToolDef {
name: t.name,
description: t.description,
input_schema: t.parameters,
})
.collect();
let req = ClaudeRequest {
model: model.to_string(),
max_tokens,
temperature,
system: system_prompt.to_string(),
messages: claude_messages,
tools: Some(claude_tools),
stream: false,
};
let resp = self
.client
.post(format!("{}/v1/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
if !status.is_success() {
return Err(AiError::ProviderError(format!("Claude {status}: {body}")));
}
let parsed: serde_json::Value = serde_json::from_str(&body)
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
let mut content_text = None;
let mut tool_calls = None;
if let Some(blocks) = parsed["content"].as_array() {
for block in blocks {
match block["type"].as_str() {
Some("text") => {
content_text = block["text"].as_str().map(|s| s.to_string());
}
Some("tool_use") => {
let tc = crate::dto::ToolCall {
id: block["id"].as_str().unwrap_or_default().to_string(),
name: block["name"].as_str().unwrap_or_default().to_string(),
arguments: block["input"].clone(),
};
tool_calls.get_or_insert_with(Vec::new).push(tc);
}
_ => {}
}
}
}
let usage = parsed["usage"].as_object().map(|u| crate::dto::TokenUsage {
input: u["input_tokens"].as_u64().unwrap_or(0) as u32,
output: u["output_tokens"].as_u64().unwrap_or(0) as u32,
});
Ok(crate::dto::AgentGenerateResponse {
content: content_text,
tool_calls,
usage,
})
}
}