Files
hms/crates/erp-server/src/main.rs
iven d2512ca9db
Some checks failed
CI / rust-check (push) Has been cancelled
CI / rust-test (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled
feat(ai): 集成知识库到 AnalysisService — system_prompt 自动注入临床规则
Phase 3 Task 23: AnalysisService 新增可选 knowledge_source,
stream_analyze 前自动查询 L1/L2/L3 知识并注入 system_prompt
2026-05-05 16:01:52 +08:00

902 lines
34 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
mod config;
mod db;
mod dialysis_workflow;
mod handlers;
mod middleware;
mod outbox;
mod state;
mod tasks;
/// OpenAPI 规范定义 — 通过 utoipa derive 合并各模块 schema。
#[derive(OpenApi)]
#[openapi(info(
title = "ERP Platform API",
version = "0.1.0",
description = "ERP 平台底座 REST API 文档"
))]
struct ApiDoc;
/// Auth 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_auth::handler::auth_handler::login,
erp_auth::handler::auth_handler::refresh,
erp_auth::handler::auth_handler::logout,
erp_auth::handler::auth_handler::change_password,
erp_auth::handler::user_handler::list_users,
erp_auth::handler::user_handler::create_user,
erp_auth::handler::user_handler::get_user,
erp_auth::handler::user_handler::update_user,
erp_auth::handler::user_handler::delete_user,
erp_auth::handler::user_handler::assign_roles,
erp_auth::handler::role_handler::list_roles,
erp_auth::handler::role_handler::create_role,
erp_auth::handler::role_handler::get_role,
erp_auth::handler::role_handler::update_role,
erp_auth::handler::role_handler::delete_role,
erp_auth::handler::role_handler::assign_permissions,
erp_auth::handler::role_handler::get_role_permissions,
erp_auth::handler::role_handler::list_permissions,
),
components(
schemas(
erp_auth::dto::LoginReq,
erp_auth::dto::LoginResp,
erp_auth::dto::RefreshReq,
erp_auth::dto::UserResp,
erp_auth::dto::CreateUserReq,
erp_auth::dto::UpdateUserReq,
erp_auth::dto::RoleResp,
erp_auth::dto::CreateRoleReq,
erp_auth::dto::UpdateRoleReq,
erp_auth::dto::PermissionResp,
erp_auth::dto::AssignPermissionsReq,
erp_auth::dto::ChangePasswordReq,
)
)
)]
struct AuthApiDoc;
/// Config 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_config::handler::dictionary_handler::list_dictionaries,
erp_config::handler::dictionary_handler::create_dictionary,
erp_config::handler::dictionary_handler::update_dictionary,
erp_config::handler::dictionary_handler::delete_dictionary,
erp_config::handler::dictionary_handler::list_items_by_code,
erp_config::handler::dictionary_handler::create_item,
erp_config::handler::dictionary_handler::update_item,
erp_config::handler::menu_handler::get_menus,
erp_config::handler::menu_handler::create_menu,
erp_config::handler::menu_handler::update_menu,
erp_config::handler::menu_handler::delete_menu,
erp_config::handler::numbering_handler::list_numbering_rules,
erp_config::handler::numbering_handler::create_numbering_rule,
erp_config::handler::numbering_handler::update_numbering_rule,
erp_config::handler::numbering_handler::generate_number,
erp_config::handler::numbering_handler::delete_numbering_rule,
erp_config::handler::theme_handler::get_theme,
erp_config::handler::theme_handler::update_theme,
erp_config::handler::language_handler::list_languages,
erp_config::handler::language_handler::update_language,
erp_config::handler::setting_handler::get_setting,
erp_config::handler::setting_handler::update_setting,
erp_config::handler::setting_handler::delete_setting,
),
components(
schemas(
erp_config::dto::DictionaryResp,
erp_config::dto::CreateDictionaryReq,
erp_config::dto::UpdateDictionaryReq,
erp_config::dto::DictionaryItemResp,
erp_config::dto::CreateDictionaryItemReq,
erp_config::dto::UpdateDictionaryItemReq,
erp_config::dto::MenuResp,
erp_config::dto::CreateMenuReq,
erp_config::dto::UpdateMenuReq,
erp_config::dto::NumberingRuleResp,
erp_config::dto::CreateNumberingRuleReq,
erp_config::dto::UpdateNumberingRuleReq,
erp_config::dto::ThemeResp,
)
)
)]
struct ConfigApiDoc;
/// Workflow 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_workflow::handler::definition_handler::list_definitions,
erp_workflow::handler::definition_handler::create_definition,
erp_workflow::handler::definition_handler::get_definition,
erp_workflow::handler::definition_handler::update_definition,
erp_workflow::handler::definition_handler::publish_definition,
erp_workflow::handler::instance_handler::start_instance,
erp_workflow::handler::instance_handler::list_instances,
erp_workflow::handler::instance_handler::get_instance,
erp_workflow::handler::instance_handler::suspend_instance,
erp_workflow::handler::instance_handler::terminate_instance,
erp_workflow::handler::instance_handler::resume_instance,
erp_workflow::handler::task_handler::list_pending_tasks,
erp_workflow::handler::task_handler::list_completed_tasks,
erp_workflow::handler::task_handler::complete_task,
erp_workflow::handler::task_handler::delegate_task,
),
components(
schemas(
erp_workflow::dto::ProcessDefinitionResp,
erp_workflow::dto::CreateProcessDefinitionReq,
erp_workflow::dto::UpdateProcessDefinitionReq,
erp_workflow::dto::ProcessInstanceResp,
erp_workflow::dto::StartInstanceReq,
erp_workflow::dto::TaskResp,
erp_workflow::dto::CompleteTaskReq,
erp_workflow::dto::DelegateTaskReq,
)
)
)]
struct WorkflowApiDoc;
/// Message 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_message::handler::message_handler::list_messages,
erp_message::handler::message_handler::unread_count,
erp_message::handler::message_handler::send_message,
erp_message::handler::message_handler::mark_read,
erp_message::handler::message_handler::mark_all_read,
erp_message::handler::message_handler::delete_message,
erp_message::handler::template_handler::list_templates,
erp_message::handler::template_handler::create_template,
erp_message::handler::subscription_handler::update_subscription,
),
components(
schemas(
erp_message::dto::MessageResp,
erp_message::dto::SendMessageReq,
erp_message::dto::MessageQuery,
erp_message::dto::UnreadCountResp,
erp_message::dto::MessageTemplateResp,
erp_message::dto::CreateTemplateReq,
erp_message::dto::MessageSubscriptionResp,
erp_message::dto::UpdateSubscriptionReq,
)
)
)]
struct MessageApiDoc;
use axum::Router;
use axum::middleware as axum_middleware;
use config::AppConfig;
use erp_auth::middleware::jwt_auth_middleware_fn;
use state::AppState;
use tower_http::services::ServeDir;
use tracing_subscriber::EnvFilter;
use utoipa::OpenApi;
use erp_core::events::EventBus;
use erp_core::module::{ErpModule, ModuleContext, ModuleRegistry};
use erp_server_migration::MigratorTrait;
use sea_orm::{ConnectionTrait, FromQueryResult};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load config
let config = AppConfig::load()?;
// ── 安全检查:拒绝默认密钥 ──────────────────────────
if config.jwt.secret == "__MUST_SET_VIA_ENV__" || config.jwt.secret == "change-me-in-production" {
tracing::error!(
"JWT 密钥为默认值,拒绝启动。请设置环境变量 ERP__JWT__SECRET"
);
std::process::exit(1);
}
if config.database.url == "__MUST_SET_VIA_ENV__" {
tracing::error!(
"数据库 URL 为默认占位值,拒绝启动。请设置环境变量 ERP__DATABASE__URL"
);
std::process::exit(1);
}
if config.redis.url == "__MUST_SET_VIA_ENV__" {
tracing::error!(
"Redis URL 为默认占位值,拒绝启动。请设置环境变量 ERP__REDIS__URL"
);
std::process::exit(1);
}
if !config.wechat.dev_mode && (config.wechat.appid == "__MUST_SET_VIA_ENV__" || config.wechat.secret == "__MUST_SET_VIA_ENV__") {
tracing::error!(
"微信凭据为默认占位值,拒绝启动。请设置环境变量 ERP__WECHAT__APPID 和 ERP__WECHAT__SECRET"
);
std::process::exit(1);
}
if config.health.aes_key == "__MUST_SET_VIA_ENV__" || config.health.hmac_key == "__MUST_SET_VIA_ENV__" {
// 注: health 密钥已被统一 KEK (ERP__CRYPTO__KEK) 替代,此处仅保留兼容性检查
tracing::warn!(
"ERP__HEALTH__AES_KEY/HMAC_KEY 未设置(已迁移到 ERP__CRYPTO__KEK 统一密钥体系)"
);
}
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log.level)),
)
.json()
.init();
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"ERP Server starting..."
);
// Connect to database
let db = db::connect(&config.database).await?;
// Run migrations
erp_server_migration::Migrator::up(&db, None).await?;
tracing::info!("Database migrations applied");
// Seed default tenant and auth data if not present, and resolve the actual tenant ID
let default_tenant_id = {
#[derive(sea_orm::FromQueryResult)]
struct TenantId {
id: uuid::Uuid,
}
let existing = TenantId::find_by_statement(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"SELECT id FROM tenant WHERE deleted_at IS NULL LIMIT 1".to_string(),
))
.one(&db)
.await
.map_err(|e| anyhow::anyhow!("Failed to query tenants: {}", e))?;
match existing {
Some(row) => {
tracing::info!(tenant_id = %row.id, "Default tenant already exists, skipping seed");
row.id
}
None => {
let new_tenant_id = uuid::Uuid::now_v7();
// Insert default tenant using raw SQL (no tenant entity in erp-server)
db.execute(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
"INSERT INTO tenant (id, name, code, status, created_at, updated_at) VALUES ($1, $2, $3, $4, NOW(), NOW())",
[
new_tenant_id.into(),
"Default Tenant".into(),
"default".into(),
"active".into(),
],
))
.await
.map_err(|e| anyhow::anyhow!("Failed to create default tenant: {}", e))?;
tracing::info!(tenant_id = %new_tenant_id, "Created default tenant");
// Seed auth data (permissions, roles, admin user)
erp_auth::service::seed::seed_tenant_auth(
&db,
new_tenant_id,
&config.auth.super_admin_password,
)
.await
.map_err(|e| anyhow::anyhow!("Failed to seed auth data: {}", e))?;
tracing::info!(tenant_id = %new_tenant_id, "Default tenant ready with auth seed data");
// Seed AI workflow definitions
if let Err(e) = erp_workflow::service::ai_workflow_seed::ensure_ai_workflows(&db, new_tenant_id).await {
tracing::warn!(error = %e, "Failed to seed AI workflow definitions");
}
// Seed dialysis session workflow definition
if let Err(e) = dialysis_workflow::seed_dialysis_session_workflow(&db, new_tenant_id, new_tenant_id).await {
tracing::warn!(error = %e, "Failed to seed dialysis session workflow");
}
new_tenant_id
}
}
};
// Connect to Redis
let redis_client = redis::Client::open(&config.redis.url[..])?;
tracing::info!("Redis client created");
// Initialize event bus (capacity 1024 events)
let event_bus = EventBus::new(1024);
// Initialize auth module
let auth_module = erp_auth::AuthModule::new();
tracing::info!(
module = auth_module.name(),
version = auth_module.version(),
"Auth module initialized"
);
// Initialize config module
let config_module = erp_config::ConfigModule::new();
tracing::info!(
module = config_module.name(),
version = config_module.version(),
"Config module initialized"
);
// Initialize workflow module
let workflow_module = erp_workflow::WorkflowModule::new();
tracing::info!(
module = workflow_module.name(),
version = workflow_module.version(),
"Workflow module initialized"
);
// Initialize message module
let message_module = erp_message::MessageModule::new();
tracing::info!(
module = message_module.name(),
version = message_module.version(),
"Message module initialized"
);
// Initialize health module
let health_module = erp_health::HealthModule::new();
tracing::info!(
module = health_module.name(),
version = health_module.version(),
"Health module initialized"
);
// Initialize AI module
let ai_module = erp_ai::AiModule;
tracing::info!(
module = ai_module.name(),
version = ai_module.version(),
"AI module initialized"
);
// Points module 已统一到 erp-health/health/points/* 路由)
// Initialize dialysis module
let dialysis_module = erp_dialysis::DialysisModule;
tracing::info!(
module = dialysis_module.name(),
version = dialysis_module.version(),
"Dialysis module initialized"
);
// Initialize module registry and register modules
let registry = ModuleRegistry::new()
.register(auth_module)
.register(config_module)
.register(workflow_module)
.register(message_module)
.register(health_module)
.register(ai_module)
.register(dialysis_module);
tracing::info!(
module_count = registry.modules().len(),
"Modules registered"
);
// Initialize plugin engine
let plugin_config = erp_plugin::engine::PluginEngineConfig::default();
let plugin_engine = erp_plugin::engine::PluginEngine::new(
db.clone(),
event_bus.clone(),
plugin_config,
)?;
tracing::info!("Plugin engine initialized");
// Register plugin module
let plugin_module = erp_plugin::module::PluginModule;
let registry = registry.register(plugin_module);
// Register event handlers
registry.register_handlers(&event_bus);
// Startup all modules (按拓扑顺序调用 on_startup)
let module_ctx = ModuleContext {
db: db.clone(),
event_bus: event_bus.clone(),
};
registry.startup_all(&module_ctx).await?;
tracing::info!("All modules started");
// 同步所有模块声明的权限到数据库upsert
sync_module_permissions(&db, &registry, default_tenant_id).await?;
// 恢复运行中的插件(服务器重启后自动重新加载)
match plugin_engine.recover_plugins(&db).await {
Ok(recovered) => {
tracing::info!(count = recovered.len(), "Plugins recovered");
}
Err(e) => {
tracing::error!(error = %e, "Failed to recover plugins");
}
}
// Start message event listener (workflow events → message notifications)
erp_message::MessageModule::start_event_listener(db.clone(), event_bus.clone());
tracing::info!("Message event listener started");
// Start plugin notification listener (plugin.trigger.* → admin notifications)
erp_plugin::notification::start_notification_listener(db.clone(), event_bus.clone());
tracing::info!("Plugin notification listener started");
// Start dialysis workflow orchestrator (dialysis.record.created → BPMN workflow)
dialysis_workflow::start_dialysis_workflow_orchestrator(db.clone(), event_bus.clone());
tracing::info!("Dialysis workflow orchestrator started");
// Start outbox relay (LISTEN/NOTIFY + fallback poll for pending domain events)
outbox::start_outbox_relay(db.clone(), event_bus.clone(), config.database.url.clone());
tracing::info!("Outbox relay started");
// Start event cleanup (archive old published events + purge processed_events)
tasks::start_event_cleanup(db.clone());
// Start DB connection pool metrics sampling (every 30s)
tasks::start_pool_metrics(db.clone());
// Start timeout checker (scan overdue tasks every 60s)
erp_workflow::WorkflowModule::start_timeout_checker(db.clone(), event_bus.clone());
tracing::info!("Timeout checker started");
// Health 模块后台任务已统一在 HealthModule::on_startup() 中启动
let host = config.server.host.clone();
let port = config.server.port;
// Extract JWT secret for middleware construction
let jwt_secret = config.jwt.secret.clone();
// Pre-build AI state (avoids per-request reconstruction)
let ai_state = {
// 构建多 Provider 注册表
let registry = std::sync::Arc::new(erp_ai::provider::registry::ProviderRegistry::new());
// 始终注册默认 Claude provider兼容旧配置
{
let mut claude = erp_ai::provider::claude::ClaudeProvider::new(
config.ai.api_key.clone(),
);
if let Some(ref base_url) = config.ai.base_url {
claude = claude.with_base_url(base_url.clone());
}
registry.register("claude".to_string(), std::sync::Arc::new(claude));
}
// 注册配置中的额外 Provider
for (name, pcfg) in &config.ai.providers {
if !pcfg.is_enabled {
continue;
}
match pcfg.provider_type.as_str() {
"openai" => {
let api_key = pcfg.api_key_env.as_ref()
.and_then(|env| std::env::var(env).ok())
.unwrap_or_default();
let base_url = pcfg.base_url.clone()
.unwrap_or_else(|| "https://api.openai.com".to_string());
let provider = erp_ai::provider::openai::OpenAIProvider::new(
api_key, base_url, pcfg.default_model.clone(),
);
registry.register(name.clone(), std::sync::Arc::new(provider));
tracing::info!(provider = %name, "已注册 OpenAI 兼容提供商");
}
"ollama" => {
let base_url = pcfg.base_url.clone()
.unwrap_or_else(|| "http://localhost:11434".to_string());
let provider = erp_ai::provider::ollama::OllamaProvider::new(
base_url, pcfg.default_model.clone(),
);
registry.register(name.clone(), std::sync::Arc::new(provider));
tracing::info!(provider = %name, "已注册 Ollama 本地提供商");
}
"claude" => {
// 已作为默认注册,跳过
}
other => {
tracing::warn!(provider = %name, provider_type = %other, "未知的提供商类型,已跳过");
}
}
}
tracing::info!(providers = ?registry.provider_names(), "AI Provider 注册完成");
// 构建默认 provider 用于 AnalysisService保持 Box<dyn AiProvider> 签名)
let mut default_claude = erp_ai::provider::claude::ClaudeProvider::new(
config.ai.api_key.clone(),
);
if let Some(ref base_url) = config.ai.base_url {
default_claude = default_claude.with_base_url(base_url.clone());
}
let analysis_svc = erp_ai::service::analysis::AnalysisService::new(
Box::new(default_claude),
db.clone(),
).with_knowledge_source(std::sync::Arc::new(
erp_ai::knowledge::structured_source::StructuredKnowledgeSource::new(db.clone()),
));
let analysis = std::sync::Arc::new(analysis_svc);
let prompt = std::sync::Arc::new(erp_ai::service::prompt::PromptService::new(db.clone()));
let usage = std::sync::Arc::new(erp_ai::service::usage::UsageService::new(db.clone()));
let suggestion = std::sync::Arc::new(erp_ai::service::suggestion::SuggestionService);
let health_provider = std::sync::Arc::new(erp_health::HealthDataProviderImpl {
db: db.clone(),
});
let quota = std::sync::Arc::new(erp_ai::service::quota::QuotaService::new(
db.clone(),
config.ai.quota_check_enabled,
));
let cache_ttl = std::time::Duration::from_secs(config.ai.cache_ttl_seconds);
let cache = std::sync::Arc::new(erp_ai::service::cache::CacheService::new(
redis_client.clone(),
db.clone(),
cache_ttl,
));
erp_ai::AiState {
db: db.clone(),
event_bus: event_bus.clone(),
analysis,
prompt,
usage,
suggestion,
health_provider,
provider_registry: registry,
quota,
cache,
}
};
// Start auto trend analysis (every 24h, scans high-risk patients)
erp_ai::service::auto_analysis::start_auto_analysis(ai_state.clone());
tracing::info!("Auto trend analysis scheduler started");
// Build shared state
let pii_crypto = if config.crypto.kek == "__MUST_SET_VIA_ENV__" {
#[cfg(debug_assertions)]
{
tracing::warn!("⚠️ PII KEK 使用开发默认值,仅用于本地开发");
erp_core::crypto::PiiCrypto::dev_default()
}
#[cfg(not(debug_assertions))]
{
panic!("ERP__CRYPTO__KEK must be set in production. Use a 64-char hex string (32 bytes).");
}
} else {
erp_core::crypto::PiiCrypto::from_kek_hex(&config.crypto.kek)
.expect("PII KEK must be valid 64-char hex (32 bytes). Set ERP__CRYPTO__KEK")
};
let state = AppState {
db,
config,
event_bus,
module_registry: registry,
redis: redis_client.clone(),
default_tenant_id,
plugin_engine,
plugin_entity_cache: moka::sync::Cache::builder()
.max_capacity(1000)
.time_to_idle(std::time::Duration::from_secs(300))
.build(),
ai_state,
pii_crypto,
};
// --- Build the router ---
//
// The router is split into two layers:
// 1. Public routes: no JWT required (health, login, refresh)
// 2. Protected routes: JWT required (user CRUD, logout)
//
// Both layers share the same AppState. The protected layer wraps routes
// with the jwt_auth_middleware_fn.
// Public routes (no authentication, but IP-based rate limiting)
// Layer execution order (outer → inner): account_lockout → rate_limit_by_ip
// So account lockout check runs FIRST, then IP rate limiting
let public_routes = Router::new()
.merge(erp_auth::AuthModule::public_routes())
.merge(erp_health::HealthModule::public_routes())
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::account_lockout_middleware,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::rate_limit_by_ip,
))
.with_state(state.clone());
// Unthrottled public routes (health, docs, brand) — no rate limiting
let unthrottled_routes = Router::new()
.merge(handlers::health::health_check_router())
.route(
"/docs/openapi.json",
axum::routing::get(handlers::openapi::openapi_spec),
)
.merge(erp_config::ConfigModule::public_routes())
.with_state(state.clone());
// Clone jwt_secret for upload auth before protected_routes closure moves it
let secret_for_uploads = jwt_secret.clone();
// Protected routes (JWT authentication required)
// User-based rate limiting (100 req/min) applied after JWT auth
let protected_routes = erp_auth::AuthModule::protected_routes()
.merge(erp_config::ConfigModule::protected_routes())
.merge(erp_workflow::WorkflowModule::protected_routes())
.merge(erp_message::MessageModule::protected_routes())
.merge(erp_plugin::module::PluginModule::protected_routes())
.merge(erp_health::HealthModule::protected_routes())
.merge(erp_ai::AiModule::protected_routes())
.merge(erp_dialysis::DialysisModule::protected_routes())
.merge(handlers::audit_log::audit_log_router())
.route(
"/upload",
axum::routing::post(handlers::upload::upload_file),
)
.route(
"/admin/tenants/{id}/rotate-key",
axum::routing::post(handlers::crypto_admin::rotate_tenant_key),
)
.route(
"/analytics/batch",
axum::routing::post(handlers::analytics::batch),
)
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::rate_limit_by_user,
))
.layer({
let db = state.db.clone();
axum_middleware::from_fn(move |req, next| {
let secret = jwt_secret.clone();
let db = db.clone();
async move { jwt_auth_middleware_fn(secret, Some(db), req, next).await }
})
})
// Tenant RLS — 在 JWT 之后执行SET app.current_tenant_id
.layer({
let db = state.db.clone();
axum_middleware::from_fn(move |req, next| {
let db = db.clone();
async move { middleware::tenant_rls::tenant_rls_middleware(db, req, next).await }
})
})
.with_state(state.clone());
// Merge public + protected into the final application router
// All API routes are nested under /api/v1
let cors = build_cors_layer(&state.config.cors.allowed_origins);
let upload_dir = state.config.storage.upload_dir.clone();
let uploads_router = Router::new()
.fallback_service(ServeDir::new(&upload_dir))
.layer(axum_middleware::from_fn(move |req, next| {
let secret = secret_for_uploads.clone();
async move { upload_auth_middleware(secret, req, next).await }
}));
let app = Router::new()
.nest("/api/v1", unthrottled_routes.merge(public_routes).merge(protected_routes))
.nest("/fhir", erp_health::HealthModule::fhir_routes().with_state(state.clone()))
.nest(
"/health/gateway",
erp_health::HealthModule::gateway_routes()
.layer(axum::middleware::from_fn_with_state(
state.clone(),
erp_health::gateway_auth::gateway_auth_middleware,
))
.with_state(state.clone()),
)
.nest("/uploads", uploads_router)
.layer(axum::middleware::from_fn(middleware::metrics::metrics_middleware))
.layer(cors);
// Start Prometheus metrics exporter on a separate port
let metrics_port = state.config.server.metrics_port;
middleware::metrics::start_metrics_server(metrics_port);
let addr = format!("{}:{}", host, port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!(addr = %addr, "Server listening");
// Graceful shutdown on CTRL+C
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
// 优雅关闭所有模块(按拓扑逆序)
state.module_registry.shutdown_all().await?;
tracing::info!("Server shutdown complete");
Ok(())
}
/// JWT auth middleware for `/uploads` file serving.
///
/// Accepts token from either `Authorization: Bearer <token>` header
/// or `?token=<token>` query parameter (for browser `<img>` / direct downloads).
async fn upload_auth_middleware(
jwt_secret: String,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Result<axum::response::Response, erp_core::error::AppError> {
use erp_auth::service::token_service::TokenService;
let token = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|s| s.to_string())
.or_else(|| {
req.uri().query().and_then(|q| {
q.split('&').find_map(|pair| {
let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
if k == "token" && !v.is_empty() {
Some(v.to_string())
} else {
None
}
})
})
});
let token = token.ok_or(erp_core::error::AppError::Unauthorized)?;
let claims = TokenService::decode_token(&token, &jwt_secret)
.map_err(|_| erp_core::error::AppError::Unauthorized)?;
if claims.token_type != "access" {
return Err(erp_core::error::AppError::Unauthorized);
}
Ok(next.run(req).await)
}
/// Build a CORS layer from the comma-separated allowed origins config.
///
/// If the config is "*", allows all origins (development mode).
/// Otherwise, parses each origin as a URL and restricts to those origins only.
fn build_cors_layer(allowed_origins: &str) -> tower_http::cors::CorsLayer {
use axum::http::HeaderValue;
use tower_http::cors::AllowOrigin;
let origins = allowed_origins
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>();
if origins.len() == 1 && origins[0] == "*" {
tracing::warn!(
"⚠️ CORS 允许所有来源 — 仅限开发环境使用!\
生产环境请通过 ERP__CORS__ALLOWED_ORIGINS 设置具体的来源域名"
);
return tower_http::cors::CorsLayer::permissive();
}
let allowed: Vec<HeaderValue> = origins
.iter()
.filter_map(|o| o.parse::<HeaderValue>().ok())
.collect();
tracing::info!(origins = ?origins, "CORS: restricting to allowed origins");
tower_http::cors::CorsLayer::new()
.allow_origin(AllowOrigin::list(allowed))
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::DELETE,
axum::http::Method::PATCH,
])
.allow_headers([
axum::http::header::AUTHORIZATION,
axum::http::header::CONTENT_TYPE,
])
.allow_credentials(true)
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received CTRL+C, shutting down gracefully...");
},
_ = terminate => {
tracing::info!("Received SIGTERM, shutting down gracefully...");
},
}
}
/// 同步所有模块声明的权限到数据库。
///
/// 对每个模块的 `permissions()` 返回的权限执行 upsert
/// - 新权限INSERT
/// - 已有权限(同 tenant_id + code跳过
/// 同时将新权限分配给 admin 角色。
async fn sync_module_permissions(
db: &sea_orm::DatabaseConnection,
registry: &erp_core::module::ModuleRegistry,
tenant_id: uuid::Uuid,
) -> Result<(), anyhow::Error> {
let system_user_id = uuid::Uuid::nil();
let mut total_new = 0u32;
for module in registry.modules() {
let perms = module.permissions();
if perms.is_empty() {
continue;
}
for perm in perms {
let result = db.execute(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
r#"INSERT INTO permissions (id, tenant_id, code, name, resource, action, description, created_at, updated_at, created_by, updated_by, deleted_at, version)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW(), $8, $8, NULL, 1)
ON CONFLICT (tenant_id, code) WHERE deleted_at IS NULL DO NOTHING"#,
[
uuid::Uuid::now_v7().into(),
tenant_id.into(),
perm.code.clone().into(),
perm.name.clone().into(),
perm.module.clone().into(),
perm.code.split('.').last().unwrap_or("manage").into(),
perm.description.clone().into(),
system_user_id.into(),
],
)).await?;
let rows = result.rows_affected();
if rows > 0 {
total_new += 1;
}
}
}
// 每次启动都确保 admin 角色拥有所有模块权限(防止权限-角色关联缺失)
db.execute(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
r#"INSERT INTO role_permissions (role_id, permission_id, tenant_id, data_scope, created_at, updated_at, created_by, updated_by, deleted_at, version)
SELECT r.id, p.id, p.tenant_id, 'all', NOW(), NOW(), $1, $1, NULL, 1
FROM permissions p
JOIN roles r ON r.code = 'admin' AND r.tenant_id = p.tenant_id AND r.deleted_at IS NULL
WHERE p.tenant_id = $2
ON CONFLICT DO NOTHING"#,
[system_user_id.into(), tenant_id.into()],
)).await?;
if total_new > 0 {
tracing::info!(total_new, "New module permissions synced and bound to admin role");
}
Ok(())
}