feat: initialize Nuanji (Warm Notes) project

- Base platform from base.git (ERP base: auth, core, config, message, workflow, plugin)
- Created erp-diary module skeleton (lib.rs, dto.rs, error.rs, event.rs, state.rs)
- Integrated erp-diary into workspace and erp-server
- Added DiaryModule registration in main.rs
- Added DiaryState FromRef in state.rs
- Diary routes mounted (empty routes, ready for implementation)
- Product design spec v1.2 preserved in docs/
- Implementation plan preserved in plans/

Cargo check: OK
Cargo test: OK (78+ base tests passing)
This commit is contained in:
iven
2026-05-31 20:52:19 +08:00
commit c539e6fd83
285 changed files with 59156 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub redis: RedisConfig,
pub jwt: JwtConfig,
pub auth: AuthConfig,
pub log: LogConfig,
pub cors: CorsConfig,
pub wechat: WechatConfig,
pub crypto: CryptoConfig,
pub storage: StorageConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
#[serde(default = "default_metrics_port")]
pub metrics_port: u16,
}
fn default_metrics_port() -> u16 {
9090
}
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RedisConfig {
pub url: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct JwtConfig {
pub secret: String,
pub access_token_ttl: String,
pub refresh_token_ttl: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LogConfig {
pub level: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
pub super_admin_password: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CorsConfig {
/// Comma-separated list of allowed origins.
/// Use "*" to allow all origins (development only).
pub allowed_origins: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WechatConfig {
pub appid: String,
pub secret: String,
#[serde(default)]
pub dev_mode: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CryptoConfig {
/// Master KEK (64 字符 hex 编码32 字节)。用于加密保护每租户 DEK。
/// Phase A 阶段同时作为全局数据加密密钥使用。
pub kek: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StorageConfig {
/// 文件上传目录(本地存储)
pub upload_dir: String,
/// 单文件最大大小(如 "10MB"
pub max_file_size: String,
/// 签名 URL 密钥HMAC-SHA256
#[serde(default = "default_secret_key")]
pub secret_key: String,
}
fn default_secret_key() -> String {
#[cfg(debug_assertions)]
{
"dev-only-secret-key-change-in-production".to_string()
}
#[cfg(not(debug_assertions))]
{
panic!("ERP__STORAGE__SECRET_KEY 必须设置(生产环境不允许使用默认签名密钥)")
}
}
impl StorageConfig {
/// 解析 max_file_size 为字节数
pub fn max_file_size_bytes(&self) -> u64 {
let s = self.max_file_size.to_uppercase();
if let Some(num) = s.strip_suffix("MB") {
num.trim().parse::<u64>().unwrap_or(10) * 1024 * 1024
} else if let Some(num) = s.strip_suffix("KB") {
num.trim().parse::<u64>().unwrap_or(1024) * 1024
} else if let Some(num) = s.strip_suffix("GB") {
num.trim().parse::<u64>().unwrap_or(1) * 1024 * 1024 * 1024
} else {
s.parse::<u64>().unwrap_or(10 * 1024 * 1024)
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct RateLimitConfig {
/// Redis 不可达时是否拒绝请求fail-close
/// true = 安全优先Redis 故障时返回 503。
/// false = 可用性优先Redis 故障时放行。
#[serde(default = "default_fail_close")]
pub fail_close: bool,
}
fn default_fail_close() -> bool {
true
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
fail_close: default_fail_close(),
}
}
}
impl AppConfig {
pub fn load() -> anyhow::Result<Self> {
let config = config::Config::builder()
.add_source(config::File::with_name("config/default"))
.add_source(config::Environment::with_prefix("ERP").separator("__"))
.build()?;
let app_config: Self = config.try_deserialize()?;
// 安全检查:禁止在生产使用默认 JWT 密钥
if app_config.jwt.secret == "change-me-in-production" {
tracing::warn!("⚠️ JWT 密钥使用默认值,请通过 ERP__JWT__SECRET 环境变量设置安全密钥");
}
Ok(app_config)
}
}

View File

@@ -0,0 +1,16 @@
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use std::time::Duration;
use crate::config::DatabaseConfig;
pub async fn connect(config: &DatabaseConfig) -> anyhow::Result<DatabaseConnection> {
let mut opt = ConnectOptions::new(&config.url);
opt.max_connections(config.max_connections)
.min_connections(config.min_connections)
.connect_timeout(Duration::from_secs(10))
.idle_timeout(Duration::from_secs(600));
let db = Database::connect(opt).await?;
tracing::info!("Database connected successfully");
Ok(db)
}

View File

@@ -0,0 +1,65 @@
use axum::Json;
use axum::extract::Extension;
use serde::Deserialize;
use tracing;
use erp_core::error::AppError;
use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, TenantContext};
const MAX_EVENTS_PER_BATCH: usize = 100;
#[derive(Debug, Deserialize)]
#[allow(dead_code)] // 客户端上报结构体,字段后续接入分析表时使用
pub struct AnalyticsEvent {
pub event: String,
pub properties: Option<serde_json::Value>,
#[serde(deserialize_with = "deserialize_flexible_timestamp")]
pub timestamp: Option<String>,
pub page: Option<String>,
pub user_id: Option<String>,
pub patient_id: Option<String>,
}
fn deserialize_flexible_timestamp<'de, D>(de: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
let val = Option::<serde_json::Value>::deserialize(de)?;
match val {
None => Ok(None),
Some(serde_json::Value::String(s)) => Ok(Some(s)),
Some(serde_json::Value::Number(n)) => Ok(Some(n.to_string())),
_ => Err(de::Error::custom("timestamp must be string or number")),
}
}
#[derive(Debug, Deserialize)]
pub struct BatchRequest {
pub events: Vec<AnalyticsEvent>,
}
/// 接收小程序批量埋点事件。
/// 当前为日志记录模式 — 后续可接入 ClickHouse/PostgreSQL 分析表。
pub async fn batch(
Extension(ctx): Extension<TenantContext>,
Json(req): Json<BatchRequest>,
) -> Result<Json<ApiResponse<()>>, AppError> {
require_permission(&ctx, "system.analytics.submit")?;
if req.events.len() > MAX_EVENTS_PER_BATCH {
return Err(AppError::Validation(format!(
"批量埋点事件数不能超过 {}",
MAX_EVENTS_PER_BATCH
)));
}
for evt in &req.events {
tracing::info!(
event = %evt.event,
page = ?evt.page,
properties = ?evt.properties,
"Analytics event received"
);
}
Ok(Json(ApiResponse::ok(())))
}

View File

@@ -0,0 +1,156 @@
use axum::Router;
use axum::extract::{Extension, FromRef, Query, State};
use axum::response::Json;
use axum::routing::get;
use sea_orm::{ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder};
use serde::{Deserialize, Serialize};
use erp_core::entity::audit_log;
use erp_core::error::AppError;
use erp_core::types::{ApiResponse, PaginatedResponse, TenantContext};
#[derive(Debug, Deserialize)]
pub struct AuditLogQuery {
pub resource_type: Option<String>,
pub user_id: Option<uuid::Uuid>,
pub page: Option<u64>,
pub page_size: Option<u64>,
}
#[derive(Debug, Serialize)]
pub struct AuditLogResp {
pub id: uuid::Uuid,
pub tenant_id: uuid::Uuid,
pub user_id: Option<uuid::Uuid>,
pub user_name: Option<String>,
pub action: String,
pub resource_type: String,
pub resource_id: Option<uuid::Uuid>,
pub old_value: Option<serde_json::Value>,
pub new_value: Option<serde_json::Value>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl From<audit_log::Model> for AuditLogResp {
fn from(m: audit_log::Model) -> Self {
Self {
id: m.id,
tenant_id: m.tenant_id,
user_id: m.user_id,
user_name: None,
action: m.action,
resource_type: m.resource_type,
resource_id: m.resource_id,
old_value: m.old_value,
new_value: m.new_value,
ip_address: m.ip_address,
user_agent: m.user_agent,
created_at: m.created_at,
}
}
}
async fn resolve_user_names(
db: &sea_orm::DatabaseConnection,
items: &[audit_log::Model],
) -> std::collections::HashMap<uuid::Uuid, String> {
use erp_auth::entity::user;
let user_ids: Vec<uuid::Uuid> = items
.iter()
.filter_map(|i| i.user_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
if user_ids.is_empty() {
return std::collections::HashMap::new();
}
let users = user::Entity::find()
.filter(user::Column::Id.is_in(user_ids))
.all(db)
.await
.unwrap_or_default();
users
.into_iter()
.map(|u| {
let name = u
.display_name
.filter(|n| !n.is_empty())
.unwrap_or(u.username);
(u.id, name)
})
.collect()
}
/// GET /audit-logs
pub async fn list_audit_logs<S>(
State(db): State<sea_orm::DatabaseConnection>,
Extension(ctx): Extension<TenantContext>,
Query(params): Query<AuditLogQuery>,
) -> Result<Json<ApiResponse<PaginatedResponse<AuditLogResp>>>, AppError>
where
sea_orm::DatabaseConnection: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
let page = params.page.unwrap_or(1).max(1);
let page_size = params.page_size.unwrap_or(20).min(100);
let tenant_id = ctx.tenant_id;
let mut q = audit_log::Entity::find().filter(audit_log::Column::TenantId.eq(tenant_id));
if let Some(rt) = &params.resource_type {
q = q.filter(audit_log::Column::ResourceType.eq(rt.clone()));
}
if let Some(uid) = &params.user_id {
q = q.filter(audit_log::Column::UserId.eq(*uid));
}
let paginator = q
.order_by_desc(audit_log::Column::CreatedAt)
.paginate(&db, page_size);
let total = paginator
.num_items()
.await
.map_err(|e| AppError::Internal(format!("查询审计日志失败: {e}")))?;
let items = paginator
.fetch_page(page - 1)
.await
.map_err(|e| AppError::Internal(format!("查询审计日志失败: {e}")))?;
let user_map = resolve_user_names(&db, &items).await;
let resp_items: Vec<AuditLogResp> = items
.into_iter()
.map(|m| {
let user_name = m.user_id.and_then(|uid| user_map.get(&uid).cloned());
let mut resp = AuditLogResp::from(m);
resp.user_name = user_name;
resp
})
.collect();
let total_pages = total.div_ceil(page_size);
Ok(Json(ApiResponse::ok(PaginatedResponse {
data: resp_items,
total,
page,
page_size,
total_pages,
})))
}
pub fn audit_log_router<S>() -> Router<S>
where
sea_orm::DatabaseConnection: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
Router::new().route("/audit-logs", get(list_audit_logs))
}

View File

@@ -0,0 +1,76 @@
use axum::Extension;
use axum::Json;
use axum::extract::{FromRef, Path, State};
use sea_orm::{ConnectionTrait, DatabaseBackend, Statement};
use serde_json::{Value, json};
use uuid::Uuid;
use erp_core::error::AppError;
use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, TenantContext};
use crate::state::AppState;
/// POST /api/v1/admin/tenants/:id/rotate-key
/// 密钥轮换 — 生成新 DEK持久化到 tenant_crypto_keys使缓存失效
pub async fn rotate_tenant_key<S>(
State(state): State<AppState>,
Extension(ctx): Extension<TenantContext>,
Path(tenant_id): Path<Uuid>,
) -> Result<Json<ApiResponse<Value>>, AppError>
where
AppState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "tenant.manage")?;
// 读取当前最大版本号
let max_version: Option<i32> = {
let row = state.db.query_one(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"SELECT COALESCE(MAX(key_version), 0) as v FROM tenant_crypto_keys WHERE tenant_id = $1 AND deleted_at IS NULL",
[tenant_id.into()],
)).await.map_err(|e| AppError::Internal(format!("查询密钥版本失败: {}", e)))?;
row.and_then(|r| r.try_get_by_index::<i32>(0).ok())
};
let current_version = max_version.unwrap_or(0);
let new_version = current_version + 1;
// 将旧版本标记为不活跃
state.db.execute(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"UPDATE tenant_crypto_keys SET is_active = false, updated_at = now() WHERE tenant_id = $1 AND is_active = true AND deleted_at IS NULL",
[tenant_id.into()],
)).await.map_err(|e| AppError::Internal(format!("停用旧 DEK 失败: {}", e)))?;
// 生成新 DEK 并用 KEK 加密
let kek = state.pii_crypto.kek();
let (_new_dek, encrypted_dek) = erp_core::crypto::DekManager::generate_new_dek(kek)
.map_err(|e| AppError::Internal(format!("生成新 DEK 失败: {}", e)))?;
// 持久化新 DEK
let new_id = Uuid::now_v7();
state.db.execute(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"INSERT INTO tenant_crypto_keys (id, tenant_id, encrypted_dek, key_version, is_active, created_at, updated_at, version) VALUES ($1, $2, $3, $4, true, now(), now(), 1)",
[new_id.into(), tenant_id.into(), encrypted_dek.into(), new_version.into()],
)).await.map_err(|e| AppError::Internal(format!("存储新 DEK 失败: {}", e)))?;
// 使 DEK 缓存失效
state.pii_crypto.invalidate_dek(tenant_id);
tracing::info!(
tenant_id = %tenant_id,
old_version = current_version,
new_version = new_version,
"密钥轮换完成(新 DEK 已持久化,缓存已清除)"
);
Ok(Json(ApiResponse::ok(json!({
"message": "密钥轮换已完成",
"tenant_id": tenant_id,
"old_version": current_version,
"new_version": new_version,
"note": "后台重加密任务需要单独触发,旧数据仍可用旧 DEK 解密"
}))))
}

