//! Fallback driver — tries multiple LLM drivers in sequence. //! //! If the primary driver fails with a non-retryable error, the fallback driver //! moves to the next driver in the chain. use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent}; use async_trait::async_trait; use std::sync::Arc; use tracing::warn; /// A driver that wraps multiple LLM drivers and tries each in order. /// /// On failure, moves to the next driver. Rate-limit and overload errors /// are bubbled up for retry logic to handle. pub struct FallbackDriver { drivers: Vec>, } impl FallbackDriver { /// Create a new fallback driver from an ordered chain of drivers. /// /// The first driver is the primary; subsequent are fallbacks. pub fn new(drivers: Vec>) -> Self { Self { drivers } } } #[async_trait] impl LlmDriver for FallbackDriver { async fn complete(&self, request: CompletionRequest) -> Result { let mut last_error = None; for (i, driver) in self.drivers.iter().enumerate() { match driver.complete(request.clone()).await { Ok(response) => return Ok(response), Err(e @ LlmError::RateLimited { .. }) | Err(e @ LlmError::Overloaded { .. }) => { // Retryable errors — bubble up for the retry loop to handle return Err(e); } Err(e) => { warn!( driver_index = i, error = %e, "Fallback driver failed, trying next" ); last_error = Some(e); } } } Err(last_error.unwrap_or_else(|| LlmError::Api { status: 0, message: "No drivers configured in fallback chain".to_string(), })) } async fn stream( &self, request: CompletionRequest, tx: tokio::sync::mpsc::Sender, ) -> Result { let mut last_error = None; for (i, driver) in self.drivers.iter().enumerate() { match driver.stream(request.clone(), tx.clone()).await { Ok(response) => return Ok(response), Err(e @ LlmError::RateLimited { .. }) | Err(e @ LlmError::Overloaded { .. }) => { return Err(e); } Err(e) => { warn!( driver_index = i, error = %e, "Fallback driver (stream) failed, trying next" ); last_error = Some(e); } } } Err(last_error.unwrap_or_else(|| LlmError::Api { status: 0, message: "No drivers configured in fallback chain".to_string(), })) } } #[cfg(test)] mod tests { use super::*; use crate::llm_driver::CompletionResponse; use openfang_types::message::{ContentBlock, StopReason, TokenUsage}; struct FailDriver; #[async_trait] impl LlmDriver for FailDriver { async fn complete(&self, _req: CompletionRequest) -> Result { Err(LlmError::Api { status: 500, message: "Internal error".to_string(), }) } } struct OkDriver; #[async_trait] impl LlmDriver for OkDriver { async fn complete(&self, _req: CompletionRequest) -> Result { Ok(CompletionResponse { content: vec![ContentBlock::Text { text: "OK".to_string(), }], stop_reason: StopReason::EndTurn, tool_calls: vec![], usage: TokenUsage { input_tokens: 10, output_tokens: 5, }, }) } } fn test_request() -> CompletionRequest { CompletionRequest { model: "test".to_string(), messages: vec![], tools: vec![], max_tokens: 100, temperature: 0.0, system: None, thinking: None, } } #[tokio::test] async fn test_fallback_primary_succeeds() { let driver = FallbackDriver::new(vec![ Arc::new(OkDriver) as Arc, Arc::new(FailDriver) as Arc, ]); let result = driver.complete(test_request()).await; assert!(result.is_ok()); assert_eq!(result.unwrap().text(), "OK"); } #[tokio::test] async fn test_fallback_primary_fails_secondary_succeeds() { let driver = FallbackDriver::new(vec![ Arc::new(FailDriver) as Arc, Arc::new(OkDriver) as Arc, ]); let result = driver.complete(test_request()).await; assert!(result.is_ok()); } #[tokio::test] async fn test_fallback_all_fail() { let driver = FallbackDriver::new(vec![ Arc::new(FailDriver) as Arc, Arc::new(FailDriver) as Arc, ]); let result = driver.complete(test_request()).await; assert!(result.is_err()); } #[tokio::test] async fn test_rate_limit_bubbles_up() { struct RateLimitDriver; #[async_trait] impl LlmDriver for RateLimitDriver { async fn complete( &self, _req: CompletionRequest, ) -> Result { Err(LlmError::RateLimited { retry_after_ms: 5000, }) } } let driver = FallbackDriver::new(vec![ Arc::new(RateLimitDriver) as Arc, Arc::new(OkDriver) as Arc, ]); let result = driver.complete(test_request()).await; // Rate limit should NOT fall through to next driver assert!(matches!(result, Err(LlmError::RateLimited { .. }))); } }