feat(ai): 实现 ProviderRegistry 并发安全多提供商注册与路由

DashMap 支持并发注册,resolve() 按首选→回退→任意可用顺序
实时健康检查,含 4 个单元测试覆盖正常/降级/全不可用场景
This commit is contained in:
iven
2026-05-05 15:07:19 +08:00
parent 24bb8e7bca
commit 74b1d44068
4 changed files with 202 additions and 0 deletions

View File

@@ -111,6 +111,7 @@ erp-dialysis = { path = "crates/erp-dialysis" }
futures = "0.3"
tokio-stream = "0.1"
async-stream = "0.3"
dashmap = "6"
# Template engine
handlebars = "6"

View File

@@ -21,5 +21,6 @@ utoipa.workspace = true
async-trait.workspace = true
reqwest.workspace = true
handlebars.workspace = true
dashmap.workspace = true
sha2.workspace = true
hex.workspace = true

View File

@@ -1,4 +1,5 @@
pub mod claude;
pub mod registry;
use async_trait::async_trait;
use futures::Stream;

View File

@@ -0,0 +1,199 @@
use crate::error::AiError;
use crate::provider::AiProvider;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize)]
pub enum ProviderHealth {
Healthy { last_check: DateTime<Utc> },
Degraded { last_check: DateTime<Utc>, error: String },
Unavailable { since: DateTime<Utc>, error: String },
}
impl ProviderHealth {
pub fn is_healthy(&self) -> bool {
matches!(self, ProviderHealth::Healthy { .. })
}
}
struct ProviderEntry {
provider: Arc<dyn AiProvider>,
health: Arc<RwLock<ProviderHealth>>,
}
pub struct ProviderRegistry {
entries: DashMap<String, ProviderEntry>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
entries: DashMap::new(),
}
}
pub fn register(&self, name: String, provider: Arc<dyn AiProvider>) {
let health = Arc::new(RwLock::new(ProviderHealth::Healthy {
last_check: Utc::now(),
}));
self.entries.insert(name, ProviderEntry { provider, health });
}
pub async fn resolve(&self, preferred: &str) -> crate::error::AiResult<ResolvedProvider> {
// 1. 首选 Provider实时健康检查
if let Some(entry) = self.entries.get(preferred) {
if entry.provider.health_check().await.unwrap_or(false) {
return Ok(ResolvedProvider {
provider_name: preferred.to_string(),
provider: entry.provider.clone(),
});
}
}
// 2. 任何可用 Provider
for entry in self.entries.iter() {
if entry.value().provider.health_check().await.unwrap_or(false) {
return Ok(ResolvedProvider {
provider_name: entry.key().to_string(),
provider: entry.value().provider.clone(),
});
}
}
Err(AiError::ProviderUnavailable(preferred.to_string()))
}
pub async fn health_check_all(&self) -> DashMap<String, ProviderHealth> {
let result = DashMap::new();
for entry in self.entries.iter() {
let healthy = entry.value().provider.health_check().await.unwrap_or(false);
let new_health = if healthy {
ProviderHealth::Healthy { last_check: Utc::now() }
} else {
ProviderHealth::Unavailable {
since: Utc::now(),
error: "health check failed".to_string(),
}
};
*entry.value().health.write().await = new_health.clone();
result.insert(entry.key().to_string(), new_health);
}
result
}
pub fn provider_names(&self) -> Vec<String> {
self.entries.iter().map(|e| e.key().to_string()).collect()
}
}
pub struct ResolvedProvider {
provider_name: String,
provider: Arc<dyn AiProvider>,
}
impl ResolvedProvider {
pub fn provider_name(&self) -> &str { &self.provider_name }
pub fn provider(&self) -> &dyn AiProvider { self.provider.as_ref() }
pub fn into_arc(self) -> Arc<dyn AiProvider> { self.provider }
}
// === 测试桩 ===
struct MockProvider {
name: String,
healthy: Arc<std::sync::atomic::AtomicBool>,
}
#[async_trait]
impl AiProvider for MockProvider {
async fn stream_generate(
&self,
_req: crate::dto::GenerateRequest,
) -> crate::error::AiResult<std::pin::Pin<Box<dyn futures::Stream<Item = crate::error::AiResult<String>> + Send>>> {
// 简单返回一个空流
let s = async_stream::stream! { yield Ok("mock".to_string()); };
Ok(Box::pin(s))
}
async fn generate(&self, _req: crate::dto::GenerateRequest) -> crate::error::AiResult<crate::dto::GenerateResponse> {
Ok(crate::dto::GenerateResponse {
content: "mock".to_string(),
model: "mock".to_string(),
input_tokens: 0,
output_tokens: 0,
duration_ms: 0,
})
}
fn name(&self) -> &str { &self.name }
async fn health_check(&self) -> crate::error::AiResult<bool> {
Ok(self.healthy.load(std::sync::atomic::Ordering::Relaxed))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
fn mock(name: &str, healthy: bool) -> (Arc<MockProvider>, Arc<AtomicBool>) {
let flag = Arc::new(AtomicBool::new(healthy));
let provider = Arc::new(MockProvider {
name: name.to_string(),
healthy: flag.clone(),
});
(provider, flag)
}
#[tokio::test]
async fn resolve_returns_preferred_when_healthy() {
let registry = ProviderRegistry::new();
let (claude, _) = mock("claude", true);
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
let resolved = registry.resolve("claude").await.unwrap();
assert_eq!(resolved.provider_name(), "claude");
}
#[tokio::test]
async fn resolve_falls_back_when_unhealthy() {
let registry = ProviderRegistry::new();
let (claude, claude_flag) = mock("claude", true);
let (ollama, _) = mock("ollama", true);
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
registry.register("ollama".to_string(), ollama as Arc<dyn AiProvider>);
claude_flag.store(false, std::sync::atomic::Ordering::Relaxed);
let resolved = registry.resolve("claude").await.unwrap();
assert_eq!(resolved.provider_name(), "ollama");
}
#[tokio::test]
async fn resolve_errors_when_all_unhealthy() {
let registry = ProviderRegistry::new();
let (claude, _) = mock("claude", false);
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
let result = registry.resolve("claude").await;
assert!(result.is_err());
}
#[tokio::test]
async fn health_check_all() {
let registry = ProviderRegistry::new();
let (claude, _) = mock("claude", true);
let (ollama, _) = mock("ollama", false);
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
registry.register("ollama".to_string(), ollama as Arc<dyn AiProvider>);
let statuses = registry.health_check_all().await;
assert!(statuses.get("claude").unwrap().is_healthy());
assert!(!statuses.get("ollama").unwrap().is_healthy());
}
}