feat(ai): Claude Provider 实现 generate_with_tools — tool_use/tool_result 解析
This commit is contained in:
@@ -31,6 +31,13 @@ impl ClaudeProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum ClaudeContent {
|
||||||
|
Text(String),
|
||||||
|
Blocks(Vec<serde_json::Value>),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct ClaudeRequest {
|
struct ClaudeRequest {
|
||||||
model: String,
|
model: String,
|
||||||
@@ -38,13 +45,22 @@ struct ClaudeRequest {
|
|||||||
temperature: f32,
|
temperature: f32,
|
||||||
system: String,
|
system: String,
|
||||||
messages: Vec<ClaudeMessage>,
|
messages: Vec<ClaudeMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<ClaudeToolDef>>,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct ClaudeMessage {
|
struct ClaudeMessage {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
content: ClaudeContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ClaudeToolDef {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
input_schema: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -88,8 +104,9 @@ impl AiProvider for ClaudeProvider {
|
|||||||
system: req.system_prompt,
|
system: req.system_prompt,
|
||||||
messages: vec![ClaudeMessage {
|
messages: vec![ClaudeMessage {
|
||||||
role: "user".into(),
|
role: "user".into(),
|
||||||
content: req.user_prompt,
|
content: ClaudeContent::Text(req.user_prompt),
|
||||||
}],
|
}],
|
||||||
|
tools: None,
|
||||||
stream: true,
|
stream: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -153,8 +170,9 @@ impl AiProvider for ClaudeProvider {
|
|||||||
system: req.system_prompt,
|
system: req.system_prompt,
|
||||||
messages: vec![ClaudeMessage {
|
messages: vec![ClaudeMessage {
|
||||||
role: "user".into(),
|
role: "user".into(),
|
||||||
content: req.user_prompt,
|
content: ClaudeContent::Text(req.user_prompt),
|
||||||
}],
|
}],
|
||||||
|
tools: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -223,4 +241,138 @@ impl AiProvider for ClaudeProvider {
|
|||||||
Err(_) => Ok(false),
|
Err(_) => Ok(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
messages: Vec<crate::dto::ChatMessage>,
|
||||||
|
tools: Vec<crate::dto::ToolDefinition>,
|
||||||
|
system_prompt: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f32,
|
||||||
|
max_tokens: u32,
|
||||||
|
) -> AiResult<crate::dto::AgentGenerateResponse> {
|
||||||
|
use crate::dto::ChatMessageRole;
|
||||||
|
|
||||||
|
let claude_messages: Vec<ClaudeMessage> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
let role = match m.role {
|
||||||
|
ChatMessageRole::User => "user",
|
||||||
|
ChatMessageRole::Assistant => "assistant",
|
||||||
|
ChatMessageRole::Tool => "user",
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = match m.role {
|
||||||
|
ChatMessageRole::Tool => {
|
||||||
|
let blocks = vec![serde_json::json!({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": m.tool_call_id.as_deref().unwrap_or(""),
|
||||||
|
"content": m.content,
|
||||||
|
})];
|
||||||
|
ClaudeContent::Blocks(blocks)
|
||||||
|
}
|
||||||
|
ChatMessageRole::Assistant if m.tool_calls.is_some() => {
|
||||||
|
let mut blocks = vec![];
|
||||||
|
if !m.content.is_empty() {
|
||||||
|
blocks.push(serde_json::json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": m.content,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for tc in m.tool_calls.as_ref().unwrap() {
|
||||||
|
blocks.push(serde_json::json!({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.id,
|
||||||
|
"name": tc.name,
|
||||||
|
"input": tc.arguments,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
ClaudeContent::Blocks(blocks)
|
||||||
|
}
|
||||||
|
_ => ClaudeContent::Text(m.content.clone()),
|
||||||
|
};
|
||||||
|
|
||||||
|
ClaudeMessage {
|
||||||
|
role: role.to_string(),
|
||||||
|
content,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let claude_tools: Vec<ClaudeToolDef> = tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| ClaudeToolDef {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
input_schema: t.parameters,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let req = ClaudeRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
system: system_prompt.to_string(),
|
||||||
|
messages: claude_messages,
|
||||||
|
tools: Some(claude_tools),
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/v1/messages", self.base_url))
|
||||||
|
.header("x-api-key", &self.api_key)
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json(&req)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
let body = resp
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(AiError::ProviderError(format!("Claude {status}: {body}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(&body)
|
||||||
|
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
|
||||||
|
|
||||||
|
let mut content_text = None;
|
||||||
|
let mut tool_calls = None;
|
||||||
|
|
||||||
|
if let Some(blocks) = parsed["content"].as_array() {
|
||||||
|
for block in blocks {
|
||||||
|
match block["type"].as_str() {
|
||||||
|
Some("text") => {
|
||||||
|
content_text = block["text"].as_str().map(|s| s.to_string());
|
||||||
|
}
|
||||||
|
Some("tool_use") => {
|
||||||
|
let tc = crate::dto::ToolCall {
|
||||||
|
id: block["id"].as_str().unwrap_or_default().to_string(),
|
||||||
|
name: block["name"].as_str().unwrap_or_default().to_string(),
|
||||||
|
arguments: block["input"].clone(),
|
||||||
|
};
|
||||||
|
tool_calls.get_or_insert_with(Vec::new).push(tc);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let usage = parsed["usage"].as_object().map(|u| crate::dto::TokenUsage {
|
||||||
|
input: u["input_tokens"].as_u64().unwrap_or(0) as u32,
|
||||||
|
output: u["output_tokens"].as_u64().unwrap_or(0) as u32,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(crate::dto::AgentGenerateResponse {
|
||||||
|
content: content_text,
|
||||||
|
tool_calls,
|
||||||
|
usage,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user