feat(ai): 重构 AiState 集成 ProviderRegistry + QuotaService
AiState 新增 provider_registry 和 quota 字段 main.rs 启动时按配置注册 Claude/OpenAI/Ollama Provider 支持多 Provider 并发注册和健康检查
This commit is contained in:
@@ -4,8 +4,10 @@ use erp_core::events::EventBus;
|
||||
use erp_core::health_provider::HealthDataProvider;
|
||||
use sea_orm::DatabaseConnection;
|
||||
|
||||
use crate::provider::registry::ProviderRegistry;
|
||||
use crate::service::analysis::AnalysisService;
|
||||
use crate::service::prompt::PromptService;
|
||||
use crate::service::quota::QuotaService;
|
||||
use crate::service::suggestion::SuggestionService;
|
||||
use crate::service::usage::UsageService;
|
||||
|
||||
@@ -18,4 +20,6 @@ pub struct AiState {
|
||||
pub usage: Arc<UsageService>,
|
||||
pub suggestion: Arc<SuggestionService>,
|
||||
pub health_provider: Arc<dyn HealthDataProvider>,
|
||||
pub provider_registry: Arc<ProviderRegistry>,
|
||||
pub quota: Arc<QuotaService>,
|
||||
}
|
||||
|
||||
@@ -459,14 +459,71 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Pre-build AI state (avoids per-request reconstruction)
|
||||
let ai_state = {
|
||||
let mut provider = erp_ai::provider::claude::ClaudeProvider::new(
|
||||
// 构建多 Provider 注册表
|
||||
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
|
||||
|
||||
// 始终注册默认 Claude provider(兼容旧配置)
|
||||
{
|
||||
let mut claude = erp_ai::provider::claude::ClaudeProvider::new(
|
||||
config.ai.api_key.clone(),
|
||||
);
|
||||
if let Some(ref base_url) = config.ai.base_url {
|
||||
claude = claude.with_base_url(base_url.clone());
|
||||
}
|
||||
registry.register("claude".to_string(), std::sync::Arc::new(claude));
|
||||
}
|
||||
|
||||
// 注册配置中的额外 Provider
|
||||
for (name, pcfg) in &config.ai.providers {
|
||||
if !pcfg.is_enabled {
|
||||
continue;
|
||||
}
|
||||
match pcfg.provider_type.as_str() {
|
||||
"openai" => {
|
||||
let api_key = pcfg.api_key_env.as_ref()
|
||||
.and_then(|env| std::env::var(env).ok())
|
||||
.unwrap_or_default();
|
||||
let base_url = pcfg.base_url.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".to_string());
|
||||
let provider = erp_ai::provider::openai::OpenAIProvider::new(
|
||||
api_key, base_url, pcfg.default_model.clone(),
|
||||
);
|
||||
registry.register(name.clone(), std::sync::Arc::new(provider));
|
||||
tracing::info!(provider = %name, "已注册 OpenAI 兼容提供商");
|
||||
}
|
||||
"ollama" => {
|
||||
let base_url = pcfg.base_url.clone()
|
||||
.unwrap_or_else(|| "http://localhost:11434".to_string());
|
||||
let provider = erp_ai::provider::ollama::OllamaProvider::new(
|
||||
base_url, pcfg.default_model.clone(),
|
||||
);
|
||||
registry.register(name.clone(), std::sync::Arc::new(provider));
|
||||
tracing::info!(provider = %name, "已注册 Ollama 本地提供商");
|
||||
}
|
||||
"claude" => {
|
||||
// 已作为默认注册,跳过
|
||||
}
|
||||
other => {
|
||||
tracing::warn!(provider = %name, provider_type = %other, "未知的提供商类型,已跳过");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(providers = ?registry.provider_names(), "AI Provider 注册完成");
|
||||
|
||||
// 构建默认 provider 用于 AnalysisService(保持 Box<dyn AiProvider> 签名)
|
||||
let mut default_claude = erp_ai::provider::claude::ClaudeProvider::new(
|
||||
config.ai.api_key.clone(),
|
||||
);
|
||||
if let Some(ref base_url) = config.ai.base_url {
|
||||
provider = provider.with_base_url(base_url.clone());
|
||||
default_claude = default_claude.with_base_url(base_url.clone());
|
||||
}
|
||||
|
||||
let analysis = std::sync::Arc::new(
|
||||
erp_ai::service::analysis::AnalysisService::new(Box::new(provider), db.clone()),
|
||||
erp_ai::service::analysis::AnalysisService::new(
|
||||
Box::new(default_claude),
|
||||
db.clone(),
|
||||
),
|
||||
);
|
||||
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
|
||||
let usage = std::sync::Arc::new(erp_ai::service::usage::UsageService::new(db.clone()));
|
||||
@@ -474,6 +531,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
let health_provider = std::sync::Arc::new(erp_health::HealthDataProviderImpl {
|
||||
db: db.clone(),
|
||||
});
|
||||
let quota = std::sync::Arc::new(erp_ai::service::quota::QuotaService::new(
|
||||
db.clone(),
|
||||
config.ai.quota_check_enabled,
|
||||
));
|
||||
|
||||
erp_ai::AiState {
|
||||
db: db.clone(),
|
||||
event_bus: event_bus.clone(),
|
||||
@@ -482,6 +544,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
usage,
|
||||
suggestion,
|
||||
health_provider,
|
||||
provider_registry: registry,
|
||||
quota,
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user