View File

@@ -0,0 +1,135 @@
use axum::Router;
use axum::extract::State;
use axum::response::Json;
use axum::routing::get;
use serde::Serialize;
use crate::state::AppState;
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: String,
pub version: String,
pub modules: Vec<String>,
}
/// GET /health — 轻量存活检查
pub async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
let modules = state
.module_registry
.modules()
.iter()
.map(|m| m.name().to_string())
.collect();
Json(HealthResponse {
status: "ok".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
modules,
})
}
#[derive(Debug, Serialize)]
pub struct ReadyResponse {
pub status: String,
pub version: String,
pub database: ComponentStatus,
pub redis: ComponentStatus,
pub modules: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct ComponentStatus {
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub latency_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// GET /health/ready — 就绪检查(含 DB + Redis 连通性)
pub async fn readiness_check(State(state): State<AppState>) -> Json<ReadyResponse> {
let modules = state
.module_registry
.modules()
.iter()
.map(|m| m.name().to_string())
.collect();
let (db_status, redis_status) =
tokio::join!(check_database(&state.db), check_redis(&state.redis),);
let overall = if db_status.status == "ok" && redis_status.status == "ok" {
"ok"
} else if db_status.status == "ok" {
"degraded"
} else {
"unavailable"
};
Json(ReadyResponse {
status: overall.to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
database: db_status,
redis: redis_status,
modules,
})
}
async fn check_database(db: &sea_orm::DatabaseConnection) -> ComponentStatus {
use sea_orm::ConnectionTrait;
let start = std::time::Instant::now();
let stmt =
sea_orm::Statement::from_string(sea_orm::DatabaseBackend::Postgres, "SELECT 1".to_string());
match db.query_one(stmt).await {
Ok(_) => ComponentStatus {
status: "ok".to_string(),
latency_ms: Some(start.elapsed().as_millis() as u64),
error: None,
},
Err(e) => {
tracing::error!(error = %e, "Database health check failed");
ComponentStatus {
status: "error".to_string(),
latency_ms: Some(start.elapsed().as_millis() as u64),
error: Some("connection failed".to_string()),
}
}
}
}
async fn check_redis(client: &redis::Client) -> ComponentStatus {
let start = std::time::Instant::now();
match client.get_multiplexed_async_connection().await {
Ok(mut conn) => match redis::cmd("PING").query_async::<String>(&mut conn).await {
Ok(_) => ComponentStatus {
status: "ok".to_string(),
latency_ms: Some(start.elapsed().as_millis() as u64),
error: None,
},
Err(e) => {
tracing::error!(error = %e, "Redis PING failed");
ComponentStatus {
status: "error".to_string(),
latency_ms: Some(start.elapsed().as_millis() as u64),
error: Some("connection failed".to_string()),
}
}
},
Err(e) => {
tracing::error!(error = %e, "Redis connection failed");
ComponentStatus {
status: "error".to_string(),
latency_ms: Some(start.elapsed().as_millis() as u64),
error: Some("connection failed".to_string()),
}
}
}
}
pub fn health_check_router() -> Router<AppState> {
Router::new()
.route("/health", get(health_check))
.route("/health/live", get(health_check))
.route("/health/ready", get(readiness_check))
}

View File

@@ -0,0 +1,6 @@
pub mod analytics;
pub mod audit_log;
pub mod crypto_admin;
pub mod health;
pub mod openapi;
pub mod upload;

View File

@@ -0,0 +1,25 @@
use axum::response::{IntoResponse, Json, Response};
use utoipa::OpenApi;
use crate::{ApiDoc, AuthApiDoc, ConfigApiDoc, MessageApiDoc, WorkflowApiDoc};
/// GET /docs/openapi.json
///
/// 返回 OpenAPI 3.0 规范 JSON 文档,合并所有模块的路径和 schema。
/// 仅在 debug 模式下可用,生产构建返回 404。
pub async fn openapi_spec() -> Response {
#[cfg(debug_assertions)]
{
let mut spec = ApiDoc::openapi();
spec.merge(AuthApiDoc::openapi());
spec.merge(ConfigApiDoc::openapi());
spec.merge(WorkflowApiDoc::openapi());
spec.merge(MessageApiDoc::openapi());
Json(serde_json::to_value(spec).unwrap_or_default()).into_response()
}
#[cfg(not(debug_assertions))]
{
(axum::http::StatusCode::NOT_FOUND, "Not Found").into_response()
}
}

View File

@@ -0,0 +1,220 @@
use axum::Extension;
use axum::extract::{FromRef, Multipart, State};
use axum::response::Json;
use erp_core::error::AppError;
use erp_core::types::{ApiResponse, TenantContext};
use serde::Serialize;
use uuid::Uuid;
use crate::state::AppState;
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct UploadResp {
pub url: String,
pub filename: String,
pub size: u64,
pub content_type: String,
}
/// 上传单个文件。
///
/// 接受 multipart/form-data将文件保存到本地目录
/// 返回可通过 `/uploads/` 前缀访问的 URL。
#[utoipa::path(
post,
path = "/upload",
request_body(content_type = "multipart/form-data"),
responses(
(status = 200, description = "上传成功", body = ApiResponse<UploadResp>),
(status = 413, description = "文件过大"),
(status = 400, description = "无文件或不支持的类型"),
),
tag = "文件上传",
)]
pub async fn upload_file<S>(
State(state): State<AppState>,
Extension(ctx): Extension<TenantContext>,
mut multipart: Multipart,
) -> Result<Json<ApiResponse<UploadResp>>, AppError>
where
AppState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
let max_size = state.config.storage.max_file_size_bytes();
let upload_dir = &state.config.storage.upload_dir;
// 确保上传目录存在
let base_dir = std::path::Path::new(upload_dir);
let tenant_dir = base_dir.join(ctx.tenant_id.to_string());
tokio::fs::create_dir_all(&tenant_dir)
.await
.map_err(|e| AppError::Internal(format!("创建上传目录失败: {}", e)))?;
// 读取第一个 field 作为上传文件
let field = multipart
.next_field()
.await
.map_err(|e| AppError::Validation(format!("读取上传数据失败: {}", e)))?
.ok_or_else(|| AppError::Validation("未找到上传文件".to_string()))?;
let content_type = field
.content_type()
.unwrap_or("application/octet-stream")
.to_string();
// 验证文件类型
validate_content_type(&content_type)?;
let original_name = field.name().unwrap_or("file").to_string();
let data = field
.bytes()
.await
.map_err(|e| AppError::Validation(format!("读取文件数据失败: {}", e)))?;
if data.len() as u64 > max_size {
return Err(AppError::Validation(format!(
"文件大小超过限制(最大 {}",
format_size(max_size)
)));
}
// 校验 magic bytes验证文件实际内容与声明的 Content-Type 一致
validate_magic_bytes(&content_type, &data)?;
// 生成唯一文件名,保留原始扩展名
let ext = std::path::Path::new(&original_name)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("bin");
let file_id = Uuid::now_v7();
let filename = format!("{}.{}", file_id, ext);
let file_path = tenant_dir.join(&filename);
let data_vec: Vec<u8> = data.to_vec();
tokio::fs::write(&file_path, &data_vec)
.await
.map_err(|e| AppError::Internal(format!("写入文件失败: {}", e)))?;
let url = format!("/uploads/{}/{}", ctx.tenant_id, filename);
tracing::info!(
tenant_id = %ctx.tenant_id,
filename = %filename,
size = data_vec.len(),
content_type = %content_type,
"文件上传成功"
);
Ok(Json(ApiResponse::ok(UploadResp {
url,
filename: original_name,
size: data_vec.len() as u64,
content_type,
})))
}
/// 允许的文件类型
fn validate_content_type(content_type: &str) -> Result<(), AppError> {
const ALLOWED: &[&str] = &[
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
];
if !ALLOWED.contains(&content_type) {
return Err(AppError::Validation(format!(
"不支持的文件类型: {}",
content_type
)));
}
Ok(())
}
/// 校验文件 magic bytes文件签名与声明的 Content-Type 是否一致。
///
/// 防止攻击者通过修改 Content-Type 头上传恶意文件。
/// 对于 Office 格式等复杂签名,跳过 magic bytes 校验(仅依赖白名单)。
fn validate_magic_bytes(content_type: &str, data: &[u8]) -> Result<(), AppError> {
// 需要至少几个字节才能校验
if data.is_empty() {
return Err(AppError::Validation("文件内容为空".to_string()));
}
let signature: &[u8] = match content_type {
"image/jpeg" => {
// JPEG: FF D8 FF
b"\xFF\xD8\xFF"
}
"image/png" => {
// PNG: 89 50 4E 47 0D 0A 1A 0A
b"\x89PNG\r\n\x1A\n"
}
"image/gif" => {
// GIF: 47 49 46 38 (GIF8)
b"GIF8"
}
"image/webp" => {
// WebP: RIFF....WEBP (12 bytes)
// 前 4 字节: 52 49 46 46 (RIFF)
// 字节 8-11: 57 45 42 50 (WEBP)
if data.len() < 12 {
return Err(AppError::Validation(
"文件数据不足,无法验证 WebP 格式".to_string(),
));
}
let riff_ok = &data[0..4] == b"RIFF";
let webp_ok = &data[8..12] == b"WEBP";
if riff_ok && webp_ok {
return Ok(());
}
return Err(AppError::Validation(
"文件内容与声明的类型 (image/webp) 不匹配".to_string(),
));
}
"application/pdf" => {
// PDF: 25 50 44 46 (%PDF)
b"%PDF"
}
// Office 格式的 magic bytes 较复杂OLE2 / ZIP-based OOXML
// 仅依赖白名单,跳过 magic bytes 校验
"application/msword"
| "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
| "application/vnd.ms-excel"
| "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => {
return Ok(());
}
_ => return Ok(()),
};
if data.len() < signature.len() {
return Err(AppError::Validation(
"文件数据不足,无法验证文件格式".to_string(),
));
}
if &data[..signature.len()] != signature {
return Err(AppError::Validation(format!(
"文件内容与声明的类型 ({}) 不匹配",
content_type
)));
}
Ok(())
}
fn format_size(bytes: u64) -> String {
if bytes >= 1024 * 1024 * 1024 {
format!("{}GB", bytes / (1024 * 1024 * 1024))
} else if bytes >= 1024 * 1024 {
format!("{}MB", bytes / (1024 * 1024))
} else {
format!("{}KB", bytes / 1024)
}
}

