feat: initialize ERP base platform (extracted from HMS)

- Stripped 11 business crates (health, ai, dialysis, plugins)
- Cleaned AppState, AppConfig, main.rs from business coupling
- Reduced migrations from 169 to 53 (base-only)
- Removed health_provider trait from erp-core
- Removed business integration tests
- Removed gateway rate limiting middleware
- Base capabilities: auth, RBAC, JWT, config, workflow, message, plugin, audit, crypto, RLS, multi-tenant

Cargo check: OK
Cargo test: OK
This commit is contained in:
iven
2026-05-31 20:35:57 +08:00
commit 59856ac2fc
639 changed files with 124710 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
[package]
name = "erp-core"
version.workspace = true
edition.workspace = true
[dependencies]
tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
uuid.workspace = true
chrono.workspace = true
thiserror.workspace = true
anyhow.workspace = true
tracing.workspace = true
axum.workspace = true
sea-orm.workspace = true
async-trait.workspace = true
utoipa.workspace = true
aes-gcm = "0.10"
hmac = "0.12"
sha2 = "0.10"
base64 = "0.22"
hex = "0.4"
rand = "0.8"
dashmap = "6"
ammonia.workspace = true

View File

@@ -0,0 +1,38 @@
//! 聚合查询容错工具
//!
//! 仪表盘等聚合统计端点通常包含多个独立子查询。
//! 单个子查询失败不应导致整个接口 500。
//! `safe_aggregate` 让每个子查询独立容错,失败时返回默认值并记录警告日志。
use std::future::Future;
/// 执行一个子查询,失败时返回 `T::default()` 并记录警告日志。
///
/// # 使用场景
///
/// 仪表盘统计 API 聚合多个指标(患者数/咨询数/随访数等),
/// 任一子查询失败不应阻塞其他指标返回。
///
/// # 示例
///
/// ```rust,ignore
/// let patients = safe_aggregate(
/// stats_service::get_patient_statistics(&state, tenant_id),
/// "患者统计",
/// ).await;
/// ```
pub async fn safe_aggregate<T: Default, E: std::fmt::Display>(
fut: impl Future<Output = Result<T, E>>,
label: &str,
) -> T {
match fut.await {
Ok(v) => {
tracing::debug!("聚合子查询 [{label}] 成功");
v
}
Err(e) => {
tracing::warn!("聚合子查询 [{label}] 失败,使用默认值: {e}");
T::default()
}
}
}

View File

@@ -0,0 +1,67 @@
use chrono::Utc;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// 审计日志记录。
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLog {
pub id: Uuid,
pub tenant_id: Uuid,
pub user_id: Option<Uuid>,
pub action: String,
pub resource_type: String,
pub resource_id: Option<Uuid>,
pub old_value: Option<serde_json::Value>,
pub new_value: Option<serde_json::Value>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub created_at: chrono::DateTime<Utc>,
}
impl AuditLog {
/// 创建一条审计日志记录。
pub fn new(
tenant_id: Uuid,
user_id: Option<Uuid>,
action: impl Into<String>,
resource_type: impl Into<String>,
) -> Self {
Self {
id: Uuid::now_v7(),
tenant_id,
user_id,
action: action.into(),
resource_type: resource_type.into(),
resource_id: None,
old_value: None,
new_value: None,
ip_address: None,
user_agent: None,
created_at: Utc::now(),
}
}
/// 设置资源 ID。
pub fn with_resource_id(mut self, id: Uuid) -> Self {
self.resource_id = Some(id);
self
}
/// 设置变更前后的值。
pub fn with_changes(
mut self,
old: Option<serde_json::Value>,
new: Option<serde_json::Value>,
) -> Self {
self.old_value = old;
self.new_value = new;
self
}
/// 设置请求来源信息。
pub fn with_request_info(mut self, ip: Option<String>, user_agent: Option<String>) -> Self {
self.ip_address = ip;
self.user_agent = user_agent;
self
}
}

View File

@@ -0,0 +1,285 @@
use crate::audit::AuditLog;
use crate::entity::audit_log;
use crate::request_info::RequestInfo;
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect, Set,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tracing;
use uuid::Uuid;
/// 审计日志中需要脱敏的 PII 字段名(小写匹配)
const PII_FIELDS: &[&str] = &[
"id_number",
"phone",
"emergency_contact_phone",
"emergency_contact_name",
"allergy_history",
"medical_history_summary",
"name",
"content",
];
/// 审计日志中需要脱敏的 resource_type 前缀
const PII_RESOURCE_TYPES: &[&str] = &[
"patient",
"consultation",
"follow_up",
"family_member",
"doctor_profile",
];
/// 对 JSON Value 中的 PII 字段进行脱敏
fn sanitize_audit_value(
value: &Option<serde_json::Value>,
resource_type: &str,
) -> Option<serde_json::Value> {
let needs_sanitization = PII_RESOURCE_TYPES
.iter()
.any(|prefix| resource_type.starts_with(prefix));
if !needs_sanitization {
return value.clone();
}
value.as_ref().map(sanitize_json_value)
}
fn sanitize_json_value(v: &serde_json::Value) -> serde_json::Value {
match v {
serde_json::Value::Object(map) => {
let sanitized: serde_json::Map<String, serde_json::Value> = map
.into_iter()
.map(|(k, v)| {
let key_lower = k.to_lowercase();
if PII_FIELDS.iter().any(|f| key_lower.contains(f)) {
(k.clone(), serde_json::Value::String("***".to_string()))
} else {
(k.clone(), sanitize_json_value(v))
}
})
.collect();
serde_json::Value::Object(sanitized)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(sanitize_json_value).collect())
}
other => other.clone(),
}
}
/// 持久化审计日志到 audit_logs 表。
///
/// 使用 fire-and-forget 模式:失败仅记录 warning 日志,不影响业务操作。
///
/// 自动从 task_local 读取当前请求的 IP 和 User-Agent
/// 如果 AuditLog 中已有 ip_address/user_agent 则不覆盖。
///
/// 哈希链:查询同租户最新一条记录的 record_hash 作为 prev_hash
/// 计算 SHA256(id + action + resource_type + resource_id + created_at + prev_hash) 作为 record_hash。
pub async fn record(mut log: AuditLog, db: &sea_orm::DatabaseConnection) {
// 自动填充请求来源信息(仅当调用方未显式设置时)
if let Some(info) = RequestInfo::try_current() {
if log.ip_address.is_none() {
log.ip_address = info.ip_address;
}
if log.user_agent.is_none() {
log.user_agent = info.user_agent;
}
}
// 查询同租户最新一条记录的 record_hash 作为 prev_hash
let prev_hash = audit_log::Entity::find()
.filter(audit_log::Column::TenantId.eq(log.tenant_id))
.filter(audit_log::Column::RecordHash.is_not_null())
.order_by_desc(audit_log::Column::CreatedAt)
.one(db)
.await
.ok()
.flatten()
.and_then(|m| m.record_hash);
// 计算当前记录的 record_hash
let record_hash = compute_record_hash(&log, prev_hash.as_deref());
// 脱敏处理:对 patient/consultation/follow_up 等资源类型的变更值中 PII 字段进行 mask
let sanitized_old = sanitize_audit_value(&log.old_value, &log.resource_type);
let sanitized_new = sanitize_audit_value(&log.new_value, &log.resource_type);
// 保存日志字段用于错误日志model 构建会 move String 字段)
let err_tenant_id = log.tenant_id;
let err_action = log.action.clone();
let err_resource_type = log.resource_type.clone();
let err_resource_id = log.resource_id;
let model = audit_log::ActiveModel {
id: Set(log.id),
tenant_id: Set(log.tenant_id),
user_id: Set(log.user_id),
action: Set(log.action),
resource_type: Set(log.resource_type),
resource_id: Set(log.resource_id),
old_value: Set(sanitized_old),
new_value: Set(sanitized_new),
ip_address: Set(log.ip_address),
user_agent: Set(log.user_agent),
created_at: Set(log.created_at),
prev_hash: Set(prev_hash),
record_hash: Set(Some(record_hash)),
};
if let Err(e) = model.insert(db).await {
tracing::error!(
error = %e,
tenant_id = ?err_tenant_id,
action = %err_action,
resource_type = %err_resource_type,
resource_id = ?err_resource_id,
"审计日志写入失败 — 数据完整性风险"
);
}
}
/// 计算 record_hash: SHA256(id + action + resource_type + resource_id + created_at + prev_hash)
fn compute_record_hash(log: &AuditLog, prev_hash: Option<&str>) -> String {
let mut hasher = Sha256::new();
hasher.update(log.id.to_string().as_bytes());
hasher.update(log.action.as_bytes());
hasher.update(log.resource_type.as_bytes());
hasher.update(
log.resource_id
.map(|id| id.to_string())
.unwrap_or_default()
.as_bytes(),
);
hasher.update(log.created_at.to_rfc3339().as_bytes());
hasher.update(prev_hash.unwrap_or("").as_bytes());
format!("{:x}", hasher.finalize())
}
/// 验证审计日志哈希链完整性。
///
/// 检查指定租户的所有含 record_hash 的日志记录,
/// 验证每条记录的 prev_hash 是否等于前一条的 record_hash
/// 以及 record_hash 是否可以重新计算验证。
///
/// 返回 (总记录数, 断链数)。
pub async fn verify_hash_chain(
db: &sea_orm::DatabaseConnection,
tenant_id: uuid::Uuid,
) -> Result<(usize, usize), sea_orm::DbErr> {
use sea_orm::QueryOrder;
let records = audit_log::Entity::find()
.filter(audit_log::Column::TenantId.eq(tenant_id))
.filter(audit_log::Column::RecordHash.is_not_null())
.order_by_asc(audit_log::Column::CreatedAt)
.all(db)
.await?;
let total = records.len();
let mut broken = 0;
let mut prev: Option<String> = None;
for record in &records {
// 验证 prev_hash 指向正确
if prev.as_deref() != record.prev_hash.as_deref() {
broken += 1;
}
// 验证 record_hash 可重算
let log = AuditLog {
id: record.id,
tenant_id: record.tenant_id,
user_id: record.user_id,
action: record.action.clone(),
resource_type: record.resource_type.clone(),
resource_id: record.resource_id,
old_value: record.old_value.clone(),
new_value: record.new_value.clone(),
ip_address: record.ip_address.clone(),
user_agent: record.user_agent.clone(),
created_at: record.created_at,
};
let expected = compute_record_hash(&log, record.prev_hash.as_deref());
if Some(expected.as_str()) != record.record_hash.as_deref() {
broken += 1;
}
prev = record.record_hash.clone();
}
Ok((total, broken))
}
/// 哈希链验证结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainVerificationResult {
pub total: usize,
pub passed: usize,
pub failed: usize,
pub failed_ids: Vec<Uuid>,
}
/// 验证最近 N 条审计记录的哈希链完整性。
pub async fn verify_recent_chain(
db: &sea_orm::DatabaseConnection,
tenant_id: Uuid,
limit: u64,
) -> Result<ChainVerificationResult, String> {
let records = audit_log::Entity::find()
.filter(audit_log::Column::TenantId.eq(tenant_id))
.filter(audit_log::Column::RecordHash.is_not_null())
.order_by_desc(audit_log::Column::CreatedAt)
.limit(limit)
.all(db)
.await
.map_err(|e| format!("查询审计日志失败: {}", e))?;
let mut records = records;
records.sort_by(|a, b| a.created_at.cmp(&b.created_at));
let total = records.len();
let mut passed = 0;
let mut failed_ids = Vec::new();
let mut prev: Option<String> = None;
for record in &records {
let mut record_broken = false;
if prev.as_deref() != record.prev_hash.as_deref() {
record_broken = true;
}
let log = AuditLog {
id: record.id,
tenant_id: record.tenant_id,
user_id: record.user_id,
action: record.action.clone(),
resource_type: record.resource_type.clone(),
resource_id: record.resource_id,
old_value: record.old_value.clone(),
new_value: record.new_value.clone(),
ip_address: record.ip_address.clone(),
user_agent: record.user_agent.clone(),
created_at: record.created_at,
};
let expected = compute_record_hash(&log, record.prev_hash.as_deref());
if Some(expected.as_str()) != record.record_hash.as_deref() {
record_broken = true;
}
if record_broken {
failed_ids.push(record.id);
} else {
passed += 1;
}
prev = record.record_hash.clone();
}
let failed = total - passed;
Ok(ChainVerificationResult {
total,
passed,
failed,
failed_ids,
})
}

