Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
193 lines
6.1 KiB
Rust
193 lines
6.1 KiB
Rust
//! Fallback driver — tries multiple LLM drivers in sequence.
|
|
//!
|
|
//! If the primary driver fails with a non-retryable error, the fallback driver
|
|
//! moves to the next driver in the chain.
|
|
|
|
use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent};
|
|
use async_trait::async_trait;
|
|
use std::sync::Arc;
|
|
use tracing::warn;
|
|
|
|
/// A driver that wraps multiple LLM drivers and tries each in order.
|
|
///
|
|
/// On failure, moves to the next driver. Rate-limit and overload errors
|
|
/// are bubbled up for retry logic to handle.
|
|
pub struct FallbackDriver {
|
|
drivers: Vec<Arc<dyn LlmDriver>>,
|
|
}
|
|
|
|
impl FallbackDriver {
|
|
/// Create a new fallback driver from an ordered chain of drivers.
|
|
///
|
|
/// The first driver is the primary; subsequent are fallbacks.
|
|
pub fn new(drivers: Vec<Arc<dyn LlmDriver>>) -> Self {
|
|
Self { drivers }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmDriver for FallbackDriver {
|
|
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
|
|
let mut last_error = None;
|
|
|
|
for (i, driver) in self.drivers.iter().enumerate() {
|
|
match driver.complete(request.clone()).await {
|
|
Ok(response) => return Ok(response),
|
|
Err(e @ LlmError::RateLimited { .. }) | Err(e @ LlmError::Overloaded { .. }) => {
|
|
// Retryable errors — bubble up for the retry loop to handle
|
|
return Err(e);
|
|
}
|
|
Err(e) => {
|
|
warn!(
|
|
driver_index = i,
|
|
error = %e,
|
|
"Fallback driver failed, trying next"
|
|
);
|
|
last_error = Some(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(last_error.unwrap_or_else(|| LlmError::Api {
|
|
status: 0,
|
|
message: "No drivers configured in fallback chain".to_string(),
|
|
}))
|
|
}
|
|
|
|
async fn stream(
|
|
&self,
|
|
request: CompletionRequest,
|
|
tx: tokio::sync::mpsc::Sender<StreamEvent>,
|
|
) -> Result<CompletionResponse, LlmError> {
|
|
let mut last_error = None;
|
|
|
|
for (i, driver) in self.drivers.iter().enumerate() {
|
|
match driver.stream(request.clone(), tx.clone()).await {
|
|
Ok(response) => return Ok(response),
|
|
Err(e @ LlmError::RateLimited { .. }) | Err(e @ LlmError::Overloaded { .. }) => {
|
|
return Err(e);
|
|
}
|
|
Err(e) => {
|
|
warn!(
|
|
driver_index = i,
|
|
error = %e,
|
|
"Fallback driver (stream) failed, trying next"
|
|
);
|
|
last_error = Some(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(last_error.unwrap_or_else(|| LlmError::Api {
|
|
status: 0,
|
|
message: "No drivers configured in fallback chain".to_string(),
|
|
}))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::llm_driver::CompletionResponse;
|
|
use openfang_types::message::{ContentBlock, StopReason, TokenUsage};
|
|
|
|
struct FailDriver;
|
|
|
|
#[async_trait]
|
|
impl LlmDriver for FailDriver {
|
|
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
|
|
Err(LlmError::Api {
|
|
status: 500,
|
|
message: "Internal error".to_string(),
|
|
})
|
|
}
|
|
}
|
|
|
|
struct OkDriver;
|
|
|
|
#[async_trait]
|
|
impl LlmDriver for OkDriver {
|
|
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
|
|
Ok(CompletionResponse {
|
|
content: vec![ContentBlock::Text {
|
|
text: "OK".to_string(),
|
|
}],
|
|
stop_reason: StopReason::EndTurn,
|
|
tool_calls: vec![],
|
|
usage: TokenUsage {
|
|
input_tokens: 10,
|
|
output_tokens: 5,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
fn test_request() -> CompletionRequest {
|
|
CompletionRequest {
|
|
model: "test".to_string(),
|
|
messages: vec![],
|
|
tools: vec![],
|
|
max_tokens: 100,
|
|
temperature: 0.0,
|
|
system: None,
|
|
thinking: None,
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_fallback_primary_succeeds() {
|
|
let driver = FallbackDriver::new(vec![
|
|
Arc::new(OkDriver) as Arc<dyn LlmDriver>,
|
|
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
|
|
]);
|
|
let result = driver.complete(test_request()).await;
|
|
assert!(result.is_ok());
|
|
assert_eq!(result.unwrap().text(), "OK");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_fallback_primary_fails_secondary_succeeds() {
|
|
let driver = FallbackDriver::new(vec![
|
|
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
|
|
Arc::new(OkDriver) as Arc<dyn LlmDriver>,
|
|
]);
|
|
let result = driver.complete(test_request()).await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_fallback_all_fail() {
|
|
let driver = FallbackDriver::new(vec![
|
|
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
|
|
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
|
|
]);
|
|
let result = driver.complete(test_request()).await;
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_rate_limit_bubbles_up() {
|
|
struct RateLimitDriver;
|
|
|
|
#[async_trait]
|
|
impl LlmDriver for RateLimitDriver {
|
|
async fn complete(
|
|
&self,
|
|
_req: CompletionRequest,
|
|
) -> Result<CompletionResponse, LlmError> {
|
|
Err(LlmError::RateLimited {
|
|
retry_after_ms: 5000,
|
|
})
|
|
}
|
|
}
|
|
|
|
let driver = FallbackDriver::new(vec![
|
|
Arc::new(RateLimitDriver) as Arc<dyn LlmDriver>,
|
|
Arc::new(OkDriver) as Arc<dyn LlmDriver>,
|
|
]);
|
|
let result = driver.complete(test_request()).await;
|
|
// Rate limit should NOT fall through to next driver
|
|
assert!(matches!(result, Err(LlmError::RateLimited { .. })));
|
|
}
|
|
}
|