From 64456d017204c555fbaf0b05bd6ac317d9b5e629 Mon Sep 17 00:00:00 2001 From: iven Date: Mon, 18 May 2026 02:32:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20Claude=20Provider=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=20generate=5Fwith=5Ftools=20=E2=80=94=20tool=5Fuse/to?= =?UTF-8?q?ol=5Fresult=20=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/erp-ai/src/provider/claude.rs | 158 ++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 3 deletions(-) diff --git a/crates/erp-ai/src/provider/claude.rs b/crates/erp-ai/src/provider/claude.rs index 88d3821..5321bd2 100644 --- a/crates/erp-ai/src/provider/claude.rs +++ b/crates/erp-ai/src/provider/claude.rs @@ -31,6 +31,13 @@ impl ClaudeProvider { } } +#[derive(Serialize)] +#[serde(untagged)] +enum ClaudeContent { + Text(String), + Blocks(Vec), +} + #[derive(Serialize)] struct ClaudeRequest { model: String, @@ -38,13 +45,22 @@ struct ClaudeRequest { temperature: f32, system: String, messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, stream: bool, } #[derive(Serialize)] struct ClaudeMessage { role: String, - content: String, + content: ClaudeContent, +} + +#[derive(Serialize)] +struct ClaudeToolDef { + name: String, + description: String, + input_schema: serde_json::Value, } #[derive(Deserialize)] @@ -88,8 +104,9 @@ impl AiProvider for ClaudeProvider { system: req.system_prompt, messages: vec![ClaudeMessage { role: "user".into(), - content: req.user_prompt, + content: ClaudeContent::Text(req.user_prompt), }], + tools: None, stream: true, }; @@ -153,8 +170,9 @@ impl AiProvider for ClaudeProvider { system: req.system_prompt, messages: vec![ClaudeMessage { role: "user".into(), - content: req.user_prompt, + content: ClaudeContent::Text(req.user_prompt), }], + tools: None, stream: false, }; @@ -223,4 +241,138 @@ impl AiProvider for ClaudeProvider { Err(_) => Ok(false), } } + + async fn generate_with_tools( + &self, + messages: Vec, + tools: Vec, + system_prompt: &str, + model: &str, + temperature: f32, + max_tokens: u32, + ) -> AiResult { + use crate::dto::ChatMessageRole; + + let claude_messages: Vec = messages + .iter() + .map(|m| { + let role = match m.role { + ChatMessageRole::User => "user", + ChatMessageRole::Assistant => "assistant", + ChatMessageRole::Tool => "user", + }; + + let content = match m.role { + ChatMessageRole::Tool => { + let blocks = vec![serde_json::json!({ + "type": "tool_result", + "tool_use_id": m.tool_call_id.as_deref().unwrap_or(""), + "content": m.content, + })]; + ClaudeContent::Blocks(blocks) + } + ChatMessageRole::Assistant if m.tool_calls.is_some() => { + let mut blocks = vec![]; + if !m.content.is_empty() { + blocks.push(serde_json::json!({ + "type": "text", + "text": m.content, + })); + } + for tc in m.tool_calls.as_ref().unwrap() { + blocks.push(serde_json::json!({ + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.arguments, + })); + } + ClaudeContent::Blocks(blocks) + } + _ => ClaudeContent::Text(m.content.clone()), + }; + + ClaudeMessage { + role: role.to_string(), + content, + } + }) + .collect(); + + let claude_tools: Vec = tools + .into_iter() + .map(|t| ClaudeToolDef { + name: t.name, + description: t.description, + input_schema: t.parameters, + }) + .collect(); + + let req = ClaudeRequest { + model: model.to_string(), + max_tokens, + temperature, + system: system_prompt.to_string(), + messages: claude_messages, + tools: Some(claude_tools), + stream: false, + }; + + let resp = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&req) + .send() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + if !status.is_success() { + return Err(AiError::ProviderError(format!("Claude {status}: {body}"))); + } + + let parsed: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; + + let mut content_text = None; + let mut tool_calls = None; + + if let Some(blocks) = parsed["content"].as_array() { + for block in blocks { + match block["type"].as_str() { + Some("text") => { + content_text = block["text"].as_str().map(|s| s.to_string()); + } + Some("tool_use") => { + let tc = crate::dto::ToolCall { + id: block["id"].as_str().unwrap_or_default().to_string(), + name: block["name"].as_str().unwrap_or_default().to_string(), + arguments: block["input"].clone(), + }; + tool_calls.get_or_insert_with(Vec::new).push(tc); + } + _ => {} + } + } + } + + let usage = parsed["usage"].as_object().map(|u| crate::dto::TokenUsage { + input: u["input_tokens"].as_u64().unwrap_or(0) as u32, + output: u["output_tokens"].as_u64().unwrap_or(0) as u32, + }); + + Ok(crate::dto::AgentGenerateResponse { + content: content_text, + tool_calls, + usage, + }) + } }