View File

@@ -0,0 +1,820 @@
mod config;
mod db;
mod handlers;
mod middleware;
mod outbox;
mod state;
mod tasks;
/// OpenAPI 规范定义 — 通过 utoipa derive 合并各模块 schema。
#[derive(OpenApi)]
#[openapi(info(
title = "ERP Platform API",
version = "0.1.0",
description = "ERP 平台底座 REST API 文档"
))]
struct ApiDoc;
/// Auth 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_auth::handler::auth_handler::login,
erp_auth::handler::auth_handler::refresh,
erp_auth::handler::auth_handler::logout,
erp_auth::handler::auth_handler::change_password,
erp_auth::handler::user_handler::list_users,
erp_auth::handler::user_handler::create_user,
erp_auth::handler::user_handler::get_user,
erp_auth::handler::user_handler::update_user,
erp_auth::handler::user_handler::delete_user,
erp_auth::handler::user_handler::assign_roles,
erp_auth::handler::role_handler::list_roles,
erp_auth::handler::role_handler::create_role,
erp_auth::handler::role_handler::get_role,
erp_auth::handler::role_handler::update_role,
erp_auth::handler::role_handler::delete_role,
erp_auth::handler::role_handler::assign_permissions,
erp_auth::handler::role_handler::get_role_permissions,
erp_auth::handler::role_handler::list_permissions,
),
components(schemas(
erp_auth::dto::LoginReq,
erp_auth::dto::LoginResp,
erp_auth::dto::RefreshReq,
erp_auth::dto::UserResp,
erp_auth::dto::CreateUserReq,
erp_auth::dto::UpdateUserReq,
erp_auth::dto::RoleResp,
erp_auth::dto::CreateRoleReq,
erp_auth::dto::UpdateRoleReq,
erp_auth::dto::PermissionResp,
erp_auth::dto::AssignPermissionsReq,
erp_auth::dto::ChangePasswordReq,
))
)]
struct AuthApiDoc;
/// Config 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_config::handler::dictionary_handler::list_dictionaries,
erp_config::handler::dictionary_handler::create_dictionary,
erp_config::handler::dictionary_handler::update_dictionary,
erp_config::handler::dictionary_handler::delete_dictionary,
erp_config::handler::dictionary_handler::list_items_by_code,
erp_config::handler::dictionary_handler::create_item,
erp_config::handler::dictionary_handler::update_item,
erp_config::handler::menu_handler::get_menus,
erp_config::handler::menu_handler::create_menu,
erp_config::handler::menu_handler::update_menu,
erp_config::handler::menu_handler::delete_menu,
erp_config::handler::numbering_handler::list_numbering_rules,
erp_config::handler::numbering_handler::create_numbering_rule,
erp_config::handler::numbering_handler::update_numbering_rule,
erp_config::handler::numbering_handler::generate_number,
erp_config::handler::numbering_handler::delete_numbering_rule,
erp_config::handler::theme_handler::get_theme,
erp_config::handler::theme_handler::update_theme,
erp_config::handler::language_handler::list_languages,
erp_config::handler::language_handler::update_language,
erp_config::handler::setting_handler::get_setting,
erp_config::handler::setting_handler::update_setting,
erp_config::handler::setting_handler::delete_setting,
),
components(schemas(
erp_config::dto::DictionaryResp,
erp_config::dto::CreateDictionaryReq,
erp_config::dto::UpdateDictionaryReq,
erp_config::dto::DictionaryItemResp,
erp_config::dto::CreateDictionaryItemReq,
erp_config::dto::UpdateDictionaryItemReq,
erp_config::dto::MenuResp,
erp_config::dto::CreateMenuReq,
erp_config::dto::UpdateMenuReq,
erp_config::dto::NumberingRuleResp,
erp_config::dto::CreateNumberingRuleReq,
erp_config::dto::UpdateNumberingRuleReq,
erp_config::dto::ThemeResp,
))
)]
struct ConfigApiDoc;
/// Workflow 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_workflow::handler::definition_handler::list_definitions,
erp_workflow::handler::definition_handler::create_definition,
erp_workflow::handler::definition_handler::get_definition,
erp_workflow::handler::definition_handler::update_definition,
erp_workflow::handler::definition_handler::publish_definition,
erp_workflow::handler::instance_handler::start_instance,
erp_workflow::handler::instance_handler::list_instances,
erp_workflow::handler::instance_handler::get_instance,
erp_workflow::handler::instance_handler::suspend_instance,
erp_workflow::handler::instance_handler::terminate_instance,
erp_workflow::handler::instance_handler::resume_instance,
erp_workflow::handler::task_handler::list_pending_tasks,
erp_workflow::handler::task_handler::list_completed_tasks,
erp_workflow::handler::task_handler::complete_task,
erp_workflow::handler::task_handler::delegate_task,
),
components(schemas(
erp_workflow::dto::ProcessDefinitionResp,
erp_workflow::dto::CreateProcessDefinitionReq,
erp_workflow::dto::UpdateProcessDefinitionReq,
erp_workflow::dto::ProcessInstanceResp,
erp_workflow::dto::StartInstanceReq,
erp_workflow::dto::TaskResp,
erp_workflow::dto::CompleteTaskReq,
erp_workflow::dto::DelegateTaskReq,
))
)]
struct WorkflowApiDoc;
/// Message 模块的 OpenAPI 路径收集
#[derive(OpenApi)]
#[openapi(
paths(
erp_message::handler::message_handler::list_messages,
erp_message::handler::message_handler::unread_count,
erp_message::handler::message_handler::send_message,
erp_message::handler::message_handler::mark_read,
erp_message::handler::message_handler::mark_all_read,
erp_message::handler::message_handler::delete_message,
erp_message::handler::template_handler::list_templates,
erp_message::handler::template_handler::create_template,
erp_message::handler::subscription_handler::update_subscription,
),
components(schemas(
erp_message::dto::MessageResp,
erp_message::dto::SendMessageReq,
erp_message::dto::MessageQuery,
erp_message::dto::UnreadCountResp,
erp_message::dto::MessageTemplateResp,
erp_message::dto::CreateTemplateReq,
erp_message::dto::MessageSubscriptionResp,
erp_message::dto::UpdateSubscriptionReq,
))
)]
struct MessageApiDoc;
use axum::Router;
use axum::middleware as axum_middleware;
use config::AppConfig;
use erp_auth::middleware::jwt_auth_middleware_fn;
use state::AppState;
use tower_http::services::ServeDir;
use tracing_subscriber::EnvFilter;
use utoipa::OpenApi;
use erp_core::events::EventBus;
use erp_core::module::{ErpModule, ModuleContext, ModuleRegistry};
use erp_server_migration::MigratorTrait;
use sea_orm::{ConnectionTrait, FromQueryResult};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load config
let config = AppConfig::load()?;
// ── 安全检查:拒绝默认密钥 ──────────────────────────
if config.jwt.secret == "__MUST_SET_VIA_ENV__" || config.jwt.secret == "change-me-in-production"
{
tracing::error!("JWT 密钥为默认值,拒绝启动。请设置环境变量 ERP__JWT__SECRET");
std::process::exit(1);
}
if config.database.url == "__MUST_SET_VIA_ENV__" {
tracing::error!("数据库 URL 为默认占位值,拒绝启动。请设置环境变量 ERP__DATABASE__URL");
std::process::exit(1);
}
if config.redis.url == "__MUST_SET_VIA_ENV__" {
tracing::error!("Redis URL 为默认占位值,拒绝启动。请设置环境变量 ERP__REDIS__URL");
std::process::exit(1);
}
if !config.wechat.dev_mode
&& (config.wechat.appid == "__MUST_SET_VIA_ENV__"
|| config.wechat.secret == "__MUST_SET_VIA_ENV__")
{
tracing::error!(
"微信凭据为默认占位值,拒绝启动。请设置环境变量 ERP__WECHAT__APPID 和 ERP__WECHAT__SECRET"
);
std::process::exit(1);
}
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log.level)),
)
.json()
.init();
tracing::info!(
version = env!("CARGO_PKG_VERSION"),
"ERP Server starting..."
);
// Connect to database
let db = db::connect(&config.database).await?;
// Run migrations
erp_server_migration::Migrator::up(&db, None).await?;
tracing::info!("Database migrations applied");
// Seed default tenant and auth data if not present, and resolve the actual tenant ID
let default_tenant_id = {
#[derive(sea_orm::FromQueryResult)]
struct TenantId {
id: uuid::Uuid,
}
let existing = TenantId::find_by_statement(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"SELECT id FROM tenant WHERE deleted_at IS NULL LIMIT 1".to_string(),
))
.one(&db)
.await
.map_err(|e| anyhow::anyhow!("Failed to query tenants: {}", e))?;
match existing {
Some(row) => {
tracing::info!(tenant_id = %row.id, "Default tenant already exists, skipping seed");
row.id
}
None => {
let new_tenant_id = uuid::Uuid::now_v7();
// Insert default tenant using raw SQL (no tenant entity in erp-server)
db.execute(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
"INSERT INTO tenant (id, name, code, status, created_at, updated_at) VALUES ($1, $2, $3, $4, NOW(), NOW())",
[
new_tenant_id.into(),
"Default Tenant".into(),
"default".into(),
"active".into(),
],
))
.await
.map_err(|e| anyhow::anyhow!("Failed to create default tenant: {}", e))?;
tracing::info!(tenant_id = %new_tenant_id, "Created default tenant");
// Seed auth data (permissions, roles, admin user)
erp_auth::service::seed::seed_tenant_auth(
&db,
new_tenant_id,
&config.auth.super_admin_password,
)
.await
.map_err(|e| anyhow::anyhow!("Failed to seed auth data: {}", e))?;
tracing::info!(tenant_id = %new_tenant_id, "Default tenant ready with auth seed data");
// Seed AI workflow definitions
if let Err(e) =
erp_workflow::service::ai_workflow_seed::ensure_ai_workflows(&db, new_tenant_id)
.await
{
tracing::warn!(error = %e, "Failed to seed AI workflow definitions");
}
new_tenant_id
}
}
};
// Connect to Redis
let redis_client = redis::Client::open(&config.redis.url[..])?;
tracing::info!("Redis client created");
// Initialize event bus (capacity 1024 events)
let event_bus = EventBus::new(1024);
// Initialize auth module
let auth_module = erp_auth::AuthModule::new();
tracing::info!(
module = auth_module.name(),
version = auth_module.version(),
"Auth module initialized"
);
// Initialize config module
let config_module = erp_config::ConfigModule::new();
tracing::info!(
module = config_module.name(),
version = config_module.version(),
"Config module initialized"
);
// Initialize workflow module
let workflow_module = erp_workflow::WorkflowModule::new();
tracing::info!(
module = workflow_module.name(),
version = workflow_module.version(),
"Workflow module initialized"
);
// Initialize message module
let message_module = erp_message::MessageModule::new();
tracing::info!(
module = message_module.name(),
version = message_module.version(),
"Message module initialized"
);
// Initialize diary module (暖记业务)
let diary_module = erp_diary::DiaryModule;
tracing::info!(
module = diary_module.name(),
version = diary_module.version(),
"Diary module initialized"
);
// Initialize module registry and register modules
let registry = ModuleRegistry::new()
.register(auth_module)
.register(config_module)
.register(workflow_module)
.register(message_module)
.register(diary_module);
tracing::info!(
module_count = registry.modules().len(),
"Modules registered"
);
// Initialize plugin engine
let plugin_config = erp_plugin::engine::PluginEngineConfig::default();
let plugin_engine =
erp_plugin::engine::PluginEngine::new(db.clone(), event_bus.clone(), plugin_config)?;
tracing::info!("Plugin engine initialized");
// Register plugin module
let plugin_module = erp_plugin::module::PluginModule;
let registry = registry.register(plugin_module);
// Register event handlers
registry.register_handlers(&event_bus);
// Startup all modules (按拓扑顺序调用 on_startup)
let module_ctx = ModuleContext {
db: db.clone(),
event_bus: event_bus.clone(),
};
registry.startup_all(&module_ctx).await?;
tracing::info!("All modules started");
// 同步所有模块声明的权限到数据库upsert
sync_module_permissions(&db, &registry, default_tenant_id).await?;
// 恢复运行中的插件(服务器重启后自动重新加载)
match plugin_engine.recover_plugins(&db).await {
Ok(recovered) => {
let count: usize = recovered.len();
tracing::info!(count, "Plugins recovered");
}
Err(e) => {
tracing::error!(error = %e, "Failed to recover plugins");
}
}
// Start message event listener (workflow events → message notifications)
erp_message::MessageModule::start_event_listener(db.clone(), event_bus.clone());
tracing::info!("Message event listener started");
// Start plugin notification listener (plugin.trigger.* → admin notifications)
erp_plugin::notification::start_notification_listener(db.clone(), event_bus.clone());
tracing::info!("Plugin notification listener started");
// Start outbox relay (LISTEN/NOTIFY + fallback poll for pending domain events)
outbox::start_outbox_relay(db.clone(), event_bus.clone(), config.database.url.clone());
tracing::info!("Outbox relay started");
// Start timeout checker (scan overdue tasks every 60s)
erp_workflow::WorkflowModule::start_timeout_checker(db.clone(), event_bus.clone());
tracing::info!("Timeout checker started");
let host = config.server.host.clone();
let port = config.server.port;
// Extract JWT secret for middleware construction
let jwt_secret = config.jwt.secret.clone();
// Build PII crypto — used by auth module for token encryption
let pii_crypto = if config.crypto.kek == "__MUST_SET_VIA_ENV__" {
#[cfg(debug_assertions)]
{
tracing::warn!("⚠️ PII KEK 使用开发默认值,仅用于本地开发");
erp_core::crypto::PiiCrypto::dev_default()
}
#[cfg(not(debug_assertions))]
{
panic!(
"ERP__CRYPTO__KEK must be set in production. Use a 64-char hex string (32 bytes)."
);
}
} else {
erp_core::crypto::PiiCrypto::from_kek_hex(&config.crypto.kek)
.expect("PII KEK must be valid 64-char hex (32 bytes). Set ERP__CRYPTO__KEK")
};
// Build shared state
let cron_heartbeat = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
));
let state = AppState {
db,
config,
event_bus,
module_registry: registry,
redis: redis_client.clone(),
default_tenant_id,
plugin_engine,
plugin_entity_cache: moka::sync::Cache::builder()
.max_capacity(1000)
.time_to_idle(std::time::Duration::from_secs(300))
.build(),
pii_crypto,
cron_heartbeat: cron_heartbeat.clone(),
};
// Start background tasks with heartbeat
tasks::start_event_cleanup(state.db.clone(), state.cron_heartbeat.clone());
tasks::start_pool_metrics(state.db.clone(), state.cron_heartbeat.clone());
// --- Build the router ---
//
// The router is split into two layers:
// 1. Public routes: no JWT required (health, login, refresh)
// 2. Protected routes: JWT required (user CRUD, logout)
//
// Both layers share the same AppState. The protected layer wraps routes
// with the jwt_auth_middleware_fn.
// Public routes (no authentication, but IP-based rate limiting)
// Layer execution order (outer → inner): account_lockout → rate_limit_by_ip
// So account lockout check runs FIRST, then IP rate limiting
let public_routes = Router::new()
.merge(erp_auth::AuthModule::public_routes())
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::account_lockout_middleware,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::rate_limit_by_ip,
))
.with_state(state.clone());
// Refresh token routes — higher rate limit (30/min) than login (5/min)
let refresh_routes = Router::new()
.merge(erp_auth::AuthModule::refresh_routes())
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::rate_limit_refresh_by_ip,
))
.with_state(state.clone());
// Unthrottled public routes (health, docs, brand) — no rate limiting
let unthrottled_routes = Router::new()
.merge(handlers::health::health_check_router())
.route(
"/docs/openapi.json",
axum::routing::get(handlers::openapi::openapi_spec),
)
.merge(erp_config::ConfigModule::public_routes())
.with_state(state.clone());
// Clone jwt_secret for upload auth before protected_routes closure moves it
let secret_for_uploads = jwt_secret.clone();
// Protected routes (JWT authentication required)
// User-based rate limiting (100 req/min) applied after JWT auth
let protected_routes = erp_auth::AuthModule::protected_routes()
.merge(erp_config::ConfigModule::protected_routes())
.merge(erp_workflow::WorkflowModule::protected_routes())
.merge(erp_message::MessageModule::protected_routes())
.merge(erp_plugin::module::PluginModule::protected_routes())
.merge(erp_diary::DiaryModule::protected_routes())
.merge(handlers::audit_log::audit_log_router())
.route(
"/upload",
axum::routing::post(handlers::upload::upload_file),
)
.route(
"/admin/tenants/{id}/rotate-key",
axum::routing::post(handlers::crypto_admin::rotate_tenant_key),
)
.route(
"/analytics/batch",
axum::routing::post(handlers::analytics::batch),
)
.layer(axum::middleware::from_fn(
middleware::frozen_module::frozen_module_middleware,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::rate_limit::rate_limit_by_user,
))
.layer({
let db = state.db.clone();
let jwt_secret_for_auth = jwt_secret.clone();
axum_middleware::from_fn(move |req, next| {
let secret = jwt_secret_for_auth.clone();
let db = db.clone();
async move { jwt_auth_middleware_fn(secret, Some(db), req, next).await }
})
})
// Tenant RLS — 在 JWT 之后执行SET app.current_tenant_id
.layer({
let db = state.db.clone();
axum_middleware::from_fn(move |req, next| {
let db = db.clone();
async move { middleware::tenant_rls::tenant_rls_middleware(db, req, next).await }
})
})
.with_state(state.clone());
// Merge public + protected into the final application router
// All API routes are nested under /api/v1
let cors = build_cors_layer(&state.config.cors.allowed_origins);
let upload_dir = state.config.storage.upload_dir.clone();
let uploads_router = Router::new()
.fallback_service(ServeDir::new(&upload_dir))
.layer(axum_middleware::from_fn(move |req, next| {
let secret = secret_for_uploads.clone();
async move { upload_auth_middleware(secret, req, next).await }
}));
let app = Router::new()
.nest(
"/api/v1",
unthrottled_routes
.merge(public_routes)
.merge(refresh_routes)
.merge(protected_routes),
)
.nest("/uploads", uploads_router)
.layer(axum::middleware::from_fn(
middleware::metrics::metrics_middleware,
))
.layer(cors)
.layer(axum::middleware::from_fn(security_headers_middleware));
// Start Prometheus metrics exporter on a separate port
let metrics_port = state.config.server.metrics_port;
middleware::metrics::start_metrics_server(metrics_port);
let addr = format!("{}:{}", host, port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!(addr = %addr, "Server listening");
// Graceful shutdown on CTRL+C
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
// 优雅关闭所有模块(按拓扑逆序)
state.module_registry.shutdown_all().await?;
tracing::info!("Server shutdown complete");
Ok(())
}
/// JWT auth middleware for `/uploads` file serving.
///
/// Accepts token from either `Authorization: Bearer <token>` header
/// or `?token=<token>` query parameter (for browser `<img>` / direct downloads).
async fn upload_auth_middleware(
jwt_secret: String,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Result<axum::response::Response, erp_core::error::AppError> {
use erp_auth::service::token_service::TokenService;
let token = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|s| s.to_string())
.or_else(|| {
req.uri().query().and_then(|q| {
q.split('&').find_map(|pair| {
let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
if k == "token" && !v.is_empty() {
Some(v.to_string())
} else {
None
}
})
})
});
let token = token.ok_or(erp_core::error::AppError::Unauthorized)?;
let claims = TokenService::decode_token(&token, &jwt_secret)
.map_err(|_| erp_core::error::AppError::Unauthorized)?;
if claims.token_type != "access" {
return Err(erp_core::error::AppError::Unauthorized);
}
Ok(next.run(req).await)
}
/// Build a CORS layer from the comma-separated allowed origins config.
///
/// If the config is "*", allows all origins (development mode).
/// Otherwise, parses each origin as a URL and restricts to those origins only.
fn build_cors_layer(allowed_origins: &str) -> tower_http::cors::CorsLayer {
use axum::http::HeaderValue;
use tower_http::cors::AllowOrigin;
let origins = allowed_origins
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>();
if origins.len() == 1 && origins[0] == "*" {
#[cfg(not(debug_assertions))]
{
tracing::error!("CORS wildcard '*' is not allowed in production builds");
panic!(
"Refusing to start with CORS wildcard in release mode. Set ERP__CORS__ALLOWED_ORIGINS to specific domains."
);
}
#[cfg(debug_assertions)]
{
tracing::warn!(
"⚠️ CORS 允许所有来源 — 仅限开发环境使用!\
生产环境请通过 ERP__CORS__ALLOWED_ORIGINS 设置具体的来源域名"
);
return tower_http::cors::CorsLayer::permissive();
}
}
let allowed: Vec<HeaderValue> = origins
.iter()
.filter_map(|o| o.parse::<HeaderValue>().ok())
.collect();
tracing::info!(origins = ?origins, "CORS: restricting to allowed origins");
tower_http::cors::CorsLayer::new()
.allow_origin(AllowOrigin::list(allowed))
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::DELETE,
axum::http::Method::PATCH,
])
.allow_headers([
axum::http::header::AUTHORIZATION,
axum::http::header::CONTENT_TYPE,
])
.allow_credentials(true)
}
async fn security_headers_middleware(
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
use axum::http::{HeaderValue, header};
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
headers.insert(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
);
headers.insert(
header::HeaderName::from_static("x-xss-protection"),
HeaderValue::from_static("1; mode=block"),
);
headers.insert(
header::HeaderName::from_static("referrer-policy"),
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=63072000; includeSubDomains; preload"),
);
headers.insert(
header::HeaderName::from_static("content-security-policy"),
HeaderValue::from_static(
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; \
img-src 'self' data: blob: https:; connect-src 'self' wss:; \
frame-ancestors 'none'; base-uri 'self'; form-action 'self'",
),
);
headers.insert(
header::HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("camera=(), microphone=(), geolocation=(), payment=()"),
);
response
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received CTRL+C, shutting down gracefully...");
},
_ = terminate => {
tracing::info!("Received SIGTERM, shutting down gracefully...");
},
}
}
/// 同步所有模块声明的权限到数据库。
///
/// 对每个模块的 `permissions()` 返回的权限执行 upsert
/// - 新权限INSERT
/// - 已有权限(同 tenant_id + code跳过
///
/// 同时将新权限分配给 admin 角色。
async fn sync_module_permissions(
db: &sea_orm::DatabaseConnection,
registry: &erp_core::module::ModuleRegistry,
tenant_id: uuid::Uuid,
) -> Result<(), anyhow::Error> {
let system_user_id = uuid::Uuid::nil();
let mut total_new = 0u32;
for module in registry.modules() {
let perms = module.permissions();
if perms.is_empty() {
continue;
}
for perm in perms {
let result = db.execute(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
r#"INSERT INTO permissions (id, tenant_id, code, name, resource, action, description, created_at, updated_at, created_by, updated_by, deleted_at, version)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW(), $8, $8, NULL, 1)
ON CONFLICT (tenant_id, code) WHERE deleted_at IS NULL DO NOTHING"#,
[
uuid::Uuid::now_v7().into(),
tenant_id.into(),
perm.code.clone().into(),
perm.name.clone().into(),
perm.module.clone().into(),
perm.code.split('.').next_back().unwrap_or("manage").into(),
perm.description.clone().into(),
system_user_id.into(),
],
)).await?;
let rows = result.rows_affected();
if rows > 0 {
total_new += 1;
}
}
}
// 每次启动都确保 admin 角色拥有所有模块权限(防止权限-角色关联缺失)
db.execute(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
r#"INSERT INTO role_permissions (role_id, permission_id, tenant_id, data_scope, created_at, updated_at, created_by, updated_by, deleted_at, version)
SELECT r.id, p.id, p.tenant_id, 'all', NOW(), NOW(), $1, $1, NULL, 1
FROM permissions p
JOIN roles r ON r.code = 'admin' AND r.tenant_id = p.tenant_id AND r.deleted_at IS NULL
WHERE p.tenant_id = $2
ON CONFLICT DO NOTHING"#,
[system_user_id.into(), tenant_id.into()],
)).await?;
if total_new > 0 {
tracing::info!(
total_new,
"New module permissions synced and bound to admin role"
);
}
Ok(())
}

