feat(ai): 实现 ProviderRegistry 并发安全多提供商注册与路由
DashMap 支持并发注册,resolve() 按首选→回退→任意可用顺序 实时健康检查,含 4 个单元测试覆盖正常/降级/全不可用场景
This commit is contained in:
@@ -111,6 +111,7 @@ erp-dialysis = { path = "crates/erp-dialysis" }
|
||||
futures = "0.3"
|
||||
tokio-stream = "0.1"
|
||||
async-stream = "0.3"
|
||||
dashmap = "6"
|
||||
|
||||
# Template engine
|
||||
handlebars = "6"
|
||||
|
||||
@@ -21,5 +21,6 @@ utoipa.workspace = true
|
||||
async-trait.workspace = true
|
||||
reqwest.workspace = true
|
||||
handlebars.workspace = true
|
||||
dashmap.workspace = true
|
||||
sha2.workspace = true
|
||||
hex.workspace = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod claude;
|
||||
pub mod registry;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
199
crates/erp-ai/src/provider/registry.rs
Normal file
199
crates/erp-ai/src/provider/registry.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use crate::error::AiError;
|
||||
use crate::provider::AiProvider;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use dashmap::DashMap;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub enum ProviderHealth {
|
||||
Healthy { last_check: DateTime<Utc> },
|
||||
Degraded { last_check: DateTime<Utc>, error: String },
|
||||
Unavailable { since: DateTime<Utc>, error: String },
|
||||
}
|
||||
|
||||
impl ProviderHealth {
|
||||
pub fn is_healthy(&self) -> bool {
|
||||
matches!(self, ProviderHealth::Healthy { .. })
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderEntry {
|
||||
provider: Arc<dyn AiProvider>,
|
||||
health: Arc<RwLock<ProviderHealth>>,
|
||||
}
|
||||
|
||||
pub struct ProviderRegistry {
|
||||
entries: DashMap<String, ProviderEntry>,
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register(&self, name: String, provider: Arc<dyn AiProvider>) {
|
||||
let health = Arc::new(RwLock::new(ProviderHealth::Healthy {
|
||||
last_check: Utc::now(),
|
||||
}));
|
||||
self.entries.insert(name, ProviderEntry { provider, health });
|
||||
}
|
||||
|
||||
pub async fn resolve(&self, preferred: &str) -> crate::error::AiResult<ResolvedProvider> {
|
||||
// 1. 首选 Provider(实时健康检查)
|
||||
if let Some(entry) = self.entries.get(preferred) {
|
||||
if entry.provider.health_check().await.unwrap_or(false) {
|
||||
return Ok(ResolvedProvider {
|
||||
provider_name: preferred.to_string(),
|
||||
provider: entry.provider.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 任何可用 Provider
|
||||
for entry in self.entries.iter() {
|
||||
if entry.value().provider.health_check().await.unwrap_or(false) {
|
||||
return Ok(ResolvedProvider {
|
||||
provider_name: entry.key().to_string(),
|
||||
provider: entry.value().provider.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(AiError::ProviderUnavailable(preferred.to_string()))
|
||||
}
|
||||
|
||||
pub async fn health_check_all(&self) -> DashMap<String, ProviderHealth> {
|
||||
let result = DashMap::new();
|
||||
for entry in self.entries.iter() {
|
||||
let healthy = entry.value().provider.health_check().await.unwrap_or(false);
|
||||
let new_health = if healthy {
|
||||
ProviderHealth::Healthy { last_check: Utc::now() }
|
||||
} else {
|
||||
ProviderHealth::Unavailable {
|
||||
since: Utc::now(),
|
||||
error: "health check failed".to_string(),
|
||||
}
|
||||
};
|
||||
*entry.value().health.write().await = new_health.clone();
|
||||
result.insert(entry.key().to_string(), new_health);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn provider_names(&self) -> Vec<String> {
|
||||
self.entries.iter().map(|e| e.key().to_string()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResolvedProvider {
|
||||
provider_name: String,
|
||||
provider: Arc<dyn AiProvider>,
|
||||
}
|
||||
|
||||
impl ResolvedProvider {
|
||||
pub fn provider_name(&self) -> &str { &self.provider_name }
|
||||
pub fn provider(&self) -> &dyn AiProvider { self.provider.as_ref() }
|
||||
pub fn into_arc(self) -> Arc<dyn AiProvider> { self.provider }
|
||||
}
|
||||
|
||||
// === 测试桩 ===
|
||||
|
||||
struct MockProvider {
|
||||
name: String,
|
||||
healthy: Arc<std::sync::atomic::AtomicBool>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for MockProvider {
|
||||
async fn stream_generate(
|
||||
&self,
|
||||
_req: crate::dto::GenerateRequest,
|
||||
) -> crate::error::AiResult<std::pin::Pin<Box<dyn futures::Stream<Item = crate::error::AiResult<String>> + Send>>> {
|
||||
// 简单返回一个空流
|
||||
let s = async_stream::stream! { yield Ok("mock".to_string()); };
|
||||
Ok(Box::pin(s))
|
||||
}
|
||||
|
||||
async fn generate(&self, _req: crate::dto::GenerateRequest) -> crate::error::AiResult<crate::dto::GenerateResponse> {
|
||||
Ok(crate::dto::GenerateResponse {
|
||||
content: "mock".to_string(),
|
||||
model: "mock".to_string(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
duration_ms: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { &self.name }
|
||||
|
||||
async fn health_check(&self) -> crate::error::AiResult<bool> {
|
||||
Ok(self.healthy.load(std::sync::atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
fn mock(name: &str, healthy: bool) -> (Arc<MockProvider>, Arc<AtomicBool>) {
|
||||
let flag = Arc::new(AtomicBool::new(healthy));
|
||||
let provider = Arc::new(MockProvider {
|
||||
name: name.to_string(),
|
||||
healthy: flag.clone(),
|
||||
});
|
||||
(provider, flag)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_returns_preferred_when_healthy() {
|
||||
let registry = ProviderRegistry::new();
|
||||
let (claude, _) = mock("claude", true);
|
||||
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
|
||||
|
||||
let resolved = registry.resolve("claude").await.unwrap();
|
||||
assert_eq!(resolved.provider_name(), "claude");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_falls_back_when_unhealthy() {
|
||||
let registry = ProviderRegistry::new();
|
||||
let (claude, claude_flag) = mock("claude", true);
|
||||
let (ollama, _) = mock("ollama", true);
|
||||
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
|
||||
registry.register("ollama".to_string(), ollama as Arc<dyn AiProvider>);
|
||||
|
||||
claude_flag.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let resolved = registry.resolve("claude").await.unwrap();
|
||||
assert_eq!(resolved.provider_name(), "ollama");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_errors_when_all_unhealthy() {
|
||||
let registry = ProviderRegistry::new();
|
||||
let (claude, _) = mock("claude", false);
|
||||
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
|
||||
|
||||
let result = registry.resolve("claude").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_check_all() {
|
||||
let registry = ProviderRegistry::new();
|
||||
let (claude, _) = mock("claude", true);
|
||||
let (ollama, _) = mock("ollama", false);
|
||||
registry.register("claude".to_string(), claude as Arc<dyn AiProvider>);
|
||||
registry.register("ollama".to_string(), ollama as Arc<dyn AiProvider>);
|
||||
|
||||
let statuses = registry.health_check_all().await;
|
||||
assert!(statuses.get("claude").unwrap().is_healthy());
|
||||
assert!(!statuses.get("ollama").unwrap().is_healthy());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user