feat(ai): 重构 AiState 集成 ProviderRegistry + QuotaService

AiState 新增 provider_registry 和 quota 字段
main.rs 启动时按配置注册 Claude/OpenAI/Ollama Provider
支持多 Provider 并发注册和健康检查
This commit is contained in:
iven
2026-05-05 15:18:26 +08:00
parent 63ff8660fc
commit a16e86bf04
2 changed files with 71 additions and 3 deletions

View File

@@ -4,8 +4,10 @@ use erp_core::events::EventBus;
use erp_core::health_provider::HealthDataProvider;
use sea_orm::DatabaseConnection;
use crate::provider::registry::ProviderRegistry;
use crate::service::analysis::AnalysisService;
use crate::service::prompt::PromptService;
use crate::service::quota::QuotaService;
use crate::service::suggestion::SuggestionService;
use crate::service::usage::UsageService;
@@ -18,4 +20,6 @@ pub struct AiState {
pub usage: Arc<UsageService>,
pub suggestion: Arc<SuggestionService>,
pub health_provider: Arc<dyn HealthDataProvider>,
pub provider_registry: Arc<ProviderRegistry>,
pub quota: Arc<QuotaService>,
}

View File

@@ -459,14 +459,71 @@ async fn main() -> anyhow::Result<()> {
// Pre-build AI state (avoids per-request reconstruction)
let ai_state = {
let mut provider = erp_ai::provider::claude::ClaudeProvider::new(
// 构建多 Provider 注册表
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
// 始终注册默认 Claude provider兼容旧配置
{
let mut claude = erp_ai::provider::claude::ClaudeProvider::new(
config.ai.api_key.clone(),
);
if let Some(ref base_url) = config.ai.base_url {
claude = claude.with_base_url(base_url.clone());
}
registry.register("claude".to_string(), std::sync::Arc::new(claude));
}
// 注册配置中的额外 Provider
for (name, pcfg) in &config.ai.providers {
if !pcfg.is_enabled {
continue;
}
match pcfg.provider_type.as_str() {
"openai" => {
let api_key = pcfg.api_key_env.as_ref()
.and_then(|env| std::env::var(env).ok())
.unwrap_or_default();
let base_url = pcfg.base_url.clone()
.unwrap_or_else(|| "https://api.openai.com".to_string());
let provider = erp_ai::provider::openai::OpenAIProvider::new(
api_key, base_url, pcfg.default_model.clone(),
);
registry.register(name.clone(), std::sync::Arc::new(provider));
tracing::info!(provider = %name, "已注册 OpenAI 兼容提供商");
}
"ollama" => {
let base_url = pcfg.base_url.clone()
.unwrap_or_else(|| "http://localhost:11434".to_string());
let provider = erp_ai::provider::ollama::OllamaProvider::new(
base_url, pcfg.default_model.clone(),
);
registry.register(name.clone(), std::sync::Arc::new(provider));
tracing::info!(provider = %name, "已注册 Ollama 本地提供商");
}
"claude" => {
// 已作为默认注册,跳过
}
other => {
tracing::warn!(provider = %name, provider_type = %other, "未知的提供商类型,已跳过");
}
}
}
tracing::info!(providers = ?registry.provider_names(), "AI Provider 注册完成");
// 构建默认 provider 用于 AnalysisService保持 Box<dyn AiProvider> 签名)
let mut default_claude = erp_ai::provider::claude::ClaudeProvider::new(
config.ai.api_key.clone(),
);
if let Some(ref base_url) = config.ai.base_url {
provider = provider.with_base_url(base_url.clone());
default_claude = default_claude.with_base_url(base_url.clone());
}
let analysis = std::sync::Arc::new(
erp_ai::service::analysis::AnalysisService::new(Box::new(provider), db.clone()),
erp_ai::service::analysis::AnalysisService::new(
Box::new(default_claude),
db.clone(),
),
);
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
let usage = std::sync::Arc::new(erp_ai::service::usage::UsageService::new(db.clone()));
@@ -474,6 +531,11 @@ async fn main() -> anyhow::Result<()> {
let health_provider = std::sync::Arc::new(erp_health::HealthDataProviderImpl {
db: db.clone(),
});
let quota = std::sync::Arc::new(erp_ai::service::quota::QuotaService::new(
db.clone(),
config.ai.quota_check_enabled,
));
erp_ai::AiState {
db: db.clone(),
event_bus: event_bus.clone(),
@@ -482,6 +544,8 @@ async fn main() -> anyhow::Result<()> {
usage,
suggestion,
health_provider,
provider_registry: registry,
quota,
}
};