View File

@@ -0,0 +1,48 @@
use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use rand::RngCore;
const CIPHER_VERSION: u8 = 0x01;
/// AES-256-GCM 加密。输出格式: Base64(0x01 || nonce[12] || ciphertext + tag)
pub fn encrypt(key: &[u8; 32], plaintext: &str) -> Result<String, String> {
let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| e.to_string())?;
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|e| e.to_string())?;
let mut combined = vec![CIPHER_VERSION];
combined.extend_from_slice(&nonce_bytes);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(&combined))
}
/// AES-256-GCM 解密。支持 v1 格式: Base64(0x01 || nonce[12] || ciphertext + tag)
/// 兼容旧格式: Base64(nonce[12] || ciphertext + tag)
pub fn decrypt(key: &[u8; 32], encoded: &str) -> Result<String, String> {
let bytes = BASE64.decode(encoded).map_err(|e| e.to_string())?;
if bytes.len() < 13 {
return Err("ciphertext too short".into());
}
let (nonce_bytes, ciphertext) = if bytes[0] == CIPHER_VERSION {
// v1: version(1) + nonce(12) + ciphertext
if bytes.len() < 14 {
return Err("v1 ciphertext too short".into());
}
(&bytes[1..13], &bytes[13..])
} else {
// 旧格式: nonce(12) + ciphertext向后兼容
(&bytes[0..12], &bytes[12..])
};
let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| e.to_string())?;
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| e.to_string())?;
String::from_utf8(plaintext).map_err(|e| e.to_string())
}

View File