View File

@@ -0,0 +1,37 @@
use axum::Json;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
/// 冻结模块路径前缀列表。
///
/// 这些模块前端已通过 FROZEN_ROUTES 守卫拦截,后端也需同步拦截,
/// 防止直接调 API 绕过限制。
const FROZEN_PREFIXES: &[&str] = &[
"/api/v1/health/care-plans",
"/api/v1/health/shifts",
"/api/v1/health/family-proxy",
"/api/v1/health/medications",
"/api/v1/health/dialysis",
"/api/v1/health/schedules",
];
pub async fn frozen_module_middleware(req: Request<Body>, next: Next) -> Response {
let path = req.uri().path();
for prefix in FROZEN_PREFIXES {
if path.starts_with(prefix) {
return (
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"success": false,
"error": "该功能正在优化中,暂不可用"
})),
)
.into_response();
}
}
next.run(req).await
}

View File

@@ -0,0 +1,126 @@
use axum::extract::Request;
use axum::http::Method;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use metrics::{counter, histogram};
use std::time::Instant;
/// HTTP 请求指标中间件。
///
/// 记录两个 Prometheus 指标:
/// - `http_requests_total` — 计数器,标签: method, path, status
/// - `http_request_duration_seconds` — 直方图,标签: method, path, status
pub async fn metrics_middleware(req: Request, next: Next) -> Response {
let method = method_label(req.method());
let path = path_label(req.uri().path());
let start = Instant::now();
let resp = next.run(req).await;
let elapsed = start.elapsed();
let status = resp.status().as_u16().to_string();
let labels = [
("method", method.clone()),
("path", path.clone()),
("status", status.clone()),
];
counter!("http_requests_total", &labels).increment(1);
histogram!("http_request_duration_seconds", &labels).record(elapsed.as_secs_f64());
resp
}
fn method_label(method: &Method) -> String {
method.as_str().to_owned()
}
/// 归一化路径:将 UUID 段替换为 `:id`,避免高基数。
fn path_label(path: &str) -> String {
let parts: Vec<&str> = path
.split('/')
.filter(|s| !s.is_empty())
.map(|s| if looks_like_uuid(s) { ":id" } else { s })
.collect();
if parts.is_empty() {
"/".to_string()
} else {
format!("/{}", parts.join("/"))
}
}
fn looks_like_uuid(s: &str) -> bool {
s.len() == 36
&& s.chars().filter(|c| *c == '-').count() == 4
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
}
/// 在独立端口启动 Prometheus exporter。
pub fn start_metrics_server(port: u16) {
let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
let recorder = builder.build_recorder();
let handle = recorder.handle();
if let Err(e) = metrics::set_global_recorder(recorder) {
tracing::error!(error = %e, "Failed to install Prometheus recorder");
return;
}
tokio::spawn(async move {
let app = axum::Router::new()
.route(
"/metrics",
axum::routing::get(move || {
let handle = handle.clone();
async move {
let body = handle.render();
axum::response::IntoResponse::into_response((
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4",
)],
body,
))
}
}),
)
.fallback(|| async { axum::http::StatusCode::NOT_FOUND.into_response() as Response });
let addr = format!("0.0.0.0:{port}");
match tokio::net::TcpListener::bind(&addr).await {
Ok(listener) => {
tracing::info!(addr = %addr, "Prometheus metrics server listening");
if let Err(e) = axum::serve(listener, app).await {
tracing::error!(error = %e, "Metrics server error");
}
}
Err(e) => {
tracing::error!(error = %e, addr = %addr, "Failed to bind metrics server");
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn path_label_normalizes_uuids() {
assert_eq!(path_label("/api/v1/users"), "/api/v1/users");
assert_eq!(
path_label("/api/v1/users/01234567-89ab-cdef-0123-456789abcdef/posts"),
"/api/v1/users/:id/posts"
);
assert_eq!(path_label("/"), "/");
assert_eq!(path_label(""), "/");
}
#[test]
fn is_uuid_checks_format() {
assert!(looks_like_uuid("01234567-89ab-cdef-0123-456789abcdef"));
assert!(!looks_like_uuid("not-a-uuid"));
assert!(!looks_like_uuid("short"));
}
}

View File

@@ -0,0 +1,4 @@
pub mod frozen_module;
pub mod metrics;
pub mod rate_limit;
pub mod tenant_rls;

View File

@@ -0,0 +1,326 @@
use axum::body::Body;
use axum::extract::State;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use redis::AsyncCommands;
use serde::Serialize;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::state::AppState;
/// Redis 连接失败时间戳缓存毫秒5 秒内复用失败状态避免重复连接尝试
static REDIS_LAST_FAIL_MS: AtomicU64 = AtomicU64::new(0);
const REDIS_FAIL_CACHE_SECS: u64 = 5;
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn is_redis_cached_failed() -> bool {
let last = REDIS_LAST_FAIL_MS.load(Ordering::Relaxed);
last > 0 && now_ms().saturating_sub(last) < REDIS_FAIL_CACHE_SECS * 1000
}
fn mark_redis_failed() {
REDIS_LAST_FAIL_MS.store(now_ms(), Ordering::Relaxed);
}
/// 限流错误响应。
#[derive(Serialize)]
struct RateLimitResponse {
error: String,
message: String,
}
/// 账户锁定配置。
const ACCOUNT_LOCKOUT_MAX_FAILURES: i64 = 5;
const ACCOUNT_LOCKOUT_TTL_SECS: i64 = 900; // 15 分钟
/// 基于 Redis 的 IP 限流中间件登录等敏感操作5 次/分钟)。
pub async fn rate_limit_by_ip(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = extract_client_ip(req.headers());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 5,
window_secs: 60,
prefix: "login",
},
&identifier,
req,
next,
)
.await
}
/// 基于 Redis 的 IP 限流中间件Token 刷新30 次/分钟)。
pub async fn rate_limit_refresh_by_ip(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = extract_client_ip(req.headers());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 30,
window_secs: 60,
prefix: "refresh",
},
&identifier,
req,
next,
)
.await
}
/// 基于 Redis 的用户限流中间件。
///
/// 从 TenantContext 中读取 user_id 作为标识符。
pub async fn rate_limit_by_user(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let identifier = req
.extensions()
.get::<erp_core::types::TenantContext>()
.map(|ctx| ctx.user_id.to_string())
.unwrap_or_else(|| "anonymous".to_string());
let fail_close = state.config.rate_limit.fail_close;
apply_rate_limit(
RateLimitParams {
redis_client: &state.redis,
fail_close,
max_requests: 300,
window_secs: 60,
prefix: "api",
},
&identifier,
req,
next,
)
.await
}
/// Redis 不可达时的安全响应fail-close 模式)。
fn service_unavailable(prefix: &str) -> Response {
let body = RateLimitResponse {
error: "Service Unavailable".to_string(),
message: "服务暂时不可用,请稍后重试".to_string(),
};
tracing::error!("Redis 不可达fail-close 模式拒绝请求 [{}]", prefix);
(StatusCode::SERVICE_UNAVAILABLE, axum::Json(body)).into_response()
}
/// 限流参数,打包以避免函数签名过长。
struct RateLimitParams<'a> {
redis_client: &'a redis::Client,
fail_close: bool,
max_requests: u64,
window_secs: u64,
prefix: &'a str,
}
/// 执行限流检查。
async fn apply_rate_limit(
params: RateLimitParams<'_>,
identifier: &str,
req: Request<Body>,
next: Next,
) -> Response {
// 快速路径Redis 在缓存期内已知不可用,跳过连接尝试
if is_redis_cached_failed() {
if params.fail_close {
return service_unavailable(params.prefix);
}
return next.run(req).await;
}
let key = format!("rate_limit:{}:{}", params.prefix, identifier);
let mut conn = match params.redis_client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
mark_redis_failed();
tracing::warn!(error = %e, "Redis 连接失败 [{}]{}秒内不再重试)", params.prefix, REDIS_FAIL_CACHE_SECS);
if params.fail_close {
return service_unavailable(params.prefix);
}
return next.run(req).await;
}
};
let count: i64 = match redis::cmd("INCR").arg(&key).query_async(&mut conn).await {
Ok(n) => n,
Err(e) => {
mark_redis_failed();
tracing::warn!(error = %e, "Redis INCR 失败 [{}]", params.prefix);
if params.fail_close {
return service_unavailable(params.prefix);
}
return next.run(req).await;
}
};
// 首次请求设置 TTL
if count == 1 {
let _: Result<(), _> = conn.expire(&key, params.window_secs as i64).await;
}
if count > params.max_requests as i64 {
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "请求过于频繁,请稍后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
next.run(req).await
}
/// 账户级登录锁定中间件。
///
/// 针对登录接口POST /api/v1/auth/login在 IP 限流之前执行:
/// 1. 解析请求体提取 username
/// 2. 检查 Redis 中该 username 的失败次数
/// 3. 超过阈值5次则拒绝请求
/// 4. 观察响应状态码401 递增失败计数200 清除计数
pub async fn account_lockout_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response {
let fail_close = state.config.rate_limit.fail_close;
// 获取 Redis 连接
let mut conn = match state.redis.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
mark_redis_failed();
tracing::warn!(error = %e, "Redis 连接失败 [login_lockout]");
if fail_close {
return service_unavailable("login_lockout");
}
return next.run(req).await;
}
};
// 读取请求体以提取 username
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, 1024).await {
Ok(b) => b,
Err(e) => {
tracing::warn!(error = %e, "读取登录请求体失败,放行");
let req = Request::from_parts(parts, Body::from(Vec::new()));
return next.run(req).await;
}
};
// 解析 username
let username = serde_json::from_slice::<serde_json::Value>(&bytes)
.ok()
.and_then(|v| v.get("username")?.as_str().map(|s| s.to_string()));
let username = match username {
Some(u) if !u.is_empty() => u,
_ => {
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
return next.run(req).await;
}
};
// 检查账户锁定状态
let lockout_key = format!("login_fail:{}", username);
let fail_count: i64 = conn.get(&lockout_key).await.unwrap_or(0);
if fail_count >= ACCOUNT_LOCKOUT_MAX_FAILURES {
tracing::warn!(
username = %username,
fail_count = fail_count,
"账户已被临时锁定"
);
let body = RateLimitResponse {
error: "Too Many Requests".to_string(),
message: "账户已被临时锁定请15分钟后重试".to_string(),
};
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
}
// 用原始 body 重建请求,转发到 handler
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
let response = next.run(req).await;
// 观察响应状态码
let status = response.status();
let (parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, 1024 * 1024)
.await
.unwrap_or_default();
if status == StatusCode::UNAUTHORIZED {
// 登录失败:递增失败计数
let new_count: i64 = match redis::cmd("INCR")
.arg(&lockout_key)
.query_async(&mut conn)
.await
{
Ok(n) => n,
Err(e) => {
tracing::warn!(error = %e, "Redis INCR 失败计数失败");
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
return resp;
}
};
// 首次失败时设置 TTL
if new_count == 1 {
let _: Result<(), _> = conn.expire(&lockout_key, ACCOUNT_LOCKOUT_TTL_SECS).await;
}
tracing::info!(
username = %username,
fail_count = new_count,
remaining = ACCOUNT_LOCKOUT_MAX_FAILURES - new_count,
"登录失败,递增失败计数"
);
} else if status.is_success() {
// 登录成功:清除失败计数
let _: Result<(), _> = conn.del(&lockout_key).await;
tracing::info!(username = %username, "登录成功,清除失败计数");
}
// 重建并返回原始响应
Response::from_parts(parts, Body::from(body_bytes.to_vec()))
}
/// 从请求头中提取客户端 IP。
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
headers
.get("x-forwarded-for")
.or_else(|| headers.get("x-real-ip"))
.and_then(|v| v.to_str().ok())
.map(|s| {
// x-forwarded-for 可能包含多个 IP取第一个
s.split(',').next().unwrap_or(s).trim().to_string()
})
.unwrap_or_else(|| "unknown".to_string())
}
// NOTE: rate_limit_by_gateway was removed during base extraction.
// It depended on erp_health::gateway_auth::GatewayAuthContext.
// Projects needing gateway rate limiting should add it in their own middleware.

