//! Local LLM driver (Ollama, LM Studio, vLLM, etc.) //! //! Uses the OpenAI-compatible API format. The only differences from the //! OpenAI driver are: no API key is required, and base_url points to a //! local server. use async_trait::async_trait; use async_stream::stream; use futures::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::pin::Pin; use zclaw_types::{Result, ZclawError}; use super::{CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason}; use crate::stream::StreamChunk; /// Local LLM driver for Ollama, LM Studio, vLLM, and other OpenAI-compatible servers. pub struct LocalDriver { client: Client, base_url: String, } impl LocalDriver { /// Create a driver pointing at a custom OpenAI-compatible endpoint. /// /// The `base_url` should end with `/v1` (e.g. `http://localhost:8080/v1`). pub fn new(base_url: impl Into) -> Self { Self { client: Client::builder() .user_agent(crate::USER_AGENT) .timeout(std::time::Duration::from_secs(300)) // 5 min -- local inference can be slow .connect_timeout(std::time::Duration::from_secs(10)) // short connect timeout .build() .unwrap_or_else(|_| Client::new()), base_url: base_url.into(), } } /// Ollama default endpoint (`http://localhost:11434/v1`). pub fn ollama() -> Self { Self::new("http://localhost:11434/v1") } /// LM Studio default endpoint (`http://localhost:1234/v1`). pub fn lm_studio() -> Self { Self::new("http://localhost:1234/v1") } /// vLLM default endpoint (`http://localhost:8000/v1`). pub fn vllm() -> Self { Self::new("http://localhost:8000/v1") } // ---------------------------------------------------------------- // Request / response conversion (OpenAI-compatible format) // ---------------------------------------------------------------- fn build_api_request(&self, request: &CompletionRequest) -> LocalApiRequest { if request.thinking_enabled { tracing::debug!("[LocalDriver] thinking_enabled=true but local driver does not support native thinking mode; ignoring"); } let messages: Vec = request .messages .iter() .filter_map(|msg| match msg { zclaw_types::Message::User { content } => Some(LocalApiMessage { role: "user".to_string(), content: Some(content.clone()), tool_calls: None, }), zclaw_types::Message::Assistant { content, thinking: _, } => Some(LocalApiMessage { role: "assistant".to_string(), content: Some(content.clone()), tool_calls: None, }), zclaw_types::Message::System { content } => Some(LocalApiMessage { role: "system".to_string(), content: Some(content.clone()), tool_calls: None, }), zclaw_types::Message::ToolUse { id, tool, input, .. } => { let args = if input.is_null() { "{}".to_string() } else { serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string()) }; Some(LocalApiMessage { role: "assistant".to_string(), content: None, tool_calls: Some(vec![LocalApiToolCall { id: id.clone(), r#type: "function".to_string(), function: LocalFunctionCall { name: tool.to_string(), arguments: args, }, }]), }) } zclaw_types::Message::ToolResult { output, is_error, .. } => Some(LocalApiMessage { role: "tool".to_string(), content: Some(if *is_error { format!("Error: {}", output) } else { output.to_string() }), tool_calls: None, }), }) .collect(); // Prepend system prompt when provided. let mut messages = messages; if let Some(system) = &request.system { messages.insert( 0, LocalApiMessage { role: "system".to_string(), content: Some(system.clone()), tool_calls: None, }, ); } let tools: Vec = request .tools .iter() .map(|t| LocalApiTool { r#type: "function".to_string(), function: LocalFunctionDef { name: t.name.clone(), description: t.description.clone(), parameters: t.input_schema.clone(), }, }) .collect(); LocalApiRequest { model: request.model.clone(), messages, max_tokens: request.max_tokens, temperature: request.temperature, stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) }, stream: request.stream, tools: if tools.is_empty() { None } else { Some(tools) }, } } fn convert_response( &self, api_response: LocalApiResponse, model: String, ) -> CompletionResponse { let choice = api_response.choices.first(); let (content, stop_reason) = match choice { Some(c) => { let has_tool_calls = c .message .tool_calls .as_ref() .map(|tc| !tc.is_empty()) .unwrap_or(false); let has_content = c .message .content .as_ref() .map(|t| !t.is_empty()) .unwrap_or(false); let blocks = if has_tool_calls { let tool_calls = c.message.tool_calls.as_deref().unwrap_or_default(); tool_calls .iter() .map(|tc| { let input: serde_json::Value = serde_json::from_str(&tc.function.arguments) .unwrap_or(serde_json::Value::Null); ContentBlock::ToolUse { id: tc.id.clone(), name: tc.function.name.clone(), input, } }) .collect() } else if has_content { vec![ContentBlock::Text { text: c.message.content.clone().unwrap_or_default(), }] } else { vec![ContentBlock::Text { text: String::new(), }] }; let stop = match c.finish_reason.as_deref() { Some("stop") => StopReason::EndTurn, Some("length") => StopReason::MaxTokens, Some("tool_calls") => StopReason::ToolUse, _ => StopReason::EndTurn, }; (blocks, stop) } None => ( vec![ContentBlock::Text { text: String::new(), }], StopReason::EndTurn, ), }; let (input_tokens, output_tokens) = api_response .usage .map(|u| (u.prompt_tokens, u.completion_tokens)) .unwrap_or((0, 0)); CompletionResponse { content, model, input_tokens, output_tokens, stop_reason, cache_creation_input_tokens: None, cache_read_input_tokens: None, } } /// Build the `reqwest::RequestBuilder` with an optional Authorization header. /// /// Ollama does not need one; LM Studio / vLLM may be configured with an /// optional API key. We send the header only when a key is present. fn authenticated_post(&self, url: &str) -> reqwest::RequestBuilder { self.client.post(url).header("Accept", "*/*") } } #[async_trait] impl LlmDriver for LocalDriver { fn provider(&self) -> &str { "local" } fn is_configured(&self) -> bool { // Local drivers never require an API key. true } async fn complete(&self, request: CompletionRequest) -> Result { let api_request = self.build_api_request(&request); let url = format!("{}/chat/completions", self.base_url); tracing::debug!(target: "local_driver", "Sending request to {}", url); tracing::trace!( target: "local_driver", "Request body: {}", serde_json::to_string(&api_request).unwrap_or_default() ); let response = self .authenticated_post(&url) .json(&api_request) .send() .await .map_err(|e| { let hint = connection_error_hint(&e); ZclawError::LlmError(format!("Failed to connect to local LLM server at {}: {}{}", self.base_url, e, hint)) })?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); tracing::warn!(target: "local_driver", "API error {}: {}", status, body); return Err(ZclawError::LlmError(format!( "Local LLM API error {}: {}", status, body ))); } let api_response: LocalApiResponse = response .json() .await .map_err(|e| ZclawError::LlmError(format!("Failed to parse response: {}", e)))?; Ok(self.convert_response(api_response, request.model)) } fn stream( &self, request: CompletionRequest, ) -> Pin> + Send + '_>> { let mut stream_request = self.build_api_request(&request); stream_request.stream = true; let url = format!("{}/chat/completions", self.base_url); tracing::debug!(target: "local_driver", "Starting stream to {}", url); Box::pin(stream! { let response = match self .authenticated_post(&url) .header("Content-Type", "application/json") .timeout(std::time::Duration::from_secs(300)) .json(&stream_request) .send() .await { Ok(r) => { tracing::debug!(target: "local_driver", "Stream response status: {}", r.status()); r } Err(e) => { let hint = connection_error_hint(&e); tracing::error!(target: "local_driver", "Stream connection failed: {}{}", e, hint); yield Err(ZclawError::LlmError(format!( "Failed to connect to local LLM server at {}: {}{}", self.base_url, e, hint ))); return; } }; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body))); return; } let mut byte_stream = response.bytes_stream(); let mut accumulated_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); let mut current_tool_id: Option = None; while let Some(chunk_result) = byte_stream.next().await { let chunk = match chunk_result { Ok(c) => c, Err(e) => { yield Err(ZclawError::LlmError(format!("Stream error: {}", e))); continue; } }; let text = String::from_utf8_lossy(&chunk); for line in text.lines() { if let Some(data) = line.strip_prefix("data: ") { if data == "[DONE]" { tracing::debug!( target: "local_driver", "Stream done, tool_calls accumulated: {}", accumulated_tool_calls.len() ); for (id, (name, args)) in &accumulated_tool_calls { if name.is_empty() { tracing::warn!( target: "local_driver", "Skipping tool call with empty name: id={}", id ); continue; } let parsed_args: serde_json::Value = if args.is_empty() { serde_json::json!({}) } else { serde_json::from_str(args).unwrap_or_else(|e| { tracing::warn!( target: "local_driver", "Failed to parse tool args '{}': {}", args, e ); serde_json::json!({}) }) }; yield Ok(StreamChunk::ToolUseEnd { id: id.clone(), input: parsed_args, }); } yield Ok(StreamChunk::Complete { input_tokens: 0, output_tokens: 0, stop_reason: "end_turn".to_string(), cache_creation_input_tokens: None, cache_read_input_tokens: None, }); continue; } match serde_json::from_str::(data) { Ok(resp) => { if let Some(choice) = resp.choices.first() { let delta = &choice.delta; // Text content if let Some(content) = &delta.content { if !content.is_empty() { yield Ok(StreamChunk::TextDelta { delta: content.clone(), }); } } // Tool calls if let Some(tool_calls) = &delta.tool_calls { for tc in tool_calls { // Tool call start if let Some(id) = &tc.id { let name = tc .function .as_ref() .and_then(|f| f.name.clone()) .unwrap_or_default(); if !name.is_empty() { current_tool_id = Some(id.clone()); accumulated_tool_calls .insert(id.clone(), (name.clone(), String::new())); yield Ok(StreamChunk::ToolUseStart { id: id.clone(), name, }); } else { current_tool_id = Some(id.clone()); accumulated_tool_calls .insert(id.clone(), (String::new(), String::new())); } } // Tool call delta if let Some(function) = &tc.function { if let Some(args) = &function.arguments { let tool_id = tc .id .as_ref() .or(current_tool_id.as_ref()) .cloned() .unwrap_or_default(); yield Ok(StreamChunk::ToolUseDelta { id: tool_id.clone(), delta: args.clone(), }); if let Some(entry) = accumulated_tool_calls.get_mut(&tool_id) { entry.1.push_str(args); } } } } } } } Err(e) => { tracing::warn!( target: "local_driver", "Failed to parse SSE: {}, data: {}", e, data ); } } } } } }) } } // --------------------------------------------------------------------------- // Connection-error diagnostics // --------------------------------------------------------------------------- /// Return a human-readable hint when the local server appears to be unreachable. fn connection_error_hint(error: &reqwest::Error) -> String { if error.is_connect() { format!( "\n\nHint: Is the local LLM server running at {}?\n\ Make sure the server is started before using this driver.", // Extract just the host:port from whatever error we have. "localhost" ) } else if error.is_timeout() { "\n\nHint: The request timed out. Local inference can be slow -- \ try a smaller model or increase the timeout." .to_string() } else { String::new() } } // --------------------------------------------------------------------------- // OpenAI-compatible API types (private to this module) // --------------------------------------------------------------------------- #[derive(Serialize)] struct LocalApiRequest { model: String, messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] stop: Option>, #[serde(default)] stream: bool, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, } #[derive(Serialize)] struct LocalApiMessage { role: String, #[serde(skip_serializing_if = "Option::is_none")] content: Option, #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, } #[derive(Serialize)] struct LocalApiToolCall { id: String, r#type: String, function: LocalFunctionCall, } #[derive(Serialize)] struct LocalFunctionCall { name: String, arguments: String, } #[derive(Serialize)] struct LocalApiTool { r#type: String, function: LocalFunctionDef, } #[derive(Serialize)] struct LocalFunctionDef { name: String, description: String, parameters: serde_json::Value, } // --- Response types --- #[derive(Deserialize, Default)] struct LocalApiResponse { #[serde(default)] choices: Vec, #[serde(default)] usage: Option, } #[derive(Deserialize, Default)] struct LocalApiChoice { #[serde(default)] message: LocalApiResponseMessage, #[serde(default)] finish_reason: Option, } #[derive(Deserialize, Default)] struct LocalApiResponseMessage { #[serde(default)] content: Option, #[serde(default)] tool_calls: Option>, } #[derive(Deserialize, Default)] struct LocalApiToolCallResponse { #[serde(default)] id: String, #[serde(default)] function: LocalFunctionCallResponse, } #[derive(Deserialize, Default)] struct LocalFunctionCallResponse { #[serde(default)] name: String, #[serde(default)] arguments: String, } #[derive(Deserialize, Default)] struct LocalApiUsage { #[serde(default)] prompt_tokens: u32, #[serde(default)] completion_tokens: u32, } // --- Streaming types --- #[derive(Debug, Deserialize)] struct LocalStreamResponse { #[serde(default)] choices: Vec, } #[derive(Debug, Deserialize)] struct LocalStreamChoice { #[serde(default)] delta: LocalDelta, #[serde(default)] #[allow(dead_code)] // Deserialized from SSE, not accessed in code finish_reason: Option, } #[derive(Debug, Deserialize, Default)] struct LocalDelta { #[serde(default)] content: Option, #[serde(default)] tool_calls: Option>, } #[derive(Debug, Deserialize)] struct LocalToolCallDelta { #[serde(default)] id: Option, #[serde(default)] function: Option, } #[derive(Debug, Deserialize)] struct LocalFunctionDelta { #[serde(default)] name: Option, #[serde(default)] arguments: Option, }