@@ -0,0 +1,24 @@
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
/// HMAC-SHA256 搜索索引。使用 KEK 派生的独立子密钥,与加密密钥分离。
pub fn hmac_hash(key: &[u8; 32], value: &str) -> String {
let hmac_key = derive_hmac_key(key);
let mut mac = HmacSha256::new_from_slice(&hmac_key).expect("HMAC key length is valid");
mac.update(value.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
/// 从 KEK 派生独立的 HMAC 子密钥,避免密钥复用
fn derive_hmac_key(kek: &[u8; 32]) -> [u8; 32] {
use sha2::Digest;
let derived = <Sha256 as Digest>::new()
.chain_update(b"pii-hmac-index-v1")
.chain_update(kek)
.finalize();
let mut key = [0u8; 32];
key.copy_from_slice(&derived);
key
}

View File

@@ -0,0 +1,225 @@
use std::time::Instant;
use dashmap::DashMap;
use uuid::Uuid;
use crate::error::{AppError, AppResult};
use super::engine;
/// DEK 缓存条目 — Drop 时清零密钥材料
#[derive(Clone)]
struct CachedDek {
dek: [u8; 32],
version: u32,
loaded_at: Instant,
}
impl Drop for CachedDek {
fn drop(&mut self) {
self.dek.fill(0);
}
}
/// DEK 缓存管理 — 每租户独立 DEKLRU + TTL
#[derive(Clone)]
pub struct DekManager {
cache: DashMap<Uuid, CachedDek>,
ttl_secs: u64,
max_entries: usize,
}
impl DekManager {
pub fn new(ttl_secs: u64, max_entries: usize) -> Self {
Self {
cache: DashMap::new(),
ttl_secs,
max_entries,
}
}
/// 获取或创建租户的 DEK
pub fn get_or_create_dek(
&self,
tenant_id: Uuid,
encrypted_dek: Option<&str>,
kek: &[u8; 32],
) -> AppResult<([u8; 32], u32)> {
// 检查缓存
if let Some(entry) = self.cache.get(&tenant_id)
&& entry.loaded_at.elapsed().as_secs() < self.ttl_secs
{
return Ok((entry.dek, entry.version));
}
// 从加密 DEK 解密
if let Some(enc_dek) = encrypted_dek {
let dek_hex = engine::decrypt(kek, enc_dek).map_err(AppError::Internal)?;
let dek_bytes = hex::decode(&dek_hex).map_err(|e| AppError::Internal(e.to_string()))?;
if dek_bytes.len() != 32 {
return Err(AppError::Internal("DEK must be 32 bytes".into()));
}
let mut dek = [0u8; 32];
dek.copy_from_slice(&dek_bytes);
// 缓存(版本从外部传入时无法确定,使用默认值 1
self.evict_if_full();
self.cache.insert(
tenant_id,
CachedDek {
dek,
version: 1,
loaded_at: Instant::now(),
},
);
return Ok((dek, 1));
}
// 无现有 DEK → 生成新的
let dek = Self::generate_dek();
self.evict_if_full();
self.cache.insert(
tenant_id,
CachedDek {
dek,
version: 1,
loaded_at: Instant::now(),
},
);
Ok((dek, 1))
}
/// 使用 KEK 加密 DEK 以便存储
pub fn encrypt_dek_for_storage(dek: &[u8; 32], kek: &[u8; 32]) -> AppResult<String> {
let dek_hex = hex::encode(dek);
engine::encrypt(kek, &dek_hex).map_err(AppError::Internal)
}
/// 生成新 DEK 并用 KEK 加密,返回 (新 DEK, 加密后的 DEK)
pub fn generate_new_dek(kek: &[u8; 32]) -> AppResult<([u8; 32], String)> {
let dek = Self::generate_dek();
let encrypted = Self::encrypt_dek_for_storage(&dek, kek)?;
Ok((dek, encrypted))
}
/// 使缓存失效(轮换后调用)
pub fn invalidate(&self, tenant_id: Uuid) {
self.cache.remove(&tenant_id);
}
fn generate_dek() -> [u8; 32] {
use rand::RngCore;
let mut dek = [0u8; 32];
rand::thread_rng().fill_bytes(&mut dek);
dek
}
fn evict_if_full(&self) {
if self.cache.len() >= self.max_entries {
let to_remove: Vec<Uuid> = self
.cache
.iter()
.filter(|e| e.loaded_at.elapsed().as_secs() > self.ttl_secs / 2)
.map(|e| *e.key())
.take(self.max_entries / 2)
.collect();
for id in to_remove {
self.cache.remove(&id);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::PiiCrypto;
fn test_kek() -> [u8; 32] {
*PiiCrypto::dev_default().kek()
}
fn test_uuid(i: u8) -> Uuid {
let s = format!("00000000-0000-0000-0000-0000000000{:02x}", i);
Uuid::parse_str(&s).unwrap()
}
#[test]
fn generate_new_dek_returns_32_bytes() {
let (dek, _enc) = DekManager::generate_new_dek(&test_kek()).unwrap();
assert_eq!(dek.len(), 32);
}
#[test]
fn generate_new_dek_produces_unique_keys() {
let (dek1, _) = DekManager::generate_new_dek(&test_kek()).unwrap();
let (dek2, _) = DekManager::generate_new_dek(&test_kek()).unwrap();
assert_ne!(dek1, dek2);
}
#[test]
fn encrypt_dek_roundtrip() {
let kek = test_kek();
let (original_dek, encrypted) = DekManager::generate_new_dek(&kek).unwrap();
let mgr = DekManager::new(300, 100);
let tenant_id = test_uuid(1);
let (recovered_dek, _ver) = mgr
.get_or_create_dek(tenant_id, Some(&encrypted), &kek)
.unwrap();
assert_eq!(original_dek, recovered_dek);
}
#[test]
fn get_or_create_generates_when_none() {
let mgr = DekManager::new(300, 100);
let tenant_id = test_uuid(2);
let (dek1, ver1) = mgr.get_or_create_dek(tenant_id, None, &test_kek()).unwrap();
assert_eq!(ver1, 1);
let (dek2, ver2) = mgr.get_or_create_dek(tenant_id, None, &test_kek()).unwrap();
assert_eq!(dek1, dek2);
assert_eq!(ver2, 1);
}
#[test]
fn invalidate_removes_cached_dek() {
let mgr = DekManager::new(300, 100);
let tenant_id = test_uuid(3);
let (dek1, _) = mgr.get_or_create_dek(tenant_id, None, &test_kek()).unwrap();
mgr.invalidate(tenant_id);
let (dek2, _) = mgr.get_or_create_dek(tenant_id, None, &test_kek()).unwrap();
assert_ne!(dek1, dek2);
}
#[test]
fn decrypt_with_wrong_kek_fails() {
let kek1 = test_kek();
let kek2 = [0xffu8; 32];
let (_, encrypted) = DekManager::generate_new_dek(&kek1).unwrap();
let mgr = DekManager::new(300, 100);
let tenant_id = test_uuid(4);
assert!(
mgr.get_or_create_dek(tenant_id, Some(&encrypted), &kek2)
.is_err()
);
}
#[test]
fn expired_entry_not_returned() {
let mgr = DekManager::new(0, 100);
let tenant_id = test_uuid(5);
let (dek1, _) = mgr.get_or_create_dek(tenant_id, None, &test_kek()).unwrap();
let (dek2, _) = mgr.get_or_create_dek(tenant_id, None, &test_kek()).unwrap();
assert_ne!(dek1, dek2);
}
#[test]
fn max_entries_eviction() {
let mgr = DekManager::new(300, 3);
for i in 0..5u8 {
let _ = mgr
.get_or_create_dek(test_uuid(i), None, &test_kek())
.unwrap();
}
assert!(mgr.cache.len() <= 6);
}
}

View File

@@ -0,0 +1,113 @@
/// 身份证号脱敏: 保留前 3 位和后 4 位,中间用 **** 替代
pub fn mask_id_number(s: &str) -> String {
let chars: Vec<char> = s.chars().collect();
if chars.len() >= 7 {
let head: String = chars[..3].iter().collect();
let tail: String = chars[chars.len() - 4..].iter().collect();
format!("{}****{}", head, tail)
} else {
"****".to_string()
}
}
/// 手机号脱敏: 保留前 3 位和后 4 位,中间用 **** 替代
pub fn mask_phone(s: Option<&str>) -> Option<String> {
s.map(|p| {
let chars: Vec<char> = p.chars().collect();
if chars.len() >= 7 {
let head: String = chars[..3].iter().collect();
let tail: String = chars[chars.len() - 4..].iter().collect();
format!("{}****{}", head, tail)
} else {
"****".to_string()
}
})
}
/// 执业证号脱敏: 保留前 2 位和后 2 位,中间用 **** 替代
pub fn mask_license_number(s: &str) -> String {
let chars: Vec<char> = s.chars().collect();
if chars.len() >= 5 {
let head: String = chars[..2].iter().collect();
let tail: String = chars[chars.len() - 2..].iter().collect();
format!("{}****{}", head, tail)
} else {
"****".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mask_id_18_digits() {
assert_eq!("110****1234", mask_id_number("110101199001011234"));
}
#[test]
fn mask_id_short() {
assert_eq!("****", mask_id_number("123456"));
}
#[test]
fn mask_id_empty() {
assert_eq!("****", mask_id_number(""));
}
#[test]
fn mask_phone_normal() {
assert_eq!(
Some("138****5678".to_string()),
mask_phone(Some("13812345678"))
);
}
#[test]
fn mask_phone_none() {
assert_eq!(None, mask_phone(None));
}
#[test]
fn mask_phone_short() {
assert_eq!(Some("****".to_string()), mask_phone(Some("123")));
}
#[test]
fn mask_phone_exactly_7() {
assert_eq!(Some("123****4567".to_string()), mask_phone(Some("1234567")));
}
#[test]
fn mask_id_exactly_7() {
assert_eq!("123****4567", mask_id_number("1234567"));
}
#[test]
fn mask_id_unicode_safe() {
assert_eq!("你好世****cdef", mask_id_number("你好世界abcdef"));
}
#[test]
fn mask_phone_unicode_safe() {
assert_eq!(
Some("你好世****cdef".to_string()),
mask_phone(Some("你好世界abcdef"))
);
}
#[test]
fn mask_license_normal() {
assert_eq!("YL****23", mask_license_number("YL-2024-00123"));
}
#[test]
fn mask_license_short() {
assert_eq!("****", mask_license_number("AB"));
}
#[test]
fn mask_license_empty() {
assert_eq!("****", mask_license_number(""));
}
}

View File

@@ -0,0 +1,234 @@
pub mod engine;
pub mod hmac_index;
pub mod key_manager;
pub mod masking;
pub use engine::{decrypt, encrypt};
pub use hmac_index::hmac_hash;
pub use key_manager::DekManager;
pub use masking::{mask_id_number, mask_license_number, mask_phone};
use crate::error::{AppError, AppResult};
/// PII 加密服务 — 封装 KEK 和 DEK 管理
#[derive(Clone)]
pub struct PiiCrypto {
kek: [u8; 32],
hmac_key: [u8; 32],
pub(crate) dek_manager: DekManager,
}
impl PiiCrypto {
/// 从 hex 编码的 KEK 创建。KEK 为 64 字符 hex32 字节)。
pub fn from_kek_hex(kek_hex: &str) -> AppResult<Self> {
let bytes = hex::decode(kek_hex)
.map_err(|e| AppError::Internal(format!("KEK hex decode failed: {}", e)))?;
if bytes.len() != 32 {
return Err(AppError::Internal(
"KEK must be 32 bytes (64 hex chars)".into(),
));
}
let mut kek = [0u8; 32];
kek.copy_from_slice(&bytes);
Ok(Self::from_kek(kek))
}
/// Dev fallback: 从固定字符串派生确定性 KEK。仅用于开发。
pub fn dev_default() -> Self {
use sha2::Digest;
let kek = <sha2::Sha256 as Digest>::digest(b"erp-pii-kek-dev-key-DO-NOT-USE-IN-PROD");
let mut key = [0u8; 32];
key.copy_from_slice(&kek);
Self::from_kek(key)
}
fn from_kek(kek: [u8; 32]) -> Self {
use sha2::Digest;
let hmac_key = <sha2::Sha256 as Digest>::new()
.chain_update(b"pii-hmac-index-v1")
.chain_update(kek)
.finalize();
let mut hk = [0u8; 32];
hk.copy_from_slice(&hmac_key);
Self {
kek,
hmac_key: hk,
dek_manager: DekManager::new(300, 100),
}
}
pub fn kek(&self) -> &[u8; 32] {
&self.kek
}
/// HMAC 搜索索引使用的独立子密钥
pub fn hmac_key(&self) -> &[u8; 32] {
&self.hmac_key
}
/// 使指定租户的 DEK 缓存失效
pub fn invalidate_dek(&self, tenant_id: uuid::Uuid) {
self.dek_manager.invalidate(tenant_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_crypto() -> PiiCrypto {
PiiCrypto::dev_default()
}
#[test]
fn from_kek_hex_roundtrip() {
let kek_hex = "00".repeat(32);
let crypto = PiiCrypto::from_kek_hex(&kek_hex).unwrap();
assert_eq!(crypto.kek(), &[0u8; 32]);
}
#[test]
fn from_kek_hex_invalid() {
assert!(PiiCrypto::from_kek_hex("not-hex").is_err());
}
#[test]
fn from_kek_hex_wrong_length() {
assert!(PiiCrypto::from_kek_hex("ab").is_err());
}
#[test]
fn encrypt_decrypt_roundtrip() {
let crypto = test_crypto();
let plaintext = "13812345678";
let encrypted = encrypt(crypto.kek(), plaintext).unwrap();
let decrypted = decrypt(crypto.kek(), &encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn encrypt_produces_different_ciphertexts() {
let crypto = test_crypto();
let e1 = encrypt(crypto.kek(), "test").unwrap();
let e2 = encrypt(crypto.kek(), "test").unwrap();
assert_ne!(e1, e2);
}
#[test]
fn decrypt_wrong_key_fails() {
let crypto1 = PiiCrypto::dev_default();
let other_key_hex = "ff".repeat(32);
let crypto2 = PiiCrypto::from_kek_hex(&other_key_hex).unwrap();
let encrypted = encrypt(crypto1.kek(), "test").unwrap();
assert!(decrypt(crypto2.kek(), &encrypted).is_err());
}
#[test]
fn hmac_hash_deterministic() {
let crypto = test_crypto();
let h1 = hmac_hash(crypto.hmac_key(), "13812345678");
let h2 = hmac_hash(crypto.hmac_key(), "13812345678");
assert_eq!(h1, h2);
}
#[test]
fn hmac_hash_different_inputs() {
let crypto = test_crypto();
let h1 = hmac_hash(crypto.hmac_key(), "111");
let h2 = hmac_hash(crypto.hmac_key(), "222");
assert_ne!(h1, h2);
}
#[test]
fn hmac_key_differs_from_kek() {
let crypto = test_crypto();
assert_ne!(crypto.kek(), crypto.hmac_key(), "HMAC 密钥应与 KEK 不同");
}
#[test]
fn encrypt_empty_string() {
let crypto = test_crypto();
let encrypted = encrypt(crypto.kek(), "").unwrap();
let decrypted = decrypt(crypto.kek(), &encrypted).unwrap();
assert_eq!("", decrypted);
}
#[test]
fn decrypt_too_short_fails() {
use base64::Engine;
let short = base64::engine::general_purpose::STANDARD.encode(b"short");
assert!(decrypt(&[0u8; 32], &short).is_err());
}
#[test]
fn encrypt_unicode() {
let crypto = test_crypto();
let plaintext = "患者过敏史:青霉素、磺胺类药物";
let encrypted = encrypt(crypto.kek(), plaintext).unwrap();
let decrypted = decrypt(crypto.kek(), &encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn ciphertext_has_version_prefix() {
let crypto = test_crypto();
let encrypted = encrypt(crypto.kek(), "test").unwrap();
use base64::Engine;
let bytes = base64::engine::general_purpose::STANDARD
.decode(&encrypted)
.unwrap();
assert_eq!(bytes[0], 0x01, "密文首字节应为版本号 0x01");
}
// ── 性能基准 ──
#[test]
fn bench_encrypt_1000() {
let crypto = test_crypto();
let kek = crypto.kek();
let plaintext = "13812345678";
let start = std::time::Instant::now();
for _ in 0..1000 {
let _ = encrypt(kek, plaintext).unwrap();
}
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() / 1000;
assert!(avg_us < 50, "encrypt 平均耗时应 < 50μs, 实际: {}μs", avg_us);
eprintln!("encrypt 1000 次: {:?} (avg {}μs)", elapsed, avg_us);
}
#[test]
fn bench_decrypt_1000() {
let crypto = test_crypto();
let kek = crypto.kek();
let ciphertext = encrypt(kek, "13812345678").unwrap();
let start = std::time::Instant::now();
for _ in 0..1000 {
let _ = decrypt(kek, &ciphertext).unwrap();
}
let elapsed = start.elapsed();
let avg_us = elapsed.as_micros() / 1000;
assert!(avg_us < 50, "decrypt 平均耗时应 < 50μs, 实际: {}μs", avg_us);
eprintln!("decrypt 1000 次: {:?} (avg {}μs)", elapsed, avg_us);
}
#[test]
fn bench_batch_decrypt_50() {
let crypto = test_crypto();
let kek = crypto.kek();
let ciphertexts: Vec<String> = (0..50)
.map(|i| encrypt(kek, &format!("数据{}", i)).unwrap())
.collect();
let start = std::time::Instant::now();
for ct in &ciphertexts {
let _ = decrypt(kek, ct).unwrap();
}
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 10,
"批量解密 50 条应 < 10ms, 实际: {}ms",
elapsed.as_millis()
);
eprintln!("batch decrypt 50 条: {:?}", elapsed);
}
}

View File

@@ -0,0 +1,29 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
/// 审计日志实体 — 映射 audit_logs 表。
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "audit_logs")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: Uuid,
pub tenant_id: Uuid,
pub user_id: Option<Uuid>,
pub action: String,
pub resource_type: String,
pub resource_id: Option<Uuid>,
pub old_value: Option<serde_json::Value>,
pub new_value: Option<serde_json::Value>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub created_at: DateTimeUtc,
/// 哈希链 — 前一条记录的 record_hash
pub prev_hash: Option<String>,
/// 当前记录的哈希 SHA256(id + action + resource_type + resource_id + created_at + prev_hash)
pub record_hash: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,27 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "dead_letter_events")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: Uuid,
#[sea_orm(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<Uuid>,
pub original_event_id: Uuid,
pub event_type: String,
#[sea_orm(skip_serializing_if = "Option::is_none")]
pub payload: Option<serde_json::Value>,
pub consumer_id: String,
pub attempts: i32,
#[sea_orm(skip_serializing_if = "Option::is_none")]
pub last_error: Option<String>,
pub created_at: DateTimeUtc,
#[sea_orm(skip_serializing_if = "Option::is_none")]
pub resolved_at: Option<DateTimeUtc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,24 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
/// 领域事件实体 — 映射 domain_events 表。
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "domain_events")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: Uuid,
pub tenant_id: Uuid,
pub event_type: String,
pub payload: Option<serde_json::Value>,
pub correlation_id: Option<Uuid>,
pub status: String,
pub attempts: i32,
pub last_error: Option<String>,
pub created_at: DateTimeUtc,
pub published_at: Option<DateTimeUtc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,4 @@
pub mod audit_log;
pub mod dead_letter_event;
pub mod domain_event;
pub mod processed_event;

View File

@@ -0,0 +1,18 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
/// 已处理事件记录 — 幂等性去重表。
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "processed_events")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub event_id: Uuid,
#[sea_orm(primary_key, auto_increment = false)]
pub consumer_id: String,
pub processed_at: DateTimeUtc,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,188 @@
use axum::Json;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde::Serialize;
/// 统一错误响应格式
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
/// 平台级错误类型
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error("资源未找到: {0}")]
NotFound(String),
#[error("验证失败: {0}")]
Validation(String),
#[error("未授权")]
Unauthorized,
#[error("禁止访问: {0}")]
Forbidden(String),
#[error("冲突: {0}")]
Conflict(String),
#[error("版本冲突: 数据已被其他操作修改,请刷新后重试")]
VersionMismatch,
#[error("请求过于频繁,请稍后重试")]
TooManyRequests,
#[error("内部错误: {0}")]
Internal(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, message) = match &self {
AppError::NotFound(_) => (StatusCode::NOT_FOUND, self.to_string()),
AppError::Validation(_) => (StatusCode::BAD_REQUEST, self.to_string()),
AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "未授权".to_string()),
AppError::Forbidden(_) => (StatusCode::FORBIDDEN, self.to_string()),
AppError::Conflict(_) => (StatusCode::CONFLICT, self.to_string()),
AppError::VersionMismatch => (StatusCode::CONFLICT, self.to_string()),
AppError::TooManyRequests => (StatusCode::TOO_MANY_REQUESTS, self.to_string()),
AppError::Internal(msg) => {
tracing::error!("Internal error: {}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, "内部错误".to_string())
}
};
let body = ErrorResponse {
error: status.canonical_reason().unwrap_or("Error").to_string(),
message,
details: None,
};
(status, Json(body)).into_response()
}
}
impl From<anyhow::Error> for AppError {
fn from(err: anyhow::Error) -> Self {
AppError::Internal(err.to_string())
}
}
impl From<sea_orm::DbErr> for AppError {
fn from(err: sea_orm::DbErr) -> Self {
match err {
sea_orm::DbErr::RecordNotFound(msg) => AppError::NotFound(msg),
sea_orm::DbErr::Query(sea_orm::RuntimeErr::SqlxError(e))
if e.to_string().contains("duplicate key") =>
{
AppError::Conflict("记录已存在".to_string())
}
_ => AppError::Internal(err.to_string()),
}
}
}
pub type AppResult<T> = Result<T, AppError>;
/// 检查乐观锁版本是否匹配。
///
/// 返回下一个版本号actual + 1或 VersionMismatch 错误。
pub fn check_version(expected: i32, actual: i32) -> AppResult<i32> {
if expected == actual {
Ok(actual + 1)
} else {
Err(AppError::VersionMismatch)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_version_ok() {
assert_eq!(check_version(1, 1).unwrap(), 2);
assert_eq!(check_version(5, 5).unwrap(), 6);
}
#[test]
fn check_version_mismatch() {
let result = check_version(1, 2);
assert!(result.is_err());
match result.unwrap_err() {
AppError::VersionMismatch => {}
other => panic!("Expected VersionMismatch, got {:?}", other),
}
}
#[test]
fn db_err_record_not_found_maps_to_not_found() {
let err = sea_orm::DbErr::RecordNotFound("test".to_string());
let app_err: AppError = err.into();
match app_err {
AppError::NotFound(msg) => assert_eq!(msg, "test"),
other => panic!("Expected NotFound, got {:?}", other),
}
}
#[test]
fn db_err_generic_maps_to_internal() {
let db_err = sea_orm::DbErr::Custom("some error".to_string());
let app_err: AppError = db_err.into();
match app_err {
AppError::Internal(msg) => assert!(msg.contains("some error")),
other => panic!("Expected Internal, got {:?}", other),
}
}
#[test]
fn app_error_into_response_status_codes() {
// NotFound -> 404
let resp = AppError::NotFound("test".to_string()).into_response();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
// Validation -> 400
let resp = AppError::Validation("bad input".to_string()).into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Unauthorized -> 401
let resp = AppError::Unauthorized.into_response();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
// Forbidden -> 403
let resp = AppError::Forbidden("no access".to_string()).into_response();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
// VersionMismatch -> 409
let resp = AppError::VersionMismatch.into_response();
assert_eq!(resp.status(), StatusCode::CONFLICT);
// TooManyRequests -> 429
let resp = AppError::TooManyRequests.into_response();
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
// Internal -> 500
let resp = AppError::Internal("oops".to_string()).into_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn app_error_internal_hides_details_from_response() {
// Internal errors should map to 500 with a generic message
let resp = AppError::Internal("sensitive db error detail".to_string()).into_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn anyhow_error_maps_to_internal() {
let err: AppError = anyhow::anyhow!("something went wrong").into();
match err {
AppError::Internal(msg) => assert_eq!(msg, "something went wrong"),
other => panic!("Expected Internal, got {:?}", other),
}
}
}

View File

@@ -0,0 +1,458 @@
use chrono::Utc;
use sea_orm::{
ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, PaginatorTrait, QueryFilter, Set,
};
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, mpsc};
use tracing::{error, info};
use uuid::Uuid;
use crate::entity::dead_letter_event;
use crate::entity::domain_event;
/// 已知的 PII 字段列表 -- 在事件 payload 中自动脱敏
const PII_FIELDS: &[&str] = &[
"phone",
"id_number",
"emergency_contact_phone",
"emergency_contact_name",
"medical_history_summary",
"allergy_history",
"content",
];
/// 递归脱敏 payload 中的 PII 字段(原地修改)。
fn sanitize_payload(payload: &mut serde_json::Value) {
if let Some(obj) = payload.as_object_mut() {
for field in PII_FIELDS {
if let Some(val) = obj.get_mut(*field)
&& val.is_string()
{
*val = serde_json::Value::String("[REDACTED]".to_string());
}
}
for val in obj.values_mut() {
if val.is_object() {
sanitize_payload(val);
}
}
}
}
/// 领域事件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainEvent {
pub id: Uuid,
pub event_type: String,
pub tenant_id: Uuid,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<Utc>,
pub correlation_id: Uuid,
}
impl DomainEvent {
pub fn new(event_type: impl Into<String>, tenant_id: Uuid, payload: serde_json::Value) -> Self {
Self {
id: Uuid::now_v7(),
event_type: event_type.into(),
tenant_id,
payload,
timestamp: Utc::now(),
correlation_id: Uuid::now_v7(),
}
}
}
/// 当前事件 payload schema 版本
pub const EVENT_SCHEMA_VERSION: &str = "v1";
/// 构造统一信封格式的事件 payload。
///
/// 自动注入 `schema_version` 和 `occurred_at`,业务数据通过 `data` 传入。
/// 用法:`build_event_payload(serde_json::json!({ "patient_id": ..., }))`
pub fn build_event_payload(data: serde_json::Value) -> serde_json::Value {
let mut envelope = serde_json::json!({
"schema_version": EVENT_SCHEMA_VERSION,
"occurred_at": Utc::now().to_rfc3339(),
});
if let serde_json::Value::Object(ref mut map) = envelope
&& let serde_json::Value::Object(data_map) = data
{
for (k, v) in data_map {
map.insert(k, v);
}
}
envelope
}
/// 检查事件是否已被指定消费者处理。
///
/// 查询 `processed_events` 表判断 event_id + consumer_id 是否已存在。
pub async fn is_event_processed(
db: &sea_orm::DatabaseConnection,
event_id: Uuid,
consumer_id: &str,
) -> Result<bool, sea_orm::DbErr> {
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
let count = crate::entity::processed_event::Entity::find()
.filter(crate::entity::processed_event::Column::EventId.eq(event_id))
.filter(crate::entity::processed_event::Column::ConsumerId.eq(consumer_id))
.count(db)
.await?;
Ok(count > 0)
}
/// 标记事件已被指定消费者处理。
///
/// 插入 `processed_events` 记录,重复插入会因主键冲突被安全忽略。
pub async fn mark_event_processed(
db: &sea_orm::DatabaseConnection,
event_id: Uuid,
consumer_id: &str,
) -> Result<(), sea_orm::DbErr> {
use sea_orm::ActiveModelTrait;
use sea_orm::Set;
let model = crate::entity::processed_event::ActiveModel {
event_id: Set(event_id),
consumer_id: Set(consumer_id.to_string()),
processed_at: Set(Utc::now()),
};
// INSERT ... ON CONFLICT DO NOTHING主键冲突时安全忽略
match model.insert(db).await {
Ok(_) => Ok(()),
Err(e) => {
// 唯一约束冲突 = 已处理,不是错误
if e.to_string().contains("duplicate") || e.to_string().contains("violates unique") {
Ok(())
} else {
Err(e)
}
}
}
}
/// 消费事件 — 带幂等检查和 dead-letter 兜底。
///
/// 如果事件已被处理(幂等),返回 `ConsumeResult::AlreadyProcessed`。
/// 如果处理成功,标记为已处理并返回 `ConsumeResult::Success`。
/// 如果处理失败,将事件转入 dead_letter_events 表并返回 `ConsumeResult::DeadLettered`。
pub async fn consume_with_retry<F, Fut>(
db: &sea_orm::DatabaseConnection,
event: &DomainEvent,
consumer_id: &str,
handler: F,
) -> ConsumeResult
where
F: FnOnce(&DomainEvent) -> Fut,
Fut: std::future::Future<Output = Result<(), String>>,
{
if is_event_processed(db, event.id, consumer_id)
.await
.unwrap_or(false)
{
return ConsumeResult::AlreadyProcessed;
}
match handler(event).await {
Ok(()) => {
if let Err(e) = mark_event_processed(db, event.id, consumer_id).await {
tracing::warn!(
event_id = %event.id,
consumer_id = consumer_id,
error = %e,
"标记事件已处理失败(非致命)"
);
}
ConsumeResult::Success
}
Err(err) => {
tracing::error!(
event_id = %event.id,
event_type = %event.event_type,
consumer_id = consumer_id,
error = %err,
"事件消费失败,转入 dead-letter"
);
if let Err(e) = insert_dead_letter(db, event, consumer_id, &err).await {
tracing::error!(
event_id = %event.id,
error = %e,
"Dead-letter 写入失败"
);
}
ConsumeResult::DeadLettered(err)
}
}
}
/// 消费结果
#[derive(Debug)]
pub enum ConsumeResult {
Success,
AlreadyProcessed,
DeadLettered(String),
}
/// 将失败事件写入 dead_letter_events 表
pub async fn insert_dead_letter(
db: &sea_orm::DatabaseConnection,
event: &DomainEvent,
consumer_id: &str,
error_msg: &str,
) -> Result<(), sea_orm::DbErr> {
let model = dead_letter_event::ActiveModel {
id: Set(Uuid::now_v7()),
tenant_id: Set(Some(event.tenant_id)),
original_event_id: Set(event.id),
event_type: Set(event.event_type.clone()),
payload: Set(Some(event.payload.clone())),
consumer_id: Set(consumer_id.to_string()),
attempts: Set(1),
last_error: Set(Some(error_msg.to_string())),
created_at: Set(Utc::now()),
resolved_at: Set(None),
};
model.insert(db).await?;
Ok(())
}
/// 过滤事件接收器 — 只接收匹配 `event_type_prefix` 的事件
pub struct FilteredEventReceiver {
receiver: mpsc::Receiver<DomainEvent>,
}
impl FilteredEventReceiver {
/// 接收下一个匹配的事件
pub async fn recv(&mut self) -> Option<DomainEvent> {
self.receiver.recv().await
}
}
/// 订阅句柄 — 用于取消过滤订阅
pub struct SubscriptionHandle {
cancel_tx: mpsc::Sender<()>,
join_handle: tokio::task::JoinHandle<()>,
}
impl SubscriptionHandle {
/// 取消订阅并等待后台任务结束
pub async fn cancel(self) {
let _ = self.cancel_tx.send(()).await;
let _ = self.join_handle.await;
}
}
/// 进程内事件总线
#[derive(Clone)]
pub struct EventBus {
sender: broadcast::Sender<DomainEvent>,
}
impl EventBus {
pub fn new(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self { sender }
}
/// 发布事件:先持久化到 domain_events 表pending 状态),再内存广播,
/// 最后更新为 published 并 NOTIFY outbox relay。
///
/// 两阶段提交保证:即使广播后服务崩溃,事件仍为 pending 状态,
/// 重启后 outbox relay 会重新广播。
pub async fn publish(&self, mut event: DomainEvent, db: &sea_orm::DatabaseConnection) {
// 0. 脱敏 payload 中的 PII 字段
sanitize_payload(&mut event.payload);
// 1. 持久化为 pending 状态
let event_id = event.id;
let model = domain_event::ActiveModel {
id: Set(event.id),
tenant_id: Set(event.tenant_id),
event_type: Set(event.event_type.clone()),
payload: Set(Some(event.payload.clone())),
correlation_id: Set(Some(event.correlation_id)),
status: Set("pending".to_string()),
attempts: Set(0),
last_error: Set(None),
created_at: Set(event.timestamp),
published_at: Set(None),
};
let saved = match model.insert(db).await {
Ok(m) => m,
Err(e) => {
tracing::warn!(event_id = %event_id, error = %e, "领域事件持久化失败");
// 持久化失败仍然广播best-effort
self.broadcast(event);
return;
}
};
// 2. 内存广播
self.broadcast(event);
// 3. 更新为 published
let mut active: domain_event::ActiveModel = saved.into();
active.status = Set("published".to_string());
active.published_at = Set(Some(Utc::now()));
if let Err(e) = active.update(db).await {
tracing::warn!(event_id = %event_id, error = %e, "领域事件状态更新为 published 失败");
}
// 4. NOTIFY outbox relay通知 outbox relay 有新事件到达)
let notify_sql = sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
format!("NOTIFY outbox_channel, '{}'", event_id),
);
if let Err(e) = db.execute(notify_sql).await {
tracing::debug!(event_id = %event_id, error = %e, "NOTIFY outbox_channel 失败(非致命)");
}
}
/// 仅内存广播(不持久化,用于内部测试等场景)。
pub fn broadcast(&self, event: DomainEvent) {
info!(event_type = %event.event_type, event_id = %event.id, "Event broadcast");
if let Err(e) = self.sender.send(event) {
error!("Failed to broadcast event: {}", e);
}
}
/// 订阅所有事件,返回接收端
pub fn subscribe(&self) -> broadcast::Receiver<DomainEvent> {
self.sender.subscribe()
}
/// 按事件类型前缀过滤订阅。
///
/// 为每次调用 spawn 一个 Tokio task 从 broadcast channel 读取,
/// 只转发匹配 `event_type_prefix` 的事件到 mpsc channelcapacity 256
pub fn subscribe_filtered(
&self,
event_type_prefix: String,
) -> (FilteredEventReceiver, SubscriptionHandle) {
let mut broadcast_rx = self.sender.subscribe();
let (mpsc_tx, mpsc_rx) = mpsc::channel(256);
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
let prefix = event_type_prefix.clone();
let join_handle = tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = cancel_rx.recv() => {
tracing::info!(prefix = %prefix, "Filtered subscription cancelled");
break;
}
event = broadcast_rx.recv() => {
match event {
Ok(event) => {
if event.event_type.starts_with(&prefix)
&& mpsc_tx.send(event).await.is_err()
{
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(prefix = %prefix, lagged = n, "Filtered subscriber lagged");
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
});
tracing::info!(prefix = %event_type_prefix, "Filtered subscription created");
(
FilteredEventReceiver { receiver: mpsc_rx },
SubscriptionHandle {
cancel_tx,
join_handle,
},
)
}
}
/// 重试 dead_letter_events 中未解决的失败事件(指数退避)。
pub async fn retry_dead_letters(
db: &sea_orm::DatabaseConnection,
bus: &EventBus,
max_attempts: i32,
) -> Result<u64, String> {
// 1. 查询所有未解决且未超过最大重试次数的 dead-letter
let pending = dead_letter_event::Entity::find()
.filter(dead_letter_event::Column::ResolvedAt.is_null())
.filter(dead_letter_event::Column::Attempts.lt(max_attempts))
.all(db)
.await
.map_err(|e| format!("查询 dead_letter_events 失败: {}", e))?;
let retried = pending.len() as u64;
for dl in &pending {
let event = DomainEvent {
id: dl.original_event_id,
event_type: dl.event_type.clone(),
tenant_id: dl.tenant_id.unwrap_or(Uuid::nil()),
payload: dl.payload.clone().unwrap_or(serde_json::Value::Null),
timestamp: dl.created_at,
correlation_id: Uuid::now_v7(),
};
bus.broadcast(event);
let mut active: dead_letter_event::ActiveModel = dl.clone().into();
let new_attempts = dl.attempts + 1;
active.attempts = Set(new_attempts);
active.last_error = Set(Some(format!(
"{} 次自动重试({}",
new_attempts,
Utc::now().to_rfc3339()
)));
if let Err(e) = active.update(db).await {
tracing::warn!(
dead_letter_id = %dl.id,
error = %e,
"更新 dead_letter_events attempts 失败"
);
}
}
// 2. 标记超过最大重试次数的记录为永久失败
let exhausted = dead_letter_event::Entity::find()
.filter(dead_letter_event::Column::ResolvedAt.is_null())
.filter(dead_letter_event::Column::Attempts.gte(max_attempts))
.all(db)
.await
.map_err(|e| format!("查询超限 dead_letter_events 失败: {}", e))?;
for dl in &exhausted {
let mut active: dead_letter_event::ActiveModel = dl.clone().into();
active.resolved_at = Set(Some(Utc::now()));
active.last_error = Set(Some(format!(
"已达最大重试次数 {},标记为永久失败",
max_attempts
)));
if let Err(e) = active.update(db).await {
tracing::warn!(
dead_letter_id = %dl.id,
error = %e,
"标记 dead_letter_event 为永久失败时更新失败"
);
}
}
if retried > 0 || !exhausted.is_empty() {
tracing::info!(
retried = retried,
permanently_failed = exhausted.len(),
"Dead-letter 自动重试完成"
);
}
Ok(retried)
}

View File

@@ -0,0 +1,19 @@
pub mod aggregate;
pub mod audit;
pub mod audit_service;
pub mod crypto;
pub mod entity;
pub mod error;
pub mod events;
pub mod module;
pub mod rbac;
pub mod request_info;
pub mod sanitize;
pub mod sea_orm_ext;
pub mod types;
#[cfg(test)]
pub mod test_helpers;
// 便捷导出
pub use module::{ModuleContext, ModuleType, PermissionDescriptor};

View File

@@ -0,0 +1,357 @@
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
use crate::error::{AppError, AppResult};
use crate::events::EventBus;
/// 权限描述符,用于模块声明自己需要的权限。
///
/// 各业务模块通过 `ErpModule::permissions()` 返回此列表,
/// 由 erp-server 在启动时统一注册到权限表。
#[derive(Clone, Debug)]
pub struct PermissionDescriptor {
/// 权限编码,全局唯一,格式建议 `{模块}.{动作}` 如 `plugin.admin`
pub code: String,
/// 权限显示名称
pub name: String,
/// 权限描述
pub description: String,
/// 所属模块名称
pub module: String,
}
/// 模块类型
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModuleType {
/// 内置模块(编译时链接)
Builtin,
/// 插件模块(运行时加载)
Plugin,
}
/// 模块启动上下文 — 在 on_startup 时提供给模块
pub struct ModuleContext {
pub db: sea_orm::DatabaseConnection,
pub event_bus: EventBus,
}
/// 模块注册接口
/// 所有业务模块Auth, Workflow, Message, Config, 行业模块)都实现此 trait
#[async_trait::async_trait]
pub trait ErpModule: Send + Sync {
/// 模块名称(唯一标识)
fn name(&self) -> &str;
/// 模块唯一 ID默认等于 name
fn id(&self) -> &str {
self.name()
}
/// 模块版本
fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
}
/// 模块类型
fn module_type(&self) -> ModuleType {
ModuleType::Builtin
}
/// 依赖的其他模块名称
fn dependencies(&self) -> Vec<&str> {
vec![]
}
/// 注册事件处理器
fn register_event_handlers(&self, _bus: &EventBus) {}
/// 模块启动钩子 — 服务启动时调用
async fn on_startup(&self, _ctx: &ModuleContext) -> AppResult<()> {
Ok(())
}
/// 模块关闭钩子 — 服务关闭时调用
async fn on_shutdown(&self) -> AppResult<()> {
Ok(())
}
/// 健康检查
async fn health_check(&self) -> AppResult<serde_json::Value> {
Ok(serde_json::json!({"status": "healthy"}))
}
/// 租户创建时的初始化钩子。
///
/// 用于为新建租户创建默认角色、管理员用户等初始数据。
async fn on_tenant_created(
&self,
_tenant_id: Uuid,
_db: &sea_orm::DatabaseConnection,
_event_bus: &EventBus,
) -> AppResult<()> {
Ok(())
}
/// 租户删除时的清理钩子。
///
/// 用于软删除该租户下的所有关联数据。
async fn on_tenant_deleted(
&self,
_tenant_id: Uuid,
_db: &sea_orm::DatabaseConnection,
) -> AppResult<()> {
Ok(())
}
/// 返回此模块需要注册的权限列表。
///
/// 默认返回空列表,有权限需求的模块(如 plugin可覆写此方法。
fn permissions(&self) -> Vec<PermissionDescriptor> {
vec![]
}
/// Downcast support: return `self` as `&dyn Any` for concrete type access.
///
/// This allows the server crate to retrieve module-specific methods
/// (e.g. `AuthModule::public_routes()`) that are not part of the trait.
fn as_any(&self) -> &dyn Any;
}
/// 模块注册器 — 用 Arc 包装使其可 Clone用于 Axum State
#[derive(Clone, Default)]
pub struct ModuleRegistry {
modules: Arc<Vec<Arc<dyn ErpModule>>>,
}
impl ModuleRegistry {
pub fn new() -> Self {
Self {
modules: Arc::new(vec![]),
}
}
pub fn register(mut self, module: impl ErpModule + 'static) -> Self {
tracing::info!(
module = module.name(),
id = module.id(),
version = module.version(),
module_type = ?module.module_type(),
"Module registered"
);
let mut modules = (*self.modules).clone();
modules.push(Arc::new(module));
self.modules = Arc::new(modules);
self
}
pub fn register_handlers(&self, bus: &EventBus) {
for module in self.modules.iter() {
module.register_event_handlers(bus);
}
}
pub fn modules(&self) -> &[Arc<dyn ErpModule>] {
&self.modules
}
/// 按名称获取模块
pub fn get_module(&self, name: &str) -> Option<Arc<dyn ErpModule>> {
self.modules.iter().find(|m| m.name() == name).cloned()
}
/// 按拓扑排序返回模块(依赖在前,被依赖在后)
///
/// 使用 Kahn 算法,环检测返回 Validation 错误。
pub fn sorted_modules(&self) -> AppResult<Vec<Arc<dyn ErpModule>>> {
let modules = &*self.modules;
let n = modules.len();
if n == 0 {
return Ok(vec![]);
}
// 构建名称到索引的映射
let name_to_idx: HashMap<&str, usize> = modules
.iter()
.enumerate()
.map(|(i, m)| (m.name(), i))
.collect();
// 构建邻接表和入度
let mut adjacency: Vec<Vec<usize>> = vec![vec![]; n];
let mut in_degree: Vec<usize> = vec![0; n];
for (idx, module) in modules.iter().enumerate() {
for dep in module.dependencies() {
if let Some(&dep_idx) = name_to_idx.get(dep) {
adjacency[dep_idx].push(idx);
in_degree[idx] += 1;
}
// 依赖未注册的模块不阻断(可能是可选依赖)
}
}
// Kahn 算法
let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut sorted_indices = Vec::with_capacity(n);
while let Some(idx) = queue.pop() {
sorted_indices.push(idx);
for &next in &adjacency[idx] {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push(next);
}
}
}
if sorted_indices.len() != n {
let cycle_modules: Vec<&str> = (0..n)
.filter(|i| !sorted_indices.contains(i))
.filter_map(|i| modules.get(i).map(|m| m.name()))
.collect();
return Err(AppError::Validation(format!(
"模块依赖存在循环: {}",
cycle_modules.join(", ")
)));
}
Ok(sorted_indices
.into_iter()
.map(|i| modules[i].clone())
.collect())
}
/// 按拓扑顺序启动所有模块
pub async fn startup_all(&self, ctx: &ModuleContext) -> AppResult<()> {
let sorted = self.sorted_modules()?;
for module in sorted {
tracing::info!(module = module.name(), "Starting module");
module.on_startup(ctx).await?;
tracing::info!(module = module.name(), "Module started");
}
Ok(())
}
/// 按拓扑逆序关闭所有模块
pub async fn shutdown_all(&self) -> AppResult<()> {
let sorted = self.sorted_modules()?;
for module in sorted.into_iter().rev() {
tracing::info!(module = module.name(), "Shutting down module");
if let Err(e) = module.on_shutdown().await {
tracing::error!(module = module.name(), error = %e, "Module shutdown failed");
}
}
Ok(())
}
/// 对所有模块执行健康检查
pub async fn health_check_all(&self) -> Vec<(String, AppResult<serde_json::Value>)> {
let mut results = Vec::with_capacity(self.modules.len());
for module in self.modules.iter() {
let result = module.health_check().await;
results.push((module.name().to_string(), result));
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestModule {
name: &'static str,
deps: Vec<&'static str>,
}
#[async_trait::async_trait]
impl ErpModule for TestModule {
fn name(&self) -> &str {
self.name
}
fn dependencies(&self) -> Vec<&str> {
self.deps.clone()
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[test]
fn sorted_modules_empty() {
let registry = ModuleRegistry::new();
let sorted = registry.sorted_modules().unwrap();
assert!(sorted.is_empty());
}
#[test]
fn sorted_modules_no_deps() {
let registry = ModuleRegistry::new()
.register(TestModule {
name: "a",
deps: vec![],
})
.register(TestModule {
name: "b",
deps: vec![],
});
let sorted = registry.sorted_modules().unwrap();
assert_eq!(sorted.len(), 2);
}
#[test]
fn sorted_modules_with_deps() {
let registry = ModuleRegistry::new()
.register(TestModule {
name: "auth",
deps: vec![],
})
.register(TestModule {
name: "plugin",
deps: vec!["auth", "config"],
})
.register(TestModule {
name: "config",
deps: vec!["auth"],
});
let sorted = registry.sorted_modules().unwrap();
let names: Vec<&str> = sorted.iter().map(|m| m.name()).collect();
let auth_pos = names.iter().position(|&n| n == "auth").unwrap();
let config_pos = names.iter().position(|&n| n == "config").unwrap();
let plugin_pos = names.iter().position(|&n| n == "plugin").unwrap();
assert!(auth_pos < config_pos);
assert!(config_pos < plugin_pos);
}
#[test]
fn sorted_modules_circular_dep() {
let registry = ModuleRegistry::new()
.register(TestModule {
name: "a",
deps: vec!["b"],
})
.register(TestModule {
name: "b",
deps: vec!["a"],
});
let result = registry.sorted_modules();
assert!(result.is_err());
match result.err().unwrap() {
AppError::Validation(msg) => assert!(msg.contains("循环")),
other => panic!("Expected Validation, got {:?}", other),
}
}
#[test]
fn get_module_found() {
let registry = ModuleRegistry::new().register(TestModule {
name: "auth",
deps: vec![],
});
assert!(registry.get_module("auth").is_some());
assert!(registry.get_module("unknown").is_none());
}
}

102
crates/erp-core/src/rbac.rs Normal file
View File

@@ -0,0 +1,102 @@
use crate::error::AppError;
use crate::types::{DataScope, TenantContext};
/// Check whether the `TenantContext` includes the specified permission code.
///
/// Returns `Ok(())` if the permission is present, or `AppError::Forbidden` otherwise.
pub fn require_permission(ctx: &TenantContext, permission: &str) -> Result<(), AppError> {
if ctx.permissions.iter().any(|p| p == permission) {
Ok(())
} else {
Err(AppError::Forbidden("权限不足".to_string()))
}
}
/// Check whether the `TenantContext` includes at least one of the specified permission codes.
///
/// Useful when multiple permissions can grant access to the same resource.
pub fn require_any_permission(ctx: &TenantContext, permissions: &[&str]) -> Result<(), AppError> {
let has_any = permissions
.iter()
.any(|p| ctx.permissions.iter().any(|up| up == *p));
if has_any {
Ok(())
} else {
Err(AppError::Forbidden("权限不足".to_string()))
}
}
/// Check whether the `TenantContext` includes the specified role code.
///
/// Returns `Ok(())` if the role is present, or `AppError::Forbidden` otherwise.
pub fn require_role(ctx: &TenantContext, role: &str) -> Result<(), AppError> {
if ctx.roles.iter().any(|r| r == role) {
Ok(())
} else {
Err(AppError::Forbidden("权限不足".to_string()))
}
}
/// 获取指定权限的数据范围。默认 All向后兼容
///
/// Service 层根据返回值追加对应的查询过滤条件。
pub fn get_data_scope(ctx: &TenantContext, permission: &str) -> DataScope {
ctx.permission_data_scopes
.get(permission)
.cloned()
.unwrap_or(DataScope::All)
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
fn test_ctx(roles: Vec<&str>, permissions: Vec<&str>) -> TenantContext {
TenantContext {
tenant_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
roles: roles.into_iter().map(String::from).collect(),
permissions: permissions.into_iter().map(String::from).collect(),
department_ids: vec![],
permission_data_scopes: std::collections::HashMap::new(),
}
}
#[test]
fn require_permission_succeeds_when_present() {
let ctx = test_ctx(vec![], vec!["user.read", "user.write"]);
assert!(require_permission(&ctx, "user.read").is_ok());
}
#[test]
fn require_permission_fails_when_missing() {
let ctx = test_ctx(vec![], vec!["user.read"]);
assert!(require_permission(&ctx, "user.delete").is_err());
}
#[test]
fn require_any_permission_succeeds_with_match() {
let ctx = test_ctx(vec![], vec!["user.read"]);
assert!(require_any_permission(&ctx, &["user.delete", "user.read"]).is_ok());
}
#[test]
fn require_any_permission_fails_with_no_match() {
let ctx = test_ctx(vec![], vec!["user.read"]);
assert!(require_any_permission(&ctx, &["user.delete", "user.admin"]).is_err());
}
#[test]
fn require_role_succeeds_when_present() {
let ctx = test_ctx(vec!["admin", "user"], vec![]);
assert!(require_role(&ctx, "admin").is_ok());
}
#[test]
fn require_role_fails_when_missing() {
let ctx = test_ctx(vec!["user"], vec![]);
assert!(require_role(&ctx, "admin").is_err());
}
}

View File

@@ -0,0 +1,54 @@
/// 请求来源信息IP 地址 + User-Agent
///
/// 通过 `tokio::task_local!` 在请求生命周期内传递,
/// JWT 中间件设置,审计服务自动读取。
#[derive(Debug, Clone, Default)]
pub struct RequestInfo {
pub ip_address: Option<String>,
pub user_agent: Option<String>,
}
tokio::task_local! {
/// 当前请求的来源信息。
///
/// 在 JWT 中间件中通过 `REQUEST_INFO.scope(info, future)` 设置,
/// 在 `audit_service::record()` 中自动读取。
pub static REQUEST_INFO: RequestInfo;
}
impl RequestInfo {
/// 从 HTTP 请求头中提取 IP 地址和 User-Agent。
///
/// IP 优先级X-Forwarded-For > X-Real-IP > 直接连接(不记录)。
pub fn from_headers(headers: &axum::http::HeaderMap) -> Self {
let ip_address = headers
.get("X-Forwarded-For")
.and_then(|v| v.to_str().ok())
.map(|s| {
// X-Forwarded-For 可能包含多个 IP取第一个客户端真实 IP
s.split(',').next().unwrap_or(s).trim().to_string()
})
.or_else(|| {
headers
.get("X-Real-IP")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
});
let user_agent = headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
Self {
ip_address,
user_agent,
}
}
/// 尝试从 task_local 中读取当前请求信息。
/// 如果不在请求上下文中(如后台任务),返回 None。
pub fn try_current() -> Option<Self> {
REQUEST_INFO.try_with(|info| info.clone()).ok()
}
}

View File

@@ -0,0 +1,218 @@
/// HTML/Script 内容清理工具。
///
/// 基于 ammoniahtml5ever剥离所有 HTML 标签,防止存储型 XSS。
/// 覆盖场景:用户名、显示名、邮箱、电话等字符串字段。
///
/// 剥离字符串中的所有 HTML 标签,返回纯文本。
///
/// 使用 ammonia 构建 DOM 树,然后用 tendril 收集文本节点。
/// 比手写字符级解析器更安全,能正确处理所有 HTML 边界情况。
pub fn strip_html_tags(input: &str) -> String {
// 使用 ammonia 清理(保留在 span 中的纯文本),然后剥离 span 标签
let doc = ammonia::Builder::new()
.tags(std::collections::HashSet::new())
.clean(input)
.to_string();
// ammonia 的 clean() 结果可能包含 HTML 实体(如 &lt;),需要解码
// 但由于所有标签已被禁止,结果是纯文本(可能有实体转义)
// 使用二次清理:将结果作为纯文本处理
decode_entities(&doc).trim().to_string()
}
/// 简单解码常见 HTML 实体。
fn decode_entities(input: &str) -> String {
input
.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&amp;", "&")
.replace("&quot;", "\"")
.replace("&#39;", "'")
.replace("&apos;", "'")
.replace("&#47;", "/")
.replace("&#32;", " ")
}
/// 对 Option<String> 类型的字段进行清理。
pub fn sanitize_option(input: Option<String>) -> Option<String> {
input.map(|s| strip_html_tags(&s)).filter(|s| !s.is_empty())
}
/// 对 String 类型的必填字段进行清理。
pub fn sanitize_string(input: &str) -> String {
strip_html_tags(input)
}
/// 对富文本 HTML 进行安全清理,保留安全的 HTML 标签和内联样式,去除危险元素。
/// 适用于文章内容等需要保留 HTML 排版的场景。
pub fn sanitize_rich_html(input: &str) -> String {
use std::collections::{HashMap, HashSet};
let tag_attrs: HashMap<&str, HashSet<&str>> = [
("div", HashSet::from(["style", "data-w-e-type"])),
("span", HashSet::from(["style"])),
("p", HashSet::from(["style"])),
(
"img",
HashSet::from(["src", "alt", "style", "width", "height"]),
),
("a", HashSet::from(["href", "target"])),
("td", HashSet::from(["style", "colspan", "rowspan"])),
("th", HashSet::from(["style", "colspan", "rowspan"])),
("blockquote", HashSet::from(["style"])),
]
.into_iter()
.collect();
ammonia::Builder::new()
.tags(
[
"p",
"br",
"span",
"div",
"strong",
"b",
"em",
"i",
"u",
"s",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"ul",
"ol",
"li",
"blockquote",
"pre",
"code",
"table",
"thead",
"tbody",
"tr",
"th",
"td",
"img",
"a",
"hr",
]
.into_iter()
.collect(),
)
.tag_attributes(tag_attrs)
.generic_attributes(HashSet::from(["style"]))
.url_relative(ammonia::UrlRelative::PassThrough)
.clean(input)
.to_string()
}
/// 对 Option<String> 的富文本进行安全清理。
pub fn sanitize_rich_html_option(input: Option<String>) -> Option<String> {
input
.map(|s| sanitize_rich_html(&s))
.filter(|s| !s.trim().is_empty())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strips_script_tag() {
// script 内容在 HTML 规范中是 raw textammonia 正确地将其完全移除
assert_eq!(strip_html_tags("<script>alert('xss')</script>"), "");
}
#[test]
fn strips_img_onerror() {
assert_eq!(strip_html_tags("<img src=x onerror=alert(1)>"), "");
}
#[test]
fn strips_bold_tags() {
assert_eq!(strip_html_tags("Hello <b>World</b>"), "Hello World");
}
#[test]
fn no_tags_passthrough() {
assert_eq!(strip_html_tags("Normal text"), "Normal text");
}
#[test]
fn nested_tags() {
assert_eq!(strip_html_tags("<div><p>text</p></div>"), "text");
}
#[test]
fn sanitize_option_some() {
assert_eq!(
sanitize_option(Some("<b>evil</b>".to_string())),
Some("evil".to_string())
);
}
#[test]
fn sanitize_option_none() {
assert_eq!(sanitize_option(None), None);
}
#[test]
fn sanitize_option_becomes_empty() {
assert_eq!(sanitize_option(Some("<img>".to_string())), None);
}
#[test]
fn strips_nested_script_attack() {
let result = strip_html_tags("<scr<script>ipt>alert(1)</scr</script>ipt>");
assert!(!result.contains("<"), "不应残留 HTML 标签");
}
#[test]
fn strips_unclosed_tag() {
let result = strip_html_tags("text <img");
assert!(result.contains("text") || result.is_empty());
}
#[test]
fn handles_entities() {
let result = strip_html_tags("a &lt; b");
assert!(result.contains("a") && result.contains("b"));
}
#[test]
fn rich_html_preserves_safe_tags() {
let html = r#"<p>Hello</p><div style="background:#f0fdf4;padding:14px">Green box</div><strong>Bold</strong>"#;
let result = sanitize_rich_html(html);
assert!(result.contains("<p>Hello</p>"), "should preserve <p> tags");
assert!(
result.contains("<strong>Bold</strong>"),
"should preserve <strong>"
);
assert!(
result.contains("background"),
"should preserve style attribute"
);
}
#[test]
fn rich_html_removes_script() {
let html = r#"<p>Hello</p><script>alert(1)</script>"#;
let result = sanitize_rich_html(html);
assert!(!result.contains("script"), "should remove script tags");
assert!(result.contains("Hello"));
}
#[test]
fn rich_html_preserves_styled_block() {
let html = r#"<div data-w-e-type="styled-block" style="background:#f0fdf4;border-radius:8px;padding:14px">Tip content</div>"#;
let result = sanitize_rich_html(html);
assert!(
result.contains("styled-block"),
"should preserve data-w-e-type"
);
assert!(result.contains("Tip content"));
}
}

View File

@@ -0,0 +1,17 @@
use sea_orm::ActiveValue;
/// 从 SeaORM ActiveValue<i32> 中安全提取 version 值。
/// Set(v) / Unchanged(v) → 返回 v
/// NotSet → 返回 1首次版本号
/// 绝不 panic。
pub fn safe_version(val: &ActiveValue<i32>) -> i32 {
match val {
ActiveValue::Set(v) | ActiveValue::Unchanged(v) => *v,
ActiveValue::NotSet => 1,
}
}
/// 安全递增 version基于当前值 +1绝不 panic。
pub fn bump_version(current: &ActiveValue<i32>) -> i32 {
safe_version(current) + 1
}

View File

@@ -0,0 +1,37 @@
//! 测试基础设施 — 事务回滚模式解决并行化问题
//!
//! 每个测试在独立事务中执行,测试结束自动回滚,无数据残留。
//! 多个测试共享同一个数据库连接池,无连接竞争。
use sea_orm::{
ConnectOptions, Database, DatabaseConnection, DatabaseTransaction, TransactionTrait,
};
use std::sync::OnceLock;
use tokio::sync::OnceCell;
static DB_POOL: OnceCell<DatabaseConnection> = OnceCell::const_new();
static DB_URL: OnceLock<String> = OnceLock::new();
fn db_url() -> String {
DB_URL
.get_or_init(|| {
std::env::var("TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgres://erp:erp@localhost:5432/erp_test".into())
})
.clone()
}
async fn db_pool() -> &'static DatabaseConnection {
DB_POOL
.get_or_init(|| async {
let opt = ConnectOptions::new(db_url()).max_connections(5).to_owned();
Database::connect(opt).await.expect("测试数据库连接失败")
})
.await
}
/// 创建测试用事务。测试结束自动回滚,无数据残留。
pub async fn test_txn() -> DatabaseTransaction {
let pool = db_pool().await;
pool.begin().await.expect("测试事务创建失败")
}

View File

@@ -0,0 +1,188 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
/// 所有数据库实体的公共字段
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BaseFields {
pub id: Uuid,
pub tenant_id: Uuid,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub created_by: Uuid,
pub updated_by: Uuid,
pub deleted_at: Option<DateTime<Utc>>,
pub version: i32,
}
/// 分页请求
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct Pagination {
pub page: Option<u64>,
pub page_size: Option<u64>,
}
impl Pagination {
pub fn offset(&self) -> u64 {
(self.page.unwrap_or(1).saturating_sub(1)) * self.limit()
}
pub fn limit(&self) -> u64 {
self.page_size.unwrap_or(20).min(100)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pagination_defaults() {
let p = Pagination {
page: None,
page_size: None,
};
assert_eq!(p.limit(), 20);
assert_eq!(p.offset(), 0);
}
#[test]
fn pagination_custom_values() {
let p = Pagination {
page: Some(3),
page_size: Some(10),
};
assert_eq!(p.limit(), 10);
assert_eq!(p.offset(), 20); // (3-1) * 10
}
#[test]
fn pagination_max_cap() {
let p = Pagination {
page: Some(1),
page_size: Some(200),
};
assert_eq!(p.limit(), 100); // capped at 100
}
#[test]
fn pagination_page_zero_treated_as_first() {
// page 0 -> saturating_sub wraps to 0 -> offset = 0
let p = Pagination {
page: Some(0),
page_size: Some(10),
};
assert_eq!(p.offset(), 0);
}
#[test]
fn pagination_page_one() {
let p = Pagination {
page: Some(1),
page_size: Some(50),
};
assert_eq!(p.offset(), 0);
}
#[test]
fn paginated_response_total_pages() {
let resp = PaginatedResponse {
data: vec![1, 2, 3],
total: 25,
page: 1,
page_size: 10,
total_pages: 3,
};
assert_eq!(resp.data.len(), 3);
assert_eq!(resp.total, 25);
assert_eq!(resp.total_pages, 3);
}
#[test]
fn api_response_ok() {
let resp = ApiResponse::ok(42);
assert!(resp.success);
assert_eq!(resp.data, Some(42));
assert!(resp.message.is_none());
}
#[test]
fn tenant_context_fields() {
let ctx = TenantContext {
tenant_id: Uuid::now_v7(),
user_id: Uuid::now_v7(),
roles: vec!["admin".to_string()],
permissions: vec!["user.read".to_string()],
department_ids: vec![],
permission_data_scopes: HashMap::new(),
};
assert_eq!(ctx.roles.len(), 1);
assert_eq!(ctx.permissions.len(), 1);
}
}
/// 分页响应
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct PaginatedResponse<T> {
pub data: Vec<T>,
pub total: u64,
pub page: u64,
pub page_size: u64,
pub total_pages: u64,
}
/// API 统一响应
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub data: Option<T>,
pub message: Option<String>,
}
impl<T: Serialize> ApiResponse<T> {
pub fn ok(data: T) -> Self {
Self {
success: true,
data: Some(data),
message: None,
}
}
}
/// 行级数据权限范围
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataScope {
/// 查看所有数据
All,
/// 仅查看自己创建的数据
SelfOnly,
/// 仅查看本部门数据
Department,
/// 查看本部门及下属部门数据
DepartmentTree,
}
impl DataScope {
pub fn parse_scope(s: &str) -> Self {
match s {
"self" => Self::SelfOnly,
"department" => Self::Department,
"department_tree" => Self::DepartmentTree,
_ => Self::All,
}
}
}
/// 租户上下文(中间件注入)
#[derive(Debug, Clone)]
pub struct TenantContext {
pub tenant_id: Uuid,
pub user_id: Uuid,
pub roles: Vec<String>,
pub permissions: Vec<String>,
/// 用户所属部门 ID 列表(行级数据权限使用)
pub department_ids: Vec<Uuid>,
/// 每个权限码对应的数据范围(从 role_permissions.data_scope 加载)
pub permission_data_scopes: HashMap<String, DataScope>,
}