View File

@@ -0,0 +1,50 @@
use axum::body::Body;
use axum::http::Request;
use axum::middleware::Next;
use axum::response::Response;
use erp_core::types::TenantContext;
use sea_orm::{ConnectionTrait, DatabaseBackend, Statement};
/// Tenant RLS 中间件。
///
/// 从 request extensions 中提取 `TenantContext`,在数据库连接上设置
/// `app.current_tenant_id`,使 PostgreSQL RLS 策略自动按租户过滤。
///
/// 请求处理完成后自动 RESET 设置,防止连接池复用时泄漏。
///
/// SET 失败时仅 warn 不阻断请求RLS 是安全网,主隔离仍在应用层)。
pub async fn tenant_rls_middleware(
db: sea_orm::DatabaseConnection,
req: Request<Body>,
next: Next,
) -> Response {
let tenant_id = req
.extensions()
.get::<TenantContext>()
.map(|ctx| ctx.tenant_id);
if let Some(tid) = tenant_id {
// SET app.current_tenant_id — RLS 策略读取此值(参数化查询防止注入)
if let Err(e) = db
.execute(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"SET app.current_tenant_id = $1",
[tid.into()],
))
.await
{
tracing::warn!(tenant_id = %tid, error = %e, "SET app.current_tenant_id 失败RLS 未激活)");
}
}
let response = next.run(req).await;
// RESET — 防止连接池复用时泄漏租户上下文
if tenant_id.is_some()
&& let Err(e) = db.execute_unprepared("RESET app.current_tenant_id").await
{
tracing::debug!(error = %e, "RESET app.current_tenant_id 失败(非致命)");
}
response
}

