From 74b1d4406898e1a0ba7b98be0446c4fb5bca98f7 Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 5 May 2026 15:07:19 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E5=AE=9E=E7=8E=B0=20ProviderRegist?= =?UTF-8?q?ry=20=E5=B9=B6=E5=8F=91=E5=AE=89=E5=85=A8=E5=A4=9A=E6=8F=90?= =?UTF-8?q?=E4=BE=9B=E5=95=86=E6=B3=A8=E5=86=8C=E4=B8=8E=E8=B7=AF=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DashMap 支持并发注册,resolve() 按首选→回退→任意可用顺序 实时健康检查,含 4 个单元测试覆盖正常/降级/全不可用场景 --- Cargo.toml | 1 + crates/erp-ai/Cargo.toml | 1 + crates/erp-ai/src/provider/mod.rs | 1 + crates/erp-ai/src/provider/registry.rs | 199 +++++++++++++++++++++++++ 4 files changed, 202 insertions(+) create mode 100644 crates/erp-ai/src/provider/registry.rs diff --git a/Cargo.toml b/Cargo.toml index bb0a224..8b42093 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,7 @@ erp-dialysis = { path = "crates/erp-dialysis" } futures = "0.3" tokio-stream = "0.1" async-stream = "0.3" +dashmap = "6" # Template engine handlebars = "6" diff --git a/crates/erp-ai/Cargo.toml b/crates/erp-ai/Cargo.toml index 413e85e..a087f91 100644 --- a/crates/erp-ai/Cargo.toml +++ b/crates/erp-ai/Cargo.toml @@ -21,5 +21,6 @@ utoipa.workspace = true async-trait.workspace = true reqwest.workspace = true handlebars.workspace = true +dashmap.workspace = true sha2.workspace = true hex.workspace = true diff --git a/crates/erp-ai/src/provider/mod.rs b/crates/erp-ai/src/provider/mod.rs index 1f35e19..c5d7aeb 100644 --- a/crates/erp-ai/src/provider/mod.rs +++ b/crates/erp-ai/src/provider/mod.rs @@ -1,4 +1,5 @@ pub mod claude; +pub mod registry; use async_trait::async_trait; use futures::Stream; diff --git a/crates/erp-ai/src/provider/registry.rs b/crates/erp-ai/src/provider/registry.rs new file mode 100644 index 0000000..c38eeca --- /dev/null +++ b/crates/erp-ai/src/provider/registry.rs @@ -0,0 +1,199 @@ +use crate::error::AiError; +use crate::provider::AiProvider; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use dashmap::DashMap; +use serde::Serialize; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[derive(Debug, Clone, Serialize)] +pub enum ProviderHealth { + Healthy { last_check: DateTime }, + Degraded { last_check: DateTime, error: String }, + Unavailable { since: DateTime, error: String }, +} + +impl ProviderHealth { + pub fn is_healthy(&self) -> bool { + matches!(self, ProviderHealth::Healthy { .. }) + } +} + +struct ProviderEntry { + provider: Arc, + health: Arc>, +} + +pub struct ProviderRegistry { + entries: DashMap, +} + +impl ProviderRegistry { + pub fn new() -> Self { + Self { + entries: DashMap::new(), + } + } + + pub fn register(&self, name: String, provider: Arc) { + let health = Arc::new(RwLock::new(ProviderHealth::Healthy { + last_check: Utc::now(), + })); + self.entries.insert(name, ProviderEntry { provider, health }); + } + + pub async fn resolve(&self, preferred: &str) -> crate::error::AiResult { + // 1. 首选 Provider(实时健康检查) + if let Some(entry) = self.entries.get(preferred) { + if entry.provider.health_check().await.unwrap_or(false) { + return Ok(ResolvedProvider { + provider_name: preferred.to_string(), + provider: entry.provider.clone(), + }); + } + } + + // 2. 任何可用 Provider + for entry in self.entries.iter() { + if entry.value().provider.health_check().await.unwrap_or(false) { + return Ok(ResolvedProvider { + provider_name: entry.key().to_string(), + provider: entry.value().provider.clone(), + }); + } + } + + Err(AiError::ProviderUnavailable(preferred.to_string())) + } + + pub async fn health_check_all(&self) -> DashMap { + let result = DashMap::new(); + for entry in self.entries.iter() { + let healthy = entry.value().provider.health_check().await.unwrap_or(false); + let new_health = if healthy { + ProviderHealth::Healthy { last_check: Utc::now() } + } else { + ProviderHealth::Unavailable { + since: Utc::now(), + error: "health check failed".to_string(), + } + }; + *entry.value().health.write().await = new_health.clone(); + result.insert(entry.key().to_string(), new_health); + } + result + } + + pub fn provider_names(&self) -> Vec { + self.entries.iter().map(|e| e.key().to_string()).collect() + } +} + +pub struct ResolvedProvider { + provider_name: String, + provider: Arc, +} + +impl ResolvedProvider { + pub fn provider_name(&self) -> &str { &self.provider_name } + pub fn provider(&self) -> &dyn AiProvider { self.provider.as_ref() } + pub fn into_arc(self) -> Arc { self.provider } +} + +// === 测试桩 === + +struct MockProvider { + name: String, + healthy: Arc, +} + +#[async_trait] +impl AiProvider for MockProvider { + async fn stream_generate( + &self, + _req: crate::dto::GenerateRequest, + ) -> crate::error::AiResult> + Send>>> { + // 简单返回一个空流 + let s = async_stream::stream! { yield Ok("mock".to_string()); }; + Ok(Box::pin(s)) + } + + async fn generate(&self, _req: crate::dto::GenerateRequest) -> crate::error::AiResult { + Ok(crate::dto::GenerateResponse { + content: "mock".to_string(), + model: "mock".to_string(), + input_tokens: 0, + output_tokens: 0, + duration_ms: 0, + }) + } + + fn name(&self) -> &str { &self.name } + + async fn health_check(&self) -> crate::error::AiResult { + Ok(self.healthy.load(std::sync::atomic::Ordering::Relaxed)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::AtomicBool; + + fn mock(name: &str, healthy: bool) -> (Arc, Arc) { + let flag = Arc::new(AtomicBool::new(healthy)); + let provider = Arc::new(MockProvider { + name: name.to_string(), + healthy: flag.clone(), + }); + (provider, flag) + } + + #[tokio::test] + async fn resolve_returns_preferred_when_healthy() { + let registry = ProviderRegistry::new(); + let (claude, _) = mock("claude", true); + registry.register("claude".to_string(), claude as Arc); + + let resolved = registry.resolve("claude").await.unwrap(); + assert_eq!(resolved.provider_name(), "claude"); + } + + #[tokio::test] + async fn resolve_falls_back_when_unhealthy() { + let registry = ProviderRegistry::new(); + let (claude, claude_flag) = mock("claude", true); + let (ollama, _) = mock("ollama", true); + registry.register("claude".to_string(), claude as Arc); + registry.register("ollama".to_string(), ollama as Arc); + + claude_flag.store(false, std::sync::atomic::Ordering::Relaxed); + + let resolved = registry.resolve("claude").await.unwrap(); + assert_eq!(resolved.provider_name(), "ollama"); + } + + #[tokio::test] + async fn resolve_errors_when_all_unhealthy() { + let registry = ProviderRegistry::new(); + let (claude, _) = mock("claude", false); + registry.register("claude".to_string(), claude as Arc); + + let result = registry.resolve("claude").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn health_check_all() { + let registry = ProviderRegistry::new(); + let (claude, _) = mock("claude", true); + let (ollama, _) = mock("ollama", false); + registry.register("claude".to_string(), claude as Arc); + registry.register("ollama".to_string(), ollama as Arc); + + let statuses = registry.health_check_all().await; + assert!(statuses.get("claude").unwrap().is_healthy()); + assert!(!statuses.get("ollama").unwrap().is_healthy()); + } +}