feat(ai): 重构 AiState 集成 ProviderRegistry + QuotaService
AiState 新增 provider_registry 和 quota 字段 main.rs 启动时按配置注册 Claude/OpenAI/Ollama Provider 支持多 Provider 并发注册和健康检查
This commit is contained in:
@@ -4,8 +4,10 @@ use erp_core::events::EventBus;
|
|||||||
use erp_core::health_provider::HealthDataProvider;
|
use erp_core::health_provider::HealthDataProvider;
|
||||||
use sea_orm::DatabaseConnection;
|
use sea_orm::DatabaseConnection;
|
||||||
|
|
||||||
|
use crate::provider::registry::ProviderRegistry;
|
||||||
use crate::service::analysis::AnalysisService;
|
use crate::service::analysis::AnalysisService;
|
||||||
use crate::service::prompt::PromptService;
|
use crate::service::prompt::PromptService;
|
||||||
|
use crate::service::quota::QuotaService;
|
||||||
use crate::service::suggestion::SuggestionService;
|
use crate::service::suggestion::SuggestionService;
|
||||||
use crate::service::usage::UsageService;
|
use crate::service::usage::UsageService;
|
||||||
|
|
||||||
@@ -18,4 +20,6 @@ pub struct AiState {
|
|||||||
pub usage: Arc<UsageService>,
|
pub usage: Arc<UsageService>,
|
||||||
pub suggestion: Arc<SuggestionService>,
|
pub suggestion: Arc<SuggestionService>,
|
||||||
pub health_provider: Arc<dyn HealthDataProvider>,
|
pub health_provider: Arc<dyn HealthDataProvider>,
|
||||||
|
pub provider_registry: Arc<ProviderRegistry>,
|
||||||
|
pub quota: Arc<QuotaService>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -459,14 +459,71 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Pre-build AI state (avoids per-request reconstruction)
|
// Pre-build AI state (avoids per-request reconstruction)
|
||||||
let ai_state = {
|
let ai_state = {
|
||||||
let mut provider = erp_ai::provider::claude::ClaudeProvider::new(
|
// 构建多 Provider 注册表
|
||||||
|
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
|
||||||
|
|
||||||
|
// 始终注册默认 Claude provider(兼容旧配置)
|
||||||
|
{
|
||||||
|
let mut claude = erp_ai::provider::claude::ClaudeProvider::new(
|
||||||
|
config.ai.api_key.clone(),
|
||||||
|
);
|
||||||
|
if let Some(ref base_url) = config.ai.base_url {
|
||||||
|
claude = claude.with_base_url(base_url.clone());
|
||||||
|
}
|
||||||
|
registry.register("claude".to_string(), std::sync::Arc::new(claude));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册配置中的额外 Provider
|
||||||
|
for (name, pcfg) in &config.ai.providers {
|
||||||
|
if !pcfg.is_enabled {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match pcfg.provider_type.as_str() {
|
||||||
|
"openai" => {
|
||||||
|
let api_key = pcfg.api_key_env.as_ref()
|
||||||
|
.and_then(|env| std::env::var(env).ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let base_url = pcfg.base_url.clone()
|
||||||
|
.unwrap_or_else(|| "https://api.openai.com".to_string());
|
||||||
|
let provider = erp_ai::provider::openai::OpenAIProvider::new(
|
||||||
|
api_key, base_url, pcfg.default_model.clone(),
|
||||||
|
);
|
||||||
|
registry.register(name.clone(), std::sync::Arc::new(provider));
|
||||||
|
tracing::info!(provider = %name, "已注册 OpenAI 兼容提供商");
|
||||||
|
}
|
||||||
|
"ollama" => {
|
||||||
|
let base_url = pcfg.base_url.clone()
|
||||||
|
.unwrap_or_else(|| "http://localhost:11434".to_string());
|
||||||
|
let provider = erp_ai::provider::ollama::OllamaProvider::new(
|
||||||
|
base_url, pcfg.default_model.clone(),
|
||||||
|
);
|
||||||
|
registry.register(name.clone(), std::sync::Arc::new(provider));
|
||||||
|
tracing::info!(provider = %name, "已注册 Ollama 本地提供商");
|
||||||
|
}
|
||||||
|
"claude" => {
|
||||||
|
// 已作为默认注册,跳过
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
tracing::warn!(provider = %name, provider_type = %other, "未知的提供商类型,已跳过");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(providers = ?registry.provider_names(), "AI Provider 注册完成");
|
||||||
|
|
||||||
|
// 构建默认 provider 用于 AnalysisService(保持 Box<dyn AiProvider> 签名)
|
||||||
|
let mut default_claude = erp_ai::provider::claude::ClaudeProvider::new(
|
||||||
config.ai.api_key.clone(),
|
config.ai.api_key.clone(),
|
||||||
);
|
);
|
||||||
if let Some(ref base_url) = config.ai.base_url {
|
if let Some(ref base_url) = config.ai.base_url {
|
||||||
provider = provider.with_base_url(base_url.clone());
|
default_claude = default_claude.with_base_url(base_url.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let analysis = std::sync::Arc::new(
|
let analysis = std::sync::Arc::new(
|
||||||
erp_ai::service::analysis::AnalysisService::new(Box::new(provider), db.clone()),
|
erp_ai::service::analysis::AnalysisService::new(
|
||||||
|
Box::new(default_claude),
|
||||||
|
db.clone(),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
|
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
|
||||||
let usage = std::sync::Arc::new(erp_ai::service::usage::UsageService::new(db.clone()));
|
let usage = std::sync::Arc::new(erp_ai::service::usage::UsageService::new(db.clone()));
|
||||||
@@ -474,6 +531,11 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let health_provider = std::sync::Arc::new(erp_health::HealthDataProviderImpl {
|
let health_provider = std::sync::Arc::new(erp_health::HealthDataProviderImpl {
|
||||||
db: db.clone(),
|
db: db.clone(),
|
||||||
});
|
});
|
||||||
|
let quota = std::sync::Arc::new(erp_ai::service::quota::QuotaService::new(
|
||||||
|
db.clone(),
|
||||||
|
config.ai.quota_check_enabled,
|
||||||
|
));
|
||||||
|
|
||||||
erp_ai::AiState {
|
erp_ai::AiState {
|
||||||
db: db.clone(),
|
db: db.clone(),
|
||||||
event_bus: event_bus.clone(),
|
event_bus: event_bus.clone(),
|
||||||
@@ -482,6 +544,8 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
usage,
|
usage,
|
||||||
suggestion,
|
suggestion,
|
||||||
health_provider,
|
health_provider,
|
||||||
|
provider_registry: registry,
|
||||||
|
quota,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user