View File

@@ -0,0 +1,137 @@
use chrono::Utc;
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect, Set,
};
use sqlx::postgres::PgListener;
use std::time::Duration;
use erp_core::entity::domain_event;
use erp_core::events::{DomainEvent, EventBus};
const MAX_RETRY: i32 = 5;
const FALLBACK_POLL_INTERVAL_SECS: u64 = 30;
const NOTIFY_CHANNEL: &str = "outbox_channel";
const RECONNECT_DELAY_SECS: u64 = 5;
/// 启动 outbox relay 后台任务。
///
/// 先执行一次性扫描(处理服务重启前遗留的 pending 事件),
/// 然后通过 PostgreSQL LISTEN/NOTIFY 监听新事件,配合 30s 兜底轮询。
pub fn start_outbox_relay(
db: sea_orm::DatabaseConnection,
event_bus: EventBus,
database_url: String,
) {
let db_clone = db.clone();
let event_bus_clone = event_bus.clone();
let url = database_url.clone();
tokio::spawn(async move {
// 启动时立即处理一次(恢复重启前未广播的事件)
match process_pending_events(&db_clone, &event_bus_clone).await {
Ok(count) if count > 0 => tracing::info!(count = count, "启动时 outbox relay 恢复完成"),
Ok(_) => tracing::info!("启动时 outbox relay 无待处理事件"),
Err(e) => tracing::warn!(error = %e, "启动时 outbox relay 处理失败"),
}
// 进入 LISTEN/NOTIFY 主循环(带自动重连)
loop {
if let Err(e) = run_listener(&db_clone, &event_bus_clone, &url).await {
tracing::warn!(error = %e, "PgListener 断开连接,{}s 后重连", RECONNECT_DELAY_SECS);
}
tokio::time::sleep(Duration::from_secs(RECONNECT_DELAY_SECS)).await;
// 重连后执行一次兜底扫描
if let Err(e) = process_pending_events(&db_clone, &event_bus_clone).await {
tracing::warn!(error = %e, "重连后 outbox relay 处理失败");
}
}
});
}
/// 运行 PgListener 监听循环。
///
/// 使用 `tokio::select!` 在 LISTEN 通知和 30s 定时器之间竞争,
/// 确保即使 NOTIFY 丢失也能兜底处理。
async fn run_listener(
db: &sea_orm::DatabaseConnection,
event_bus: &EventBus,
database_url: &str,
) -> Result<(), sqlx::Error> {
let mut listener = PgListener::connect(database_url).await?;
listener.listen(NOTIFY_CHANNEL).await?;
tracing::info!("Outbox relay LISTEN/NOTIFY 已连接,监听 {}", NOTIFY_CHANNEL);
let mut fallback = tokio::time::interval(Duration::from_secs(FALLBACK_POLL_INTERVAL_SECS));
loop {
tokio::select! {
// LISTEN/NOTIFY 通知触发
notification = listener.recv() => {
match notification {
Ok(notif) => {
tracing::debug!(
channel = %notif.channel(),
payload = %notif.payload(),
"收到 outbox NOTIFY"
);
if let Err(e) = process_pending_events(db, event_bus).await {
tracing::warn!(error = %e, "NOTIFY 触发的 outbox 处理失败");
}
}
Err(e) => return Err(e),
}
}
// 30s 兜底轮询
_ = fallback.tick() => {
tracing::debug!("outbox relay 兜底轮询触发");
if let Err(e) = process_pending_events(db, event_bus).await {
tracing::warn!(error = %e, "兜底轮询 outbox 处理失败");
}
}
}
}
}
async fn process_pending_events(
db: &sea_orm::DatabaseConnection,
event_bus: &EventBus,
) -> Result<usize, sea_orm::DbErr> {
let pending = domain_event::Entity::find()
.filter(domain_event::Column::Status.eq("pending"))
.filter(domain_event::Column::Attempts.lt(MAX_RETRY))
.order_by_asc(domain_event::Column::CreatedAt)
.limit(100)
.all(db)
.await?;
if pending.is_empty() {
return Ok(0);
}
let count = pending.len();
tracing::info!(count = count, "处理待发领域事件");
for event_model in pending {
// 重建 DomainEvent 并广播(保留原始 ID 和时间戳)
let domain_event = DomainEvent {
id: event_model.id,
event_type: event_model.event_type.clone(),
tenant_id: event_model.tenant_id,
payload: event_model.payload.clone().unwrap_or(serde_json::json!({})),
timestamp: event_model.created_at,
correlation_id: event_model.correlation_id.unwrap_or(event_model.id),
};
event_bus.broadcast(domain_event);
// 标记为 published增加 attempts 计数
let mut active: domain_event::ActiveModel = event_model.into();
active.status = Set("published".to_string());
active.published_at = Set(Some(Utc::now()));
active.attempts = Set(erp_core::sea_orm_ext::bump_version(&active.attempts));
active.update(db).await?;
}
Ok(count)
}

