//! Google Gemini driver implementation //! //! Implements the Gemini REST API v1beta with full support for: //! - Text generation (complete and streaming) //! - Tool / function calling //! - System instructions //! - Token usage reporting use async_trait::async_trait; use async_stream::stream; use futures::{Stream, StreamExt}; use secrecy::{ExposeSecret, SecretString}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::pin::Pin; use zclaw_types::{Result, ZclawError}; use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; use crate::stream::StreamChunk; /// Google Gemini driver pub struct GeminiDriver { client: Client, api_key: SecretString, base_url: String, } impl GeminiDriver { pub fn new(api_key: SecretString) -> Self { Self { client: Client::builder() .user_agent(crate::USER_AGENT) .http1_only() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(30)) .build() .unwrap_or_else(|_| Client::new()), api_key, base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(), } } pub fn with_base_url(api_key: SecretString, base_url: String) -> Self { Self { client: Client::builder() .user_agent(crate::USER_AGENT) .http1_only() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(30)) .build() .unwrap_or_else(|_| Client::new()), api_key, base_url, } } } #[async_trait] impl LlmDriver for GeminiDriver { fn provider(&self) -> &str { "gemini" } fn is_configured(&self) -> bool { !self.api_key.expose_secret().is_empty() } async fn complete(&self, request: CompletionRequest) -> Result { let api_request = self.build_api_request(&request); let url = format!( "{}/models/{}:generateContent", self.base_url, request.model, ); tracing::debug!(target: "gemini_driver", "Sending request to: {}", url); let response = self.client .post(&url) .header("content-type", "application/json") .header("x-goog-api-key", self.api_key.expose_secret()) .json(&api_request) .send() .await .map_err(|e| ZclawError::LlmError(format!("HTTP request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); tracing::warn!(target: "gemini_driver", "API error {}: {}", status, body); return Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); } let api_response: GeminiResponse = response .json() .await .map_err(|e| ZclawError::LlmError(format!("Failed to parse response: {}", e)))?; Ok(self.convert_response(api_response, request.model)) } fn stream( &self, request: CompletionRequest, ) -> Pin> + Send + '_>> { let api_request = self.build_api_request(&request); let url = format!( "{}/models/{}:streamGenerateContent?alt=sse", self.base_url, request.model, ); tracing::debug!(target: "gemini_driver", "Starting stream request to: {}", url); Box::pin(stream! { let response = match self.client .post(&url) .header("content-type", "application/json") .header("x-goog-api-key", self.api_key.expose_secret()) .timeout(std::time::Duration::from_secs(120)) .json(&api_request) .send() .await { Ok(r) => { tracing::debug!(target: "gemini_driver", "Stream response status: {}", r.status()); r }, Err(e) => { tracing::error!(target: "gemini_driver", "HTTP request failed: {:?}", e); yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e))); return; } }; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); return; } let mut byte_stream = response.bytes_stream(); let mut accumulated_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); while let Some(chunk_result) = byte_stream.next().await { let chunk = match chunk_result { Ok(c) => c, Err(e) => { yield Err(ZclawError::LlmError(format!("Stream error: {}", e))); continue; } }; let text = String::from_utf8_lossy(&chunk); for line in text.lines() { if let Some(data) = line.strip_prefix("data: ") { match serde_json::from_str::(data) { Ok(resp) => { if let Some(candidate) = resp.candidates.first() { let content = match &candidate.content { Some(c) => c, None => continue, }; let parts = &content.parts; for (idx, part) in parts.iter().enumerate() { // Handle text content if let Some(text) = &part.text { if !text.is_empty() { yield Ok(StreamChunk::TextDelta { delta: text.clone() }); } } // Handle function call (tool use) if let Some(fc) = &part.function_call { let name = fc.name.clone().unwrap_or_default(); let args = fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default())); // Emit ToolUseStart if this is a new tool call if !accumulated_tool_calls.contains_key(&idx) { accumulated_tool_calls.insert(idx, (name.clone(), String::new())); yield Ok(StreamChunk::ToolUseStart { id: format!("gemini_call_{}", idx), name, }); } // Emit the function arguments as delta let args_str = serde_json::to_string(&args).unwrap_or_default(); let call_id = format!("gemini_call_{}", idx); yield Ok(StreamChunk::ToolUseDelta { id: call_id.clone(), delta: args_str.clone(), }); // Accumulate if let Some(entry) = accumulated_tool_calls.get_mut(&idx) { entry.1 = args_str; } } } // When the candidate is finished, emit ToolUseEnd for all pending if let Some(ref finish_reason) = candidate.finish_reason { let is_final = finish_reason == "STOP" || finish_reason == "MAX_TOKENS"; if is_final { // Emit ToolUseEnd for all accumulated tool calls for (idx, (_name, args_str)) in &accumulated_tool_calls { let input: serde_json::Value = if args_str.is_empty() { serde_json::json!({}) } else { serde_json::from_str(args_str).unwrap_or_else(|e| { tracing::warn!(target: "gemini_driver", "Failed to parse tool args '{}': {}", args_str, e); serde_json::json!({}) }) }; yield Ok(StreamChunk::ToolUseEnd { id: format!("gemini_call_{}", idx), input, }); } // Extract usage metadata from the response let usage = resp.usage_metadata.as_ref(); let input_tokens = usage.map(|u| u.prompt_token_count.unwrap_or(0)).unwrap_or(0); let output_tokens = usage.map(|u| u.candidates_token_count.unwrap_or(0)).unwrap_or(0); let stop_reason = match finish_reason.as_str() { "STOP" => "end_turn", "MAX_TOKENS" => "max_tokens", "SAFETY" => "error", "RECITATION" => "error", _ => "end_turn", }; yield Ok(StreamChunk::Complete { input_tokens, output_tokens, stop_reason: stop_reason.to_string(), }); } } } } Err(e) => { tracing::warn!(target: "gemini_driver", "Failed to parse SSE event: {} - {}", e, data); } } } } } }) } } impl GeminiDriver { /// Convert a CompletionRequest into the Gemini API request format. /// /// Key mapping decisions: /// - `system` prompt maps to `systemInstruction` /// - Messages use Gemini's `contents` array with `role`/`parts` /// - Tool definitions use `functionDeclarations` /// - Tool results are sent as `functionResponse` parts in `user` messages fn build_api_request(&self, request: &CompletionRequest) -> GeminiRequest { if request.thinking_enabled { tracing::debug!("[GeminiDriver] thinking_enabled=true but Gemini does not support native thinking mode; ignoring"); } let mut contents: Vec = Vec::new(); for msg in &request.messages { match msg { zclaw_types::Message::User { content } => { contents.push(GeminiContent { role: "user".to_string(), parts: vec![GeminiPart { text: Some(content.clone()), inline_data: None, function_call: None, function_response: None, }], }); } zclaw_types::Message::Assistant { content, thinking } => { let mut parts = Vec::new(); // Gemini does not have a native "thinking" field, so we prepend // any thinking content as a text part with a marker. if let Some(think) = thinking { if !think.is_empty() { parts.push(GeminiPart { text: Some(format!("[thinking]\n{}\n[/thinking]", think)), inline_data: None, function_call: None, function_response: None, }); } } parts.push(GeminiPart { text: Some(content.clone()), inline_data: None, function_call: None, function_response: None, }); contents.push(GeminiContent { role: "model".to_string(), parts, }); } zclaw_types::Message::ToolUse { id: _, tool, input } => { // Tool use from the assistant is represented as a functionCall part let args = if input.is_null() { serde_json::json!({}) } else { input.clone() }; contents.push(GeminiContent { role: "model".to_string(), parts: vec![GeminiPart { text: None, inline_data: None, function_call: Some(GeminiFunctionCall { name: Some(tool.to_string()), args: Some(args), }), function_response: None, }], }); } zclaw_types::Message::ToolResult { tool_call_id, tool, output, is_error } => { // Tool results are sent as functionResponse parts in a "user" role message. // Gemini requires that function responses reference the function name // and include the response wrapped in a "result" or "error" key. let response_content = if *is_error { serde_json::json!({ "error": output.to_string() }) } else { serde_json::json!({ "result": output.clone() }) }; contents.push(GeminiContent { role: "user".to_string(), parts: vec![GeminiPart { text: None, inline_data: None, function_call: None, function_response: Some(GeminiFunctionResponse { name: tool.to_string(), response: response_content, }), }], }); // Gemini ignores tool_call_id, but we log it for debugging let _ = tool_call_id; } zclaw_types::Message::System { content } => { // System messages are converted to user messages with system context. // Note: the primary system prompt is handled via systemInstruction. // Inline system messages in conversation history become user messages. contents.push(GeminiContent { role: "user".to_string(), parts: vec![GeminiPart { text: Some(content.clone()), inline_data: None, function_call: None, function_response: None, }], }); } } } // Build tool declarations let function_declarations: Vec = request.tools .iter() .map(|t| GeminiFunctionDeclaration { name: t.name.clone(), description: t.description.clone(), parameters: t.input_schema.clone(), }) .collect(); // Build generation config let mut generation_config = GeminiGenerationConfig::default(); if let Some(temp) = request.temperature { generation_config.temperature = Some(temp); } if let Some(max) = request.max_tokens { generation_config.max_output_tokens = Some(max); } if !request.stop.is_empty() { generation_config.stop_sequences = Some(request.stop.clone()); } // Build system instruction let system_instruction = request.system.as_ref().map(|s| GeminiSystemInstruction { parts: vec![GeminiPart { text: Some(s.clone()), inline_data: None, function_call: None, function_response: None, }], }); GeminiRequest { contents, system_instruction, generation_config: Some(generation_config), tools: if function_declarations.is_empty() { None } else { Some(vec![GeminiTool { function_declarations, }]) }, } } /// Convert a Gemini API response into a CompletionResponse. fn convert_response(&self, api_response: GeminiResponse, model: String) -> CompletionResponse { let candidate = api_response.candidates.first(); let (content, stop_reason) = match candidate { Some(c) => { let parts = c.content.as_ref() .map(|content| content.parts.as_slice()) .unwrap_or(&[]); let mut blocks: Vec = Vec::new(); let mut has_tool_use = false; for part in parts { // Handle text content if let Some(text) = &part.text { // Skip thinking markers we injected if text.starts_with("[thinking]\n") && text.contains("[/thinking]") { let thinking_content = text .strip_prefix("[thinking]\n") .and_then(|s| s.strip_suffix("\n[/thinking]")) .unwrap_or(""); if !thinking_content.is_empty() { blocks.push(ContentBlock::Thinking { thinking: thinking_content.to_string(), }); } } else if !text.is_empty() { blocks.push(ContentBlock::Text { text: text.clone() }); } } // Handle function call (tool use) if let Some(fc) = &part.function_call { has_tool_use = true; blocks.push(ContentBlock::ToolUse { id: format!("gemini_call_{}", blocks.len()), name: fc.name.clone().unwrap_or_default(), input: fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default())), }); } } // If there are no content blocks, add an empty text block if blocks.is_empty() { blocks.push(ContentBlock::Text { text: String::new() }); } let stop = match c.finish_reason.as_deref() { Some("STOP") => StopReason::EndTurn, Some("MAX_TOKENS") => StopReason::MaxTokens, Some("SAFETY") => StopReason::Error, Some("RECITATION") => StopReason::Error, Some("TOOL_USE") => StopReason::ToolUse, _ => { if has_tool_use { StopReason::ToolUse } else { StopReason::EndTurn } } }; (blocks, stop) } None => { tracing::warn!(target: "gemini_driver", "No candidates in response"); ( vec![ContentBlock::Text { text: String::new() }], StopReason::EndTurn, ) } }; let usage = api_response.usage_metadata.as_ref(); let input_tokens = usage.map(|u| u.prompt_token_count.unwrap_or(0)).unwrap_or(0); let output_tokens = usage.map(|u| u.candidates_token_count.unwrap_or(0)).unwrap_or(0); CompletionResponse { content, model, input_tokens, output_tokens, stop_reason, } } } // --------------------------------------------------------------------------- // Gemini API request types // --------------------------------------------------------------------------- #[derive(Serialize)] struct GeminiRequest { contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] system_instruction: Option, #[serde(skip_serializing_if = "Option::is_none")] generation_config: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, } #[derive(Serialize)] struct GeminiContent { role: String, parts: Vec, } #[derive(Serialize, Clone)] struct GeminiPart { #[serde(skip_serializing_if = "Option::is_none")] text: Option, #[serde(skip_serializing_if = "Option::is_none")] inline_data: Option, #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")] function_call: Option, #[serde(rename = "functionResponse", skip_serializing_if = "Option::is_none")] function_response: Option, } #[derive(Serialize)] struct GeminiSystemInstruction { parts: Vec, } #[derive(Serialize)] struct GeminiGenerationConfig { #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] max_output_tokens: Option, #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")] stop_sequences: Option>, } impl Default for GeminiGenerationConfig { fn default() -> Self { Self { temperature: None, max_output_tokens: None, stop_sequences: None, } } } #[derive(Serialize)] struct GeminiTool { #[serde(rename = "functionDeclarations")] function_declarations: Vec, } #[derive(Serialize)] struct GeminiFunctionDeclaration { name: String, description: String, parameters: serde_json::Value, } #[derive(Serialize, Clone)] struct GeminiFunctionCall { #[serde(skip_serializing_if = "Option::is_none")] name: Option, #[serde(skip_serializing_if = "Option::is_none")] args: Option, } #[derive(Serialize, Clone)] struct GeminiFunctionResponse { name: String, response: serde_json::Value, } // --------------------------------------------------------------------------- // Gemini API response types // --------------------------------------------------------------------------- #[derive(Deserialize)] struct GeminiResponse { #[serde(default)] candidates: Vec, #[serde(default)] usage_metadata: Option, } #[derive(Debug, Deserialize)] struct GeminiCandidate { #[serde(default)] content: Option, #[serde(default)] finish_reason: Option, } #[derive(Debug, Deserialize)] struct GeminiResponseContent { #[serde(default)] parts: Vec, #[serde(default)] #[allow(dead_code)] // @reserved: deserialized from Gemini API, not accessed in code role: Option, } #[derive(Debug, Deserialize)] struct GeminiResponsePart { #[serde(default)] text: Option, #[serde(rename = "functionCall", default)] function_call: Option, } #[derive(Debug, Deserialize)] struct GeminiResponseFunctionCall { #[serde(default)] name: Option, #[serde(default)] args: Option, } #[derive(Debug, Deserialize)] struct GeminiUsageMetadata { #[serde(default)] prompt_token_count: Option, #[serde(default)] candidates_token_count: Option, #[serde(default)] #[allow(dead_code)] // @reserved: deserialized from Gemini API, not accessed in code total_token_count: Option, } // --------------------------------------------------------------------------- // Gemini streaming types // --------------------------------------------------------------------------- /// Streaming response from the Gemini SSE endpoint. /// Each SSE event contains the same structure as the non-streaming response, /// but with incremental content. #[derive(Debug, Deserialize)] struct GeminiStreamResponse { #[serde(default)] candidates: Vec, #[serde(default)] usage_metadata: Option, }