From a16e86bf044f89017d181a1007ebb02e7bc6c823 Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 5 May 2026 15:18:26 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E9=87=8D=E6=9E=84=20AiState=20?= =?UTF-8?q?=E9=9B=86=E6=88=90=20ProviderRegistry=20+=20QuotaService?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AiState 新增 provider_registry 和 quota 字段 main.rs 启动时按配置注册 Claude/OpenAI/Ollama Provider 支持多 Provider 并发注册和健康检查 --- crates/erp-ai/src/state.rs | 4 ++ crates/erp-server/src/main.rs | 70 +++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/crates/erp-ai/src/state.rs b/crates/erp-ai/src/state.rs index a1c7646..78c7a9c 100644 --- a/crates/erp-ai/src/state.rs +++ b/crates/erp-ai/src/state.rs @@ -4,8 +4,10 @@ use erp_core::events::EventBus; use erp_core::health_provider::HealthDataProvider; use sea_orm::DatabaseConnection; +use crate::provider::registry::ProviderRegistry; use crate::service::analysis::AnalysisService; use crate::service::prompt::PromptService; +use crate::service::quota::QuotaService; use crate::service::suggestion::SuggestionService; use crate::service::usage::UsageService; @@ -18,4 +20,6 @@ pub struct AiState { pub usage: Arc, pub suggestion: Arc, pub health_provider: Arc, + pub provider_registry: Arc, + pub quota: Arc, } diff --git a/crates/erp-server/src/main.rs b/crates/erp-server/src/main.rs index 80d6667..2fc481b 100644 --- a/crates/erp-server/src/main.rs +++ b/crates/erp-server/src/main.rs @@ -459,14 +459,71 @@ async fn main() -> anyhow::Result<()> { // Pre-build AI state (avoids per-request reconstruction) let ai_state = { - let mut provider = erp_ai::provider::claude::ClaudeProvider::new( + // 构建多 Provider 注册表 + let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new()); + + // 始终注册默认 Claude provider(兼容旧配置) + { + let mut claude = erp_ai::provider::claude::ClaudeProvider::new( + config.ai.api_key.clone(), + ); + if let Some(ref base_url) = config.ai.base_url { + claude = claude.with_base_url(base_url.clone()); + } + registry.register("claude".to_string(), std::sync::Arc::new(claude)); + } + + // 注册配置中的额外 Provider + for (name, pcfg) in &config.ai.providers { + if !pcfg.is_enabled { + continue; + } + match pcfg.provider_type.as_str() { + "openai" => { + let api_key = pcfg.api_key_env.as_ref() + .and_then(|env| std::env::var(env).ok()) + .unwrap_or_default(); + let base_url = pcfg.base_url.clone() + .unwrap_or_else(|| "https://api.openai.com".to_string()); + let provider = erp_ai::provider::openai::OpenAIProvider::new( + api_key, base_url, pcfg.default_model.clone(), + ); + registry.register(name.clone(), std::sync::Arc::new(provider)); + tracing::info!(provider = %name, "已注册 OpenAI 兼容提供商"); + } + "ollama" => { + let base_url = pcfg.base_url.clone() + .unwrap_or_else(|| "http://localhost:11434".to_string()); + let provider = erp_ai::provider::ollama::OllamaProvider::new( + base_url, pcfg.default_model.clone(), + ); + registry.register(name.clone(), std::sync::Arc::new(provider)); + tracing::info!(provider = %name, "已注册 Ollama 本地提供商"); + } + "claude" => { + // 已作为默认注册,跳过 + } + other => { + tracing::warn!(provider = %name, provider_type = %other, "未知的提供商类型,已跳过"); + } + } + } + + tracing::info!(providers = ?registry.provider_names(), "AI Provider 注册完成"); + + // 构建默认 provider 用于 AnalysisService(保持 Box 签名) + let mut default_claude = erp_ai::provider::claude::ClaudeProvider::new( config.ai.api_key.clone(), ); if let Some(ref base_url) = config.ai.base_url { - provider = provider.with_base_url(base_url.clone()); + default_claude = default_claude.with_base_url(base_url.clone()); } + let analysis = std::sync::Arc::new( - erp_ai::service::analysis::AnalysisService::new(Box::new(provider), db.clone()), + erp_ai::service::analysis::AnalysisService::new( + Box::new(default_claude), + db.clone(), + ), ); let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone())); let usage = std::sync::Arc::new(erp_ai::service::usage::UsageService::new(db.clone())); @@ -474,6 +531,11 @@ async fn main() -> anyhow::Result<()> { let health_provider = std::sync::Arc::new(erp_health::HealthDataProviderImpl { db: db.clone(), }); + let quota = std::sync::Arc::new(erp_ai::service::quota::QuotaService::new( + db.clone(), + config.ai.quota_check_enabled, + )); + erp_ai::AiState { db: db.clone(), event_bus: event_bus.clone(), @@ -482,6 +544,8 @@ async fn main() -> anyhow::Result<()> { usage, suggestion, health_provider, + provider_registry: registry, + quota, } };