View File

@@ -0,0 +1,121 @@
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use axum::extract::FromRef;
use sea_orm::DatabaseConnection;
use crate::config::AppConfig;
use erp_core::events::EventBus;
use erp_core::module::ModuleRegistry;
/// Axum shared application state.
/// All handlers access database connections, configuration, etc. through `State<AppState>`.
#[derive(Clone)]
pub struct AppState {
pub db: DatabaseConnection,
pub config: AppConfig,
pub event_bus: EventBus,
pub module_registry: ModuleRegistry,
pub redis: redis::Client,
/// 实际的默认租户 ID从数据库种子数据中获取。
pub default_tenant_id: uuid::Uuid,
/// 插件引擎
pub plugin_engine: erp_plugin::engine::PluginEngine,
/// 插件实体缓存
pub plugin_entity_cache: moka::sync::Cache<String, erp_plugin::state::EntityInfo>,
/// PII 加密服务KEK + DEK 管理)
pub pii_crypto: erp_core::crypto::PiiCrypto,
/// 定时任务心跳unix timestamp secs每个 cron tick 更新
pub cron_heartbeat: Arc<AtomicU64>,
}
/// Allow handlers to extract `DatabaseConnection` directly from `State<AppState>`.
impl FromRef<AppState> for DatabaseConnection {
fn from_ref(state: &AppState) -> Self {
state.db.clone()
}
}
/// Allow handlers to extract `EventBus` directly from `State<AppState>`.
impl FromRef<AppState> for EventBus {
fn from_ref(state: &AppState) -> Self {
state.event_bus.clone()
}
}
/// Allow erp-auth handlers to extract their required state without depending on erp-server.
///
/// This bridges the gap: erp-auth defines `AuthState` with the fields it needs,
/// and erp-server fills them from `AppState`.
impl FromRef<AppState> for erp_auth::AuthState {
fn from_ref(state: &AppState) -> Self {
use erp_auth::auth_state::parse_ttl;
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
jwt_secret: state.config.jwt.secret.clone(),
access_ttl_secs: parse_ttl(&state.config.jwt.access_token_ttl),
refresh_ttl_secs: parse_ttl(&state.config.jwt.refresh_token_ttl),
default_tenant_id: state.default_tenant_id,
wechat_appid: state.config.wechat.appid.clone(),
wechat_secret: state.config.wechat.secret.clone(),
wechat_dev_mode: state.config.wechat.dev_mode,
redis: Some(state.redis.clone()),
crypto: state.pii_crypto.clone(),
}
}
}
/// Allow erp-config handlers to extract their required state without depending on erp-server.
impl FromRef<AppState> for erp_config::ConfigState {
fn from_ref(state: &AppState) -> Self {
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
}
}
}
/// Allow erp-workflow handlers to extract their required state without depending on erp-server.
impl FromRef<AppState> for erp_workflow::WorkflowState {
fn from_ref(state: &AppState) -> Self {
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
}
}
}
/// Allow erp-message handlers to extract their required state without depending on erp-server.
impl FromRef<AppState> for erp_message::MessageState {
fn from_ref(state: &AppState) -> Self {
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
}
}
}
/// Allow erp-plugin handlers to extract their required state.
impl FromRef<AppState> for erp_plugin::state::PluginState {
fn from_ref(state: &AppState) -> Self {
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
engine: state.plugin_engine.clone(),
entity_cache: state.plugin_entity_cache.clone(),
}
}
}
/// Allow erp-diary handlers to extract their required state.
impl FromRef<AppState> for erp_diary::DiaryState {
fn from_ref(state: &AppState) -> Self {
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
crypto: state.pii_crypto.clone(),
}
}
}

