Files
openfang/crates/openfang-runtime/src/drivers/fallback.rs
iven 92e5def702
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
初始化提交
2026-03-01 16:24:24 +08:00

193 lines
6.1 KiB
Rust

//! Fallback driver — tries multiple LLM drivers in sequence.
//!
//! If the primary driver fails with a non-retryable error, the fallback driver
//! moves to the next driver in the chain.
use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent};
use async_trait::async_trait;
use std::sync::Arc;
use tracing::warn;
/// A driver that wraps multiple LLM drivers and tries each in order.
///
/// On failure, moves to the next driver. Rate-limit and overload errors
/// are bubbled up for retry logic to handle.
pub struct FallbackDriver {
drivers: Vec<Arc<dyn LlmDriver>>,
}
impl FallbackDriver {
/// Create a new fallback driver from an ordered chain of drivers.
///
/// The first driver is the primary; subsequent are fallbacks.
pub fn new(drivers: Vec<Arc<dyn LlmDriver>>) -> Self {
Self { drivers }
}
}
#[async_trait]
impl LlmDriver for FallbackDriver {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
let mut last_error = None;
for (i, driver) in self.drivers.iter().enumerate() {
match driver.complete(request.clone()).await {
Ok(response) => return Ok(response),
Err(e @ LlmError::RateLimited { .. }) | Err(e @ LlmError::Overloaded { .. }) => {
// Retryable errors — bubble up for the retry loop to handle
return Err(e);
}
Err(e) => {
warn!(
driver_index = i,
error = %e,
"Fallback driver failed, trying next"
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| LlmError::Api {
status: 0,
message: "No drivers configured in fallback chain".to_string(),
}))
}
async fn stream(
&self,
request: CompletionRequest,
tx: tokio::sync::mpsc::Sender<StreamEvent>,
) -> Result<CompletionResponse, LlmError> {
let mut last_error = None;
for (i, driver) in self.drivers.iter().enumerate() {
match driver.stream(request.clone(), tx.clone()).await {
Ok(response) => return Ok(response),
Err(e @ LlmError::RateLimited { .. }) | Err(e @ LlmError::Overloaded { .. }) => {
return Err(e);
}
Err(e) => {
warn!(
driver_index = i,
error = %e,
"Fallback driver (stream) failed, trying next"
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| LlmError::Api {
status: 0,
message: "No drivers configured in fallback chain".to_string(),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm_driver::CompletionResponse;
use openfang_types::message::{ContentBlock, StopReason, TokenUsage};
struct FailDriver;
#[async_trait]
impl LlmDriver for FailDriver {
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
Err(LlmError::Api {
status: 500,
message: "Internal error".to_string(),
})
}
}
struct OkDriver;
#[async_trait]
impl LlmDriver for OkDriver {
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
Ok(CompletionResponse {
content: vec![ContentBlock::Text {
text: "OK".to_string(),
}],
stop_reason: StopReason::EndTurn,
tool_calls: vec![],
usage: TokenUsage {
input_tokens: 10,
output_tokens: 5,
},
})
}
}
fn test_request() -> CompletionRequest {
CompletionRequest {
model: "test".to_string(),
messages: vec![],
tools: vec![],
max_tokens: 100,
temperature: 0.0,
system: None,
thinking: None,
}
}
#[tokio::test]
async fn test_fallback_primary_succeeds() {
let driver = FallbackDriver::new(vec![
Arc::new(OkDriver) as Arc<dyn LlmDriver>,
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
]);
let result = driver.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().text(), "OK");
}
#[tokio::test]
async fn test_fallback_primary_fails_secondary_succeeds() {
let driver = FallbackDriver::new(vec![
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
Arc::new(OkDriver) as Arc<dyn LlmDriver>,
]);
let result = driver.complete(test_request()).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_fallback_all_fail() {
let driver = FallbackDriver::new(vec![
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
Arc::new(FailDriver) as Arc<dyn LlmDriver>,
]);
let result = driver.complete(test_request()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_rate_limit_bubbles_up() {
struct RateLimitDriver;
#[async_trait]
impl LlmDriver for RateLimitDriver {
async fn complete(
&self,
_req: CompletionRequest,
) -> Result<CompletionResponse, LlmError> {
Err(LlmError::RateLimited {
retry_after_ms: 5000,
})
}
}
let driver = FallbackDriver::new(vec![
Arc::new(RateLimitDriver) as Arc<dyn LlmDriver>,
Arc::new(OkDriver) as Arc<dyn LlmDriver>,
]);
let result = driver.complete(test_request()).await;
// Rate limit should NOT fall through to next driver
assert!(matches!(result, Err(LlmError::RateLimited { .. })));
}
}