Files
zclaw_openfang/crates/zclaw-runtime/src/driver/gemini.rs
iven d9b0b4f4f7
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
fix(audit): Batch 7-9 dead_code 标注 + TODO 清理 + 文档同步
Batch 7: dead_code 标注统一 (16 处)
- crates/ 9 处: growth, kernel, pipeline, runtime, saas, skills
- src-tauri/ 7 处: classroom, intelligence, browser, mcp
- 统一格式: #[allow(dead_code)] // @reserved: <原因>

Batch 7+: EvolutionEngine L2/L3 10 个未使用 pub 函数
- 全部标注 @reserved: EvolutionEngine L2/L3, post-release integration

Batch 9: TODO → FUTURE 标记 (4 处)
- html.rs: template-based export
- nl_schedule.rs: LLM-assisted parsing
- knowledge/handlers.rs: category_id from upload
- personality_detector.rs: VikingStorage persistence

Batch 5+: Cargo.lock 更新 (serde_yaml_bw 迁移)

全量测试通过: 719 passed, 0 failed
2026-04-19 08:54:57 +08:00

664 lines
26 KiB
Rust

//! Google Gemini driver implementation
//!
//! Implements the Gemini REST API v1beta with full support for:
//! - Text generation (complete and streaming)
//! - Tool / function calling
//! - System instructions
//! - Token usage reporting
use async_trait::async_trait;
use async_stream::stream;
use futures::{Stream, StreamExt};
use secrecy::{ExposeSecret, SecretString};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use zclaw_types::{Result, ZclawError};
use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason};
use crate::stream::StreamChunk;
/// Google Gemini driver
pub struct GeminiDriver {
client: Client,
api_key: SecretString,
base_url: String,
}
impl GeminiDriver {
pub fn new(api_key: SecretString) -> Self {
Self {
client: Client::builder()
.user_agent(crate::USER_AGENT)
.http1_only()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_else(|_| Client::new()),
api_key,
base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
}
}
pub fn with_base_url(api_key: SecretString, base_url: String) -> Self {
Self {
client: Client::builder()
.user_agent(crate::USER_AGENT)
.http1_only()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_else(|_| Client::new()),
api_key,
base_url,
}
}
}
#[async_trait]
impl LlmDriver for GeminiDriver {
fn provider(&self) -> &str {
"gemini"
}
fn is_configured(&self) -> bool {
!self.api_key.expose_secret().is_empty()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let api_request = self.build_api_request(&request);
let url = format!(
"{}/models/{}:generateContent",
self.base_url,
request.model,
);
tracing::debug!(target: "gemini_driver", "Sending request to: {}", url);
let response = self.client
.post(&url)
.header("content-type", "application/json")
.header("x-goog-api-key", self.api_key.expose_secret())
.json(&api_request)
.send()
.await
.map_err(|e| ZclawError::LlmError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
tracing::warn!(target: "gemini_driver", "API error {}: {}", status, body);
return Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
}
let api_response: GeminiResponse = response
.json()
.await
.map_err(|e| ZclawError::LlmError(format!("Failed to parse response: {}", e)))?;
Ok(self.convert_response(api_response, request.model))
}
fn stream(
&self,
request: CompletionRequest,
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
let api_request = self.build_api_request(&request);
let url = format!(
"{}/models/{}:streamGenerateContent?alt=sse",
self.base_url,
request.model,
);
tracing::debug!(target: "gemini_driver", "Starting stream request to: {}", url);
Box::pin(stream! {
let response = match self.client
.post(&url)
.header("content-type", "application/json")
.header("x-goog-api-key", self.api_key.expose_secret())
.timeout(std::time::Duration::from_secs(120))
.json(&api_request)
.send()
.await
{
Ok(r) => {
tracing::debug!(target: "gemini_driver", "Stream response status: {}", r.status());
r
},
Err(e) => {
tracing::error!(target: "gemini_driver", "HTTP request failed: {:?}", e);
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
return;
}
let mut byte_stream = response.bytes_stream();
let mut accumulated_tool_calls: std::collections::HashMap<usize, (String, String)> = std::collections::HashMap::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
continue;
}
};
let text = String::from_utf8_lossy(&chunk);
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
match serde_json::from_str::<GeminiStreamResponse>(data) {
Ok(resp) => {
if let Some(candidate) = resp.candidates.first() {
let content = match &candidate.content {
Some(c) => c,
None => continue,
};
let parts = &content.parts;
for (idx, part) in parts.iter().enumerate() {
// Handle text content
if let Some(text) = &part.text {
if !text.is_empty() {
yield Ok(StreamChunk::TextDelta { delta: text.clone() });
}
}
// Handle function call (tool use)
if let Some(fc) = &part.function_call {
let name = fc.name.clone().unwrap_or_default();
let args = fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default()));
// Emit ToolUseStart if this is a new tool call
if !accumulated_tool_calls.contains_key(&idx) {
accumulated_tool_calls.insert(idx, (name.clone(), String::new()));
yield Ok(StreamChunk::ToolUseStart {
id: format!("gemini_call_{}", idx),
name,
});
}
// Emit the function arguments as delta
let args_str = serde_json::to_string(&args).unwrap_or_default();
let call_id = format!("gemini_call_{}", idx);
yield Ok(StreamChunk::ToolUseDelta {
id: call_id.clone(),
delta: args_str.clone(),
});
// Accumulate
if let Some(entry) = accumulated_tool_calls.get_mut(&idx) {
entry.1 = args_str;
}
}
}
// When the candidate is finished, emit ToolUseEnd for all pending
if let Some(ref finish_reason) = candidate.finish_reason {
let is_final = finish_reason == "STOP" || finish_reason == "MAX_TOKENS";
if is_final {
// Emit ToolUseEnd for all accumulated tool calls
for (idx, (_name, args_str)) in &accumulated_tool_calls {
let input: serde_json::Value = if args_str.is_empty() {
serde_json::json!({})
} else {
serde_json::from_str(args_str).unwrap_or_else(|e| {
tracing::warn!(target: "gemini_driver", "Failed to parse tool args '{}': {}", args_str, e);
serde_json::json!({})
})
};
yield Ok(StreamChunk::ToolUseEnd {
id: format!("gemini_call_{}", idx),
input,
});
}
// Extract usage metadata from the response
let usage = resp.usage_metadata.as_ref();
let input_tokens = usage.map(|u| u.prompt_token_count.unwrap_or(0)).unwrap_or(0);
let output_tokens = usage.map(|u| u.candidates_token_count.unwrap_or(0)).unwrap_or(0);
let stop_reason = match finish_reason.as_str() {
"STOP" => "end_turn",
"MAX_TOKENS" => "max_tokens",
"SAFETY" => "error",
"RECITATION" => "error",
_ => "end_turn",
};
yield Ok(StreamChunk::Complete {
input_tokens,
output_tokens,
stop_reason: stop_reason.to_string(),
});
}
}
}
}
Err(e) => {
tracing::warn!(target: "gemini_driver", "Failed to parse SSE event: {} - {}", e, data);
}
}
}
}
}
})
}
}
impl GeminiDriver {
/// Convert a CompletionRequest into the Gemini API request format.
///
/// Key mapping decisions:
/// - `system` prompt maps to `systemInstruction`
/// - Messages use Gemini's `contents` array with `role`/`parts`
/// - Tool definitions use `functionDeclarations`
/// - Tool results are sent as `functionResponse` parts in `user` messages
fn build_api_request(&self, request: &CompletionRequest) -> GeminiRequest {
if request.thinking_enabled {
tracing::debug!("[GeminiDriver] thinking_enabled=true but Gemini does not support native thinking mode; ignoring");
}
let mut contents: Vec<GeminiContent> = Vec::new();
for msg in &request.messages {
match msg {
zclaw_types::Message::User { content } => {
contents.push(GeminiContent {
role: "user".to_string(),
parts: vec![GeminiPart {
text: Some(content.clone()),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
zclaw_types::Message::Assistant { content, thinking } => {
let mut parts = Vec::new();
// Gemini does not have a native "thinking" field, so we prepend
// any thinking content as a text part with a marker.
if let Some(think) = thinking {
if !think.is_empty() {
parts.push(GeminiPart {
text: Some(format!("[thinking]\n{}\n[/thinking]", think)),
inline_data: None,
function_call: None,
function_response: None,
});
}
}
parts.push(GeminiPart {
text: Some(content.clone()),
inline_data: None,
function_call: None,
function_response: None,
});
contents.push(GeminiContent {
role: "model".to_string(),
parts,
});
}
zclaw_types::Message::ToolUse { id: _, tool, input } => {
// Tool use from the assistant is represented as a functionCall part
let args = if input.is_null() {
serde_json::json!({})
} else {
input.clone()
};
contents.push(GeminiContent {
role: "model".to_string(),
parts: vec![GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
name: Some(tool.to_string()),
args: Some(args),
}),
function_response: None,
}],
});
}
zclaw_types::Message::ToolResult { tool_call_id, tool, output, is_error } => {
// Tool results are sent as functionResponse parts in a "user" role message.
// Gemini requires that function responses reference the function name
// and include the response wrapped in a "result" or "error" key.
let response_content = if *is_error {
serde_json::json!({ "error": output.to_string() })
} else {
serde_json::json!({ "result": output.clone() })
};
contents.push(GeminiContent {
role: "user".to_string(),
parts: vec![GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: Some(GeminiFunctionResponse {
name: tool.to_string(),
response: response_content,
}),
}],
});
// Gemini ignores tool_call_id, but we log it for debugging
let _ = tool_call_id;
}
zclaw_types::Message::System { content } => {
// System messages are converted to user messages with system context.
// Note: the primary system prompt is handled via systemInstruction.
// Inline system messages in conversation history become user messages.
contents.push(GeminiContent {
role: "user".to_string(),
parts: vec![GeminiPart {
text: Some(content.clone()),
inline_data: None,
function_call: None,
function_response: None,
}],
});
}
}
}
// Build tool declarations
let function_declarations: Vec<GeminiFunctionDeclaration> = request.tools
.iter()
.map(|t| GeminiFunctionDeclaration {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.input_schema.clone(),
})
.collect();
// Build generation config
let mut generation_config = GeminiGenerationConfig::default();
if let Some(temp) = request.temperature {
generation_config.temperature = Some(temp);
}
if let Some(max) = request.max_tokens {
generation_config.max_output_tokens = Some(max);
}
if !request.stop.is_empty() {
generation_config.stop_sequences = Some(request.stop.clone());
}
// Build system instruction
let system_instruction = request.system.as_ref().map(|s| GeminiSystemInstruction {
parts: vec![GeminiPart {
text: Some(s.clone()),
inline_data: None,
function_call: None,
function_response: None,
}],
});
GeminiRequest {
contents,
system_instruction,
generation_config: Some(generation_config),
tools: if function_declarations.is_empty() {
None
} else {
Some(vec![GeminiTool {
function_declarations,
}])
},
}
}
/// Convert a Gemini API response into a CompletionResponse.
fn convert_response(&self, api_response: GeminiResponse, model: String) -> CompletionResponse {
let candidate = api_response.candidates.first();
let (content, stop_reason) = match candidate {
Some(c) => {
let parts = c.content.as_ref()
.map(|content| content.parts.as_slice())
.unwrap_or(&[]);
let mut blocks: Vec<ContentBlock> = Vec::new();
let mut has_tool_use = false;
for part in parts {
// Handle text content
if let Some(text) = &part.text {
// Skip thinking markers we injected
if text.starts_with("[thinking]\n") && text.contains("[/thinking]") {
let thinking_content = text
.strip_prefix("[thinking]\n")
.and_then(|s| s.strip_suffix("\n[/thinking]"))
.unwrap_or("");
if !thinking_content.is_empty() {
blocks.push(ContentBlock::Thinking {
thinking: thinking_content.to_string(),
});
}
} else if !text.is_empty() {
blocks.push(ContentBlock::Text { text: text.clone() });
}
}
// Handle function call (tool use)
if let Some(fc) = &part.function_call {
has_tool_use = true;
blocks.push(ContentBlock::ToolUse {
id: format!("gemini_call_{}", blocks.len()),
name: fc.name.clone().unwrap_or_default(),
input: fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default())),
});
}
}
// If there are no content blocks, add an empty text block
if blocks.is_empty() {
blocks.push(ContentBlock::Text { text: String::new() });
}
let stop = match c.finish_reason.as_deref() {
Some("STOP") => StopReason::EndTurn,
Some("MAX_TOKENS") => StopReason::MaxTokens,
Some("SAFETY") => StopReason::Error,
Some("RECITATION") => StopReason::Error,
Some("TOOL_USE") => StopReason::ToolUse,
_ => {
if has_tool_use {
StopReason::ToolUse
} else {
StopReason::EndTurn
}
}
};
(blocks, stop)
}
None => {
tracing::warn!(target: "gemini_driver", "No candidates in response");
(
vec![ContentBlock::Text { text: String::new() }],
StopReason::EndTurn,
)
}
};
let usage = api_response.usage_metadata.as_ref();
let input_tokens = usage.map(|u| u.prompt_token_count.unwrap_or(0)).unwrap_or(0);
let output_tokens = usage.map(|u| u.candidates_token_count.unwrap_or(0)).unwrap_or(0);
CompletionResponse {
content,
model,
input_tokens,
output_tokens,
stop_reason,
}
}
}
// ---------------------------------------------------------------------------
// Gemini API request types
// ---------------------------------------------------------------------------
#[derive(Serialize)]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiSystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GeminiGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
}
#[derive(Serialize)]
struct GeminiContent {
role: String,
parts: Vec<GeminiPart>,
}
#[derive(Serialize, Clone)]
struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<serde_json::Value>,
#[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
function_call: Option<GeminiFunctionCall>,
#[serde(rename = "functionResponse", skip_serializing_if = "Option::is_none")]
function_response: Option<GeminiFunctionResponse>,
}
#[derive(Serialize)]
struct GeminiSystemInstruction {
parts: Vec<GeminiPart>,
}
#[derive(Serialize)]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
}
impl Default for GeminiGenerationConfig {
fn default() -> Self {
Self {
temperature: None,
max_output_tokens: None,
stop_sequences: None,
}
}
}
#[derive(Serialize)]
struct GeminiTool {
#[serde(rename = "functionDeclarations")]
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Serialize)]
struct GeminiFunctionDeclaration {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Serialize, Clone)]
struct GeminiFunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
args: Option<serde_json::Value>,
}
#[derive(Serialize, Clone)]
struct GeminiFunctionResponse {
name: String,
response: serde_json::Value,
}
// ---------------------------------------------------------------------------
// Gemini API response types
// ---------------------------------------------------------------------------
#[derive(Deserialize)]
struct GeminiResponse {
#[serde(default)]
candidates: Vec<GeminiCandidate>,
#[serde(default)]
usage_metadata: Option<GeminiUsageMetadata>,
}
#[derive(Debug, Deserialize)]
struct GeminiCandidate {
#[serde(default)]
content: Option<GeminiResponseContent>,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GeminiResponseContent {
#[serde(default)]
parts: Vec<GeminiResponsePart>,
#[serde(default)]
#[allow(dead_code)] // @reserved: deserialized from Gemini API, not accessed in code
role: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GeminiResponsePart {
#[serde(default)]
text: Option<String>,
#[serde(rename = "functionCall", default)]
function_call: Option<GeminiResponseFunctionCall>,
}
#[derive(Debug, Deserialize)]
struct GeminiResponseFunctionCall {
#[serde(default)]
name: Option<String>,
#[serde(default)]
args: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
struct GeminiUsageMetadata {
#[serde(default)]
prompt_token_count: Option<u32>,
#[serde(default)]
candidates_token_count: Option<u32>,
#[serde(default)]
#[allow(dead_code)] // @reserved: deserialized from Gemini API, not accessed in code
total_token_count: Option<u32>,
}
// ---------------------------------------------------------------------------
// Gemini streaming types
// ---------------------------------------------------------------------------
/// Streaming response from the Gemini SSE endpoint.
/// Each SSE event contains the same structure as the non-streaming response,
/// but with incremental content.
#[derive(Debug, Deserialize)]
struct GeminiStreamResponse {
#[serde(default)]
candidates: Vec<GeminiCandidate>,
#[serde(default)]
usage_metadata: Option<GeminiUsageMetadata>,
}