View File

@@ -0,0 +1,125 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
fn touch_heartbeat(heartbeat: &Arc<AtomicU64>) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
heartbeat.store(now, Ordering::Relaxed);
}
/// 启动事件清理后台任务。
///
/// 每日执行一次:
/// - 调用 `cleanup_old_published_events()` 归档 >7 天的已发布事件
/// - 调用 `cleanup_old_processed_events()` 清理 >7 天的去重记录
pub fn start_event_cleanup(db: sea_orm::DatabaseConnection, heartbeat: Arc<AtomicU64>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(86400));
loop {
interval.tick().await;
if let Err(e) = run_cleanup(&db).await {
tracing::warn!(error = %e, "事件清理任务执行失败");
}
touch_heartbeat(&heartbeat);
}
});
tracing::info!("事件清理任务已启动(每 24 小时执行一次)");
}
async fn run_cleanup(db: &sea_orm::DatabaseConnection) -> Result<(), sea_orm::DbErr> {
use sea_orm::ConnectionTrait;
// 归档 >7 天的已发布事件
match db
.execute_unprepared("SELECT cleanup_old_published_events(7, 1000)")
.await
{
Ok(result) => {
tracing::info!(rows_affected = result.rows_affected(), "已发布事件归档完成");
}
Err(e) => tracing::warn!(error = %e, "已发布事件归档失败"),
}
// 清理 >7 天的去重记录
match db
.execute_unprepared("SELECT cleanup_old_processed_events(7, 1000)")
.await
{
Ok(result) => {
tracing::info!(rows_affected = result.rows_affected(), "去重记录清理完成");
}
Err(e) => tracing::warn!(error = %e, "去重记录清理失败"),
}
Ok(())
}
/// 启动 DB 连接池 + EventBus 积压指标采样任务。
///
/// 每 30 秒采样一次并导出为 Prometheus gauge
/// - `db_pool_connections_active` — 当前活跃连接数
/// - `db_pool_connections_idle` — 当前空闲连接数
/// - `eventbus_pending_total` — pending 状态的领域事件数
pub fn start_pool_metrics(db: sea_orm::DatabaseConnection, heartbeat: Arc<AtomicU64>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
sample_pool_metrics(&db).await;
sample_eventbus_backlog(&db).await;
touch_heartbeat(&heartbeat);
}
});
tracing::info!("DB 连接池 + EventBus 积压指标采样已启动(每 30 秒采样一次)");
}
async fn sample_pool_metrics(db: &sea_orm::DatabaseConnection) {
use sea_orm::FromQueryResult;
#[derive(FromQueryResult)]
struct CountRow {
cnt: i64,
}
// 通过 pg_stat_activity 查询当前连接数
let stmt = sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"SELECT COUNT(*)::bigint AS cnt FROM pg_stat_activity WHERE state = 'active'".to_string(),
);
if let Ok(Some(row)) = CountRow::find_by_statement(stmt).one(db).await {
metrics::gauge!("db_pool_connections_active").set(row.cnt as f64);
}
let stmt = sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"SELECT COUNT(*)::bigint AS cnt FROM pg_stat_activity WHERE state = 'idle'".to_string(),
);
if let Ok(Some(row)) = CountRow::find_by_statement(stmt).one(db).await {
metrics::gauge!("db_pool_connections_idle").set(row.cnt as f64);
}
}
async fn sample_eventbus_backlog(db: &sea_orm::DatabaseConnection) {
use sea_orm::FromQueryResult;
#[derive(FromQueryResult)]
struct CountRow {
cnt: i64,
}
let stmt = sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"SELECT COUNT(*)::bigint AS cnt FROM domain_events WHERE status = 'pending'".to_string(),
);
match CountRow::find_by_statement(stmt).one(db).await {
Ok(Some(row)) => {
metrics::gauge!("eventbus_pending_total").set(row.cnt as f64);
}
_ => {
tracing::debug!("EventBus 积压采样:无法获取 pending 事件数");
}
}
}