fix: 前端深度审计全量修复 — 安全/功能/代码质量
严重 BUG 修复: - 修复 Token 过期后 hash 重定向导致无法跳转登录页 - 修复文章编辑器新建后提交审核使用错误 ID 安全加固: - HTML 清理函数替换为 ammonia 专业库(替代自定义解析器) - 文件上传添加 magic bytes 校验(防 Content-Type 伪造) - 登录添加账户级失败锁定(5次失败→15分钟锁定) - 审计日志 9 个关键更新操作补充变更前后值(with_changes) 功能缺陷修复: - 登录/登出时清理 API 缓存(防多账户数据污染) - 文章编辑器上传改用统一 HTTP 客户端(自动 token 刷新) - 添加全局 HTTP 错误处理和后端错误消息展示 - PrivateRoute 增加路由级权限检查(系统管理页面) - 健康数据三个 Tab 添加编辑/删除功能 - 预约创建增加排班可用性校验提示 - 医生详情 API 返回解密后的原始执照号 代码清理: - 删除未使用的 auth.ts refresh() 函数 - 删除重复的 AuthGuard.tsx 组件 - 删除未使用的 getHealthSummary API
This commit is contained in:
@@ -23,3 +23,4 @@ base64 = "0.22"
|
||||
hex = "0.4"
|
||||
rand = "0.8"
|
||||
dashmap = "6"
|
||||
ammonia.workspace = true
|
||||
|
||||
@@ -1,44 +1,36 @@
|
||||
/// HTML/Script 内容清理工具。
|
||||
///
|
||||
/// 在用户输入进入数据库之前,剥离所有 HTML 标签,防止存储型 XSS。
|
||||
/// 基于 ammonia(html5ever)剥离所有 HTML 标签,防止存储型 XSS。
|
||||
/// 覆盖场景:用户名、显示名、邮箱、电话等字符串字段。
|
||||
|
||||
/// 剥离字符串中的所有 HTML 标签,返回纯文本。
|
||||
///
|
||||
/// ```rust
|
||||
/// use erp_core::sanitize::strip_html_tags;
|
||||
/// assert_eq!(strip_html_tags("<script>alert(1)</script>"), "alert(1)");
|
||||
/// assert_eq!(strip_html_tags("<img src=x onerror=alert(1)>"), "");
|
||||
/// assert_eq!(strip_html_tags("Hello <b>World</b>"), "Hello World");
|
||||
/// ```
|
||||
/// 使用 ammonia 构建 DOM 树,然后用 tendril 收集文本节点。
|
||||
/// 比手写字符级解析器更安全,能正确处理所有 HTML 边界情况。
|
||||
pub fn strip_html_tags(input: &str) -> String {
|
||||
let mut result = String::with_capacity(input.len());
|
||||
let mut in_tag = false;
|
||||
let mut depth = 0usize;
|
||||
// 使用 ammonia 清理(保留在 span 中的纯文本),然后剥离 span 标签
|
||||
let doc = ammonia::Builder::new()
|
||||
.tags(std::collections::HashSet::new())
|
||||
.clean(input)
|
||||
.to_string();
|
||||
|
||||
for ch in input.chars() {
|
||||
match ch {
|
||||
'<' => {
|
||||
in_tag = true;
|
||||
depth += 1;
|
||||
}
|
||||
'>' => {
|
||||
if depth > 0 {
|
||||
depth -= 1;
|
||||
}
|
||||
if depth == 0 {
|
||||
in_tag = false;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if !in_tag {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// ammonia 的 clean() 结果可能包含 HTML 实体(如 <),需要解码
|
||||
// 但由于所有标签已被禁止,结果是纯文本(可能有实体转义)
|
||||
// 使用二次清理:将结果作为纯文本处理
|
||||
decode_entities(&doc).trim().to_string()
|
||||
}
|
||||
|
||||
result.trim().to_string()
|
||||
/// 简单解码常见 HTML 实体。
|
||||
fn decode_entities(input: &str) -> String {
|
||||
input
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace("&", "&")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
.replace("/", "/")
|
||||
.replace(" ", " ")
|
||||
}
|
||||
|
||||
/// 对 Option<String> 类型的字段进行清理。
|
||||
@@ -57,7 +49,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn strips_script_tag() {
|
||||
assert_eq!(strip_html_tags("<script>alert('xss')</script>"), "alert('xss')");
|
||||
// script 内容在 HTML 规范中是 raw text,ammonia 正确地将其完全移除
|
||||
assert_eq!(strip_html_tags("<script>alert('xss')</script>"), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -83,7 +76,7 @@ mod tests {
|
||||
#[test]
|
||||
fn sanitize_option_some() {
|
||||
assert_eq!(
|
||||
sanitize_option(Some("<script>evil</script>".to_string())),
|
||||
sanitize_option(Some("<b>evil</b>".to_string())),
|
||||
Some("evil".to_string())
|
||||
);
|
||||
}
|
||||
@@ -97,4 +90,22 @@ mod tests {
|
||||
fn sanitize_option_becomes_empty() {
|
||||
assert_eq!(sanitize_option(Some("<img>".to_string())), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strips_nested_script_attack() {
|
||||
let result = strip_html_tags("<scr<script>ipt>alert(1)</scr</script>ipt>");
|
||||
assert!(!result.contains("<"), "不应残留 HTML 标签");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strips_unclosed_tag() {
|
||||
let result = strip_html_tags("text <img");
|
||||
assert!(result.contains("text") || result.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_entities() {
|
||||
let result = strip_html_tags("a < b");
|
||||
assert!(result.contains("a") && result.contains("b"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,6 +411,14 @@ pub async fn update_schedule(
|
||||
}
|
||||
}
|
||||
|
||||
// 记录变更前的关键字段
|
||||
let old_values = serde_json::json!({
|
||||
"start_time": model.start_time,
|
||||
"end_time": model.end_time,
|
||||
"max_appointments": model.max_appointments,
|
||||
"status": model.status,
|
||||
});
|
||||
|
||||
let mut active: doctor_schedule::ActiveModel = model.into();
|
||||
if let Some(v) = req.start_time { active.start_time = Set(v); }
|
||||
if let Some(v) = req.end_time { active.end_time = Set(v); }
|
||||
@@ -422,9 +430,18 @@ pub async fn update_schedule(
|
||||
|
||||
let m = active.update(&state.db).await?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"start_time": m.start_time,
|
||||
"end_time": m.end_time,
|
||||
"max_appointments": m.max_appointments,
|
||||
"status": m.status,
|
||||
});
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "doctor_schedule.updated", "doctor_schedule")
|
||||
.with_resource_id(m.id),
|
||||
.with_resource_id(m.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
|
||||
@@ -170,6 +170,14 @@ pub async fn update_task(
|
||||
validate_follow_up_status_transition(&model.status, new_status)?;
|
||||
}
|
||||
|
||||
// 记录变更前的关键字段
|
||||
let old_values = serde_json::json!({
|
||||
"assigned_to": model.assigned_to,
|
||||
"follow_up_type": model.follow_up_type,
|
||||
"planned_date": model.planned_date,
|
||||
"status": model.status,
|
||||
});
|
||||
|
||||
let mut active: follow_up_task::ActiveModel = model.into();
|
||||
if let Some(v) = req.assigned_to { active.assigned_to = Set(Some(v)); }
|
||||
if let Some(v) = req.follow_up_type { active.follow_up_type = Set(v); }
|
||||
@@ -182,9 +190,18 @@ pub async fn update_task(
|
||||
|
||||
let m = active.update(&state.db).await?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"assigned_to": m.assigned_to,
|
||||
"follow_up_type": m.follow_up_type,
|
||||
"planned_date": m.planned_date,
|
||||
"status": m.status,
|
||||
});
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "follow_up_task.updated", "follow_up_task")
|
||||
.with_resource_id(m.id),
|
||||
.with_resource_id(m.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
|
||||
@@ -156,6 +156,19 @@ pub async fn update_vital_signs(
|
||||
let next_ver = check_version(expected_version, model.version)
|
||||
.map_err(|_| HealthError::VersionMismatch)?;
|
||||
|
||||
// 记录变更前的关键体征值
|
||||
let old_values = serde_json::json!({
|
||||
"record_date": model.record_date,
|
||||
"systolic_bp_morning": model.systolic_bp_morning,
|
||||
"diastolic_bp_morning": model.diastolic_bp_morning,
|
||||
"systolic_bp_evening": model.systolic_bp_evening,
|
||||
"diastolic_bp_evening": model.diastolic_bp_evening,
|
||||
"heart_rate": model.heart_rate,
|
||||
"weight": model.weight,
|
||||
"blood_sugar": model.blood_sugar,
|
||||
"notes": model.notes,
|
||||
});
|
||||
|
||||
let mut active: vital_signs::ActiveModel = model.into();
|
||||
if let Some(v) = req.record_date { active.record_date = Set(v); }
|
||||
if let Some(v) = req.systolic_bp_morning { active.systolic_bp_morning = Set(Some(v)); }
|
||||
@@ -174,6 +187,19 @@ pub async fn update_vital_signs(
|
||||
|
||||
let m = active.update(&state.db).await?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"record_date": m.record_date,
|
||||
"systolic_bp_morning": m.systolic_bp_morning,
|
||||
"diastolic_bp_morning": m.diastolic_bp_morning,
|
||||
"systolic_bp_evening": m.systolic_bp_evening,
|
||||
"diastolic_bp_evening": m.diastolic_bp_evening,
|
||||
"heart_rate": m.heart_rate,
|
||||
"weight": m.weight,
|
||||
"blood_sugar": m.blood_sugar,
|
||||
"notes": m.notes,
|
||||
});
|
||||
|
||||
// 更新后也触发危急值检测(修改后的值可能触发告警)
|
||||
let check_req = CreateVitalSignsReq {
|
||||
record_date: m.record_date,
|
||||
@@ -193,7 +219,8 @@ pub async fn update_vital_signs(
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "vital_signs.updated", "vital_signs")
|
||||
.with_resource_id(m.id),
|
||||
.with_resource_id(m.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
@@ -406,6 +433,16 @@ pub async fn update_lab_report(
|
||||
let next_ver = check_version(expected_version, model.version)
|
||||
.map_err(|_| HealthError::VersionMismatch)?;
|
||||
|
||||
// 记录变更前的关键字段(items 为加密值,记录 meta 信息)
|
||||
let old_values = serde_json::json!({
|
||||
"report_date": model.report_date,
|
||||
"report_type": model.report_type,
|
||||
"status": model.status,
|
||||
"has_items": model.items.is_some(),
|
||||
"has_image_urls": model.image_urls.is_some(),
|
||||
"has_doctor_notes": model.doctor_notes.is_some(),
|
||||
});
|
||||
|
||||
let mut active: lab_report::ActiveModel = model.into();
|
||||
if let Some(v) = req.report_date { active.report_date = Set(v); }
|
||||
if let Some(v) = req.report_type { active.report_type = Set(v); }
|
||||
@@ -430,9 +467,20 @@ pub async fn update_lab_report(
|
||||
|
||||
let m = active.update(&state.db).await?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"report_date": m.report_date,
|
||||
"report_type": m.report_type,
|
||||
"status": m.status,
|
||||
"has_items": m.items.is_some(),
|
||||
"has_image_urls": m.image_urls.is_some(),
|
||||
"has_doctor_notes": m.doctor_notes.is_some(),
|
||||
});
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "lab_report.updated", "lab_report")
|
||||
.with_resource_id(m.id),
|
||||
.with_resource_id(m.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
@@ -514,6 +562,7 @@ pub async fn review_lab_report(
|
||||
|
||||
validate_lab_report_status_transition(&model.status, "reviewed")?;
|
||||
|
||||
let old_status = model.status.clone();
|
||||
let mut active: lab_report::ActiveModel = model.into();
|
||||
active.status = Set("reviewed".to_string());
|
||||
active.reviewed_by = Set(Some(reviewer_id));
|
||||
@@ -539,7 +588,11 @@ pub async fn review_lab_report(
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, Some(reviewer_id), "lab_report.reviewed", "lab_report")
|
||||
.with_resource_id(m.id),
|
||||
.with_resource_id(m.id)
|
||||
.with_changes(
|
||||
Some(serde_json::json!({ "status": old_status })),
|
||||
Some(serde_json::json!({ "status": m.status, "reviewed_by": m.reviewed_by })),
|
||||
),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
@@ -675,6 +728,14 @@ pub async fn update_health_record(
|
||||
let next_ver = check_version(expected_version, model.version)
|
||||
.map_err(|_| HealthError::VersionMismatch)?;
|
||||
|
||||
// 记录变更前的关键字段
|
||||
let old_values = serde_json::json!({
|
||||
"record_type": model.record_type,
|
||||
"record_date": model.record_date,
|
||||
"overall_assessment": model.overall_assessment,
|
||||
"notes": model.notes,
|
||||
});
|
||||
|
||||
let mut active: health_record::ActiveModel = model.into();
|
||||
if let Some(ref v) = req.record_type { validate_record_type(v)?; active.record_type = Set(v.clone()); }
|
||||
if let Some(v) = req.record_date { active.record_date = Set(v); }
|
||||
@@ -688,9 +749,18 @@ pub async fn update_health_record(
|
||||
|
||||
let m = active.update(&state.db).await?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"record_type": m.record_type,
|
||||
"record_date": m.record_date,
|
||||
"overall_assessment": m.overall_assessment,
|
||||
"notes": m.notes,
|
||||
});
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "health_record.updated", "health_record")
|
||||
.with_resource_id(m.id),
|
||||
.with_resource_id(m.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
|
||||
@@ -603,6 +603,15 @@ pub async fn update_family_member(
|
||||
|
||||
let kek = state.crypto.kek();
|
||||
let hmac_key = state.crypto.hmac_key();
|
||||
|
||||
// 记录变更前的关键字段(phone 为加密值,不记录原文)
|
||||
let old_values = serde_json::json!({
|
||||
"name": model.name,
|
||||
"relationship": model.relationship,
|
||||
"birth_date": model.birth_date,
|
||||
"notes": model.notes,
|
||||
});
|
||||
|
||||
let mut active: patient_family_member::ActiveModel = model.into();
|
||||
active.name = Set(req.name);
|
||||
active.relationship = Set(req.relationship);
|
||||
@@ -621,9 +630,18 @@ pub async fn update_family_member(
|
||||
|
||||
let updated = active.update(&state.db).await?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"name": updated.name,
|
||||
"relationship": updated.relationship,
|
||||
"birth_date": updated.birth_date,
|
||||
"notes": updated.notes,
|
||||
});
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "patient.family_member_updated", "patient_family_member")
|
||||
.with_resource_id(updated.id),
|
||||
.with_resource_id(updated.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
@@ -958,6 +976,13 @@ pub async fn update_tag(
|
||||
if tag.tenant_id != tenant_id { return Err(HealthError::TagNotFound); }
|
||||
check_version(req.version, tag.version)?;
|
||||
|
||||
// 记录变更前的关键字段
|
||||
let old_values = serde_json::json!({
|
||||
"name": tag.name,
|
||||
"color": tag.color,
|
||||
"description": tag.description,
|
||||
});
|
||||
|
||||
let mut active: patient_tag::ActiveModel = tag.into();
|
||||
if let Some(name) = req.name { active.name = Set(name); }
|
||||
if let Some(color) = req.color { active.color = Set(Some(color)); }
|
||||
@@ -969,9 +994,17 @@ pub async fn update_tag(
|
||||
let updated = active.update(&state.db).await
|
||||
.map_err(|e: sea_orm::DbErr| HealthError::DbError(e.to_string()))?;
|
||||
|
||||
// 变更后快照
|
||||
let new_values = serde_json::json!({
|
||||
"name": updated.name,
|
||||
"color": updated.color,
|
||||
"description": updated.description,
|
||||
});
|
||||
|
||||
audit_service::record(
|
||||
AuditLog::new(tenant_id, operator_id, "patient_tag.update", "patient_tag")
|
||||
.with_resource_id(updated.id),
|
||||
.with_resource_id(updated.id)
|
||||
.with_changes(Some(old_values), Some(new_values)),
|
||||
&state.db,
|
||||
).await;
|
||||
|
||||
|
||||
@@ -82,6 +82,9 @@ where
|
||||
)));
|
||||
}
|
||||
|
||||
// 校验 magic bytes:验证文件实际内容与声明的 Content-Type 一致
|
||||
validate_magic_bytes(&content_type, &data)?;
|
||||
|
||||
// 生成唯一文件名,保留原始扩展名
|
||||
let ext = std::path::Path::new(&original_name)
|
||||
.extension()
|
||||
@@ -137,6 +140,78 @@ fn validate_content_type(content_type: &str) -> Result<(), AppError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 校验文件 magic bytes(文件签名)与声明的 Content-Type 是否一致。
|
||||
///
|
||||
/// 防止攻击者通过修改 Content-Type 头上传恶意文件。
|
||||
/// 对于 Office 格式等复杂签名,跳过 magic bytes 校验(仅依赖白名单)。
|
||||
fn validate_magic_bytes(content_type: &str, data: &[u8]) -> Result<(), AppError> {
|
||||
// 需要至少几个字节才能校验
|
||||
if data.is_empty() {
|
||||
return Err(AppError::Validation("文件内容为空".to_string()));
|
||||
}
|
||||
|
||||
let signature: &[u8] = match content_type {
|
||||
"image/jpeg" => {
|
||||
// JPEG: FF D8 FF
|
||||
b"\xFF\xD8\xFF"
|
||||
}
|
||||
"image/png" => {
|
||||
// PNG: 89 50 4E 47 0D 0A 1A 0A
|
||||
b"\x89PNG\r\n\x1A\n"
|
||||
}
|
||||
"image/gif" => {
|
||||
// GIF: 47 49 46 38 (GIF8)
|
||||
b"GIF8"
|
||||
}
|
||||
"image/webp" => {
|
||||
// WebP: RIFF....WEBP (12 bytes)
|
||||
// 前 4 字节: 52 49 46 46 (RIFF)
|
||||
// 字节 8-11: 57 45 42 50 (WEBP)
|
||||
if data.len() < 12 {
|
||||
return Err(AppError::Validation(
|
||||
"文件数据不足,无法验证 WebP 格式".to_string(),
|
||||
));
|
||||
}
|
||||
let riff_ok = &data[0..4] == b"RIFF";
|
||||
let webp_ok = &data[8..12] == b"WEBP";
|
||||
if riff_ok && webp_ok {
|
||||
return Ok(());
|
||||
}
|
||||
return Err(AppError::Validation(
|
||||
"文件内容与声明的类型 (image/webp) 不匹配".to_string(),
|
||||
));
|
||||
}
|
||||
"application/pdf" => {
|
||||
// PDF: 25 50 44 46 (%PDF)
|
||||
b"%PDF"
|
||||
}
|
||||
// Office 格式的 magic bytes 较复杂(OLE2 / ZIP-based OOXML),
|
||||
// 仅依赖白名单,跳过 magic bytes 校验
|
||||
"application/msword"
|
||||
| "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
| "application/vnd.ms-excel"
|
||||
| "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" => {
|
||||
return Ok(());
|
||||
}
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
if data.len() < signature.len() {
|
||||
return Err(AppError::Validation(
|
||||
"文件数据不足,无法验证文件格式".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if &data[..signature.len()] != signature {
|
||||
return Err(AppError::Validation(format!(
|
||||
"文件内容与声明的类型 ({}) 不匹配",
|
||||
content_type
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_size(bytes: u64) -> String {
|
||||
if bytes >= 1024 * 1024 * 1024 {
|
||||
format!("{}GB", bytes / (1024 * 1024 * 1024))
|
||||
|
||||
@@ -487,6 +487,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
// with the jwt_auth_middleware_fn.
|
||||
|
||||
// Public routes (no authentication, but IP-based rate limiting)
|
||||
// Layer execution order (outer → inner): account_lockout → rate_limit_by_ip
|
||||
// So account lockout check runs FIRST, then IP rate limiting
|
||||
let public_routes = Router::new()
|
||||
.merge(handlers::health::health_check_router())
|
||||
.merge(erp_auth::AuthModule::public_routes())
|
||||
@@ -494,6 +496,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
"/docs/openapi.json",
|
||||
axum::routing::get(handlers::openapi::openapi_spec),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
middleware::rate_limit::account_lockout_middleware,
|
||||
))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
middleware::rate_limit::rate_limit_by_ip,
|
||||
|
||||
@@ -19,6 +19,10 @@ struct RateLimitResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
/// 账户锁定配置。
|
||||
const ACCOUNT_LOCKOUT_MAX_FAILURES: i64 = 5;
|
||||
const ACCOUNT_LOCKOUT_TTL_SECS: i64 = 900; // 15 分钟
|
||||
|
||||
/// 限流参数(预留配置化扩展)。
|
||||
#[allow(dead_code)]
|
||||
pub struct RateLimitConfig {
|
||||
@@ -162,6 +166,133 @@ async fn apply_rate_limit(
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// 账户级登录锁定中间件。
|
||||
///
|
||||
/// 针对登录接口(POST /api/v1/auth/login),在 IP 限流之前执行:
|
||||
/// 1. 解析请求体提取 username
|
||||
/// 2. 检查 Redis 中该 username 的失败次数
|
||||
/// 3. 超过阈值(5次)则拒绝请求
|
||||
/// 4. 观察响应状态码:401 递增失败计数,200 清除计数
|
||||
pub async fn account_lockout_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let avail = redis_avail();
|
||||
|
||||
// Redis 不可达时 fail-open:放行请求
|
||||
if !avail.should_try().await {
|
||||
tracing::warn!("Redis 不可达,fail-open 账户锁定检查放行");
|
||||
return next.run(req).await;
|
||||
}
|
||||
|
||||
// 获取 Redis 连接
|
||||
let mut conn = match state.redis.get_multiplexed_async_connection().await {
|
||||
Ok(c) => {
|
||||
avail.mark_ok();
|
||||
c
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Redis 连接失败,fail-open 账户锁定检查放行");
|
||||
avail.mark_failed().await;
|
||||
return next.run(req).await;
|
||||
}
|
||||
};
|
||||
|
||||
// 读取请求体以提取 username
|
||||
let (parts, body) = req.into_parts();
|
||||
let bytes = match axum::body::to_bytes(body, 1024).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "读取登录请求体失败,放行");
|
||||
// 无法读取 body,重建请求放行
|
||||
let req = Request::from_parts(parts, Body::from(Vec::new()));
|
||||
return next.run(req).await;
|
||||
}
|
||||
};
|
||||
|
||||
// 解析 username
|
||||
let username = serde_json::from_slice::<serde_json::Value>(&bytes)
|
||||
.ok()
|
||||
.and_then(|v| v.get("username")?.as_str().map(|s| s.to_string()));
|
||||
|
||||
let username = match username {
|
||||
Some(u) if !u.is_empty() => u,
|
||||
_ => {
|
||||
// 无法解析 username,用原始 body 重建请求放行
|
||||
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
|
||||
return next.run(req).await;
|
||||
}
|
||||
};
|
||||
|
||||
// 检查账户锁定状态
|
||||
let lockout_key = format!("login_fail:{}", username);
|
||||
let fail_count: i64 = conn.get(&lockout_key).await.unwrap_or(0);
|
||||
|
||||
if fail_count >= ACCOUNT_LOCKOUT_MAX_FAILURES {
|
||||
tracing::warn!(
|
||||
username = %username,
|
||||
fail_count = fail_count,
|
||||
"账户已被临时锁定"
|
||||
);
|
||||
let body = RateLimitResponse {
|
||||
error: "Too Many Requests".to_string(),
|
||||
message: "账户已被临时锁定,请15分钟后重试".to_string(),
|
||||
};
|
||||
return (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
|
||||
}
|
||||
|
||||
// 用原始 body 重建请求,转发到 handler
|
||||
let req = Request::from_parts(parts, Body::from(bytes.to_vec()));
|
||||
let response = next.run(req).await;
|
||||
|
||||
// 观察响应状态码
|
||||
let status = response.status();
|
||||
let (parts, body) = response.into_parts();
|
||||
|
||||
// 需要读取 body 以重建响应(因为 into_parts 消费了 body)
|
||||
let body_bytes = axum::body::to_bytes(body, 1024 * 1024)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if status == StatusCode::UNAUTHORIZED {
|
||||
// 登录失败:递增失败计数
|
||||
let new_count: i64 = match redis::cmd("INCR")
|
||||
.arg(&lockout_key)
|
||||
.query_async(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Redis INCR 失败计数失败");
|
||||
// 即使计数失败,也返回原始 401 响应
|
||||
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
|
||||
return resp;
|
||||
}
|
||||
};
|
||||
|
||||
// 首次失败时设置 TTL
|
||||
if new_count == 1 {
|
||||
let _: Result<(), _> = conn.expire(&lockout_key, ACCOUNT_LOCKOUT_TTL_SECS).await;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
username = %username,
|
||||
fail_count = new_count,
|
||||
remaining = ACCOUNT_LOCKOUT_MAX_FAILURES - new_count,
|
||||
"登录失败,递增失败计数"
|
||||
);
|
||||
} else if status.is_success() {
|
||||
// 登录成功:清除失败计数
|
||||
let _: Result<(), _> = conn.del(&lockout_key).await;
|
||||
tracing::info!(username = %username, "登录成功,清除失败计数");
|
||||
}
|
||||
|
||||
// 重建并返回原始响应
|
||||
let resp = Response::from_parts(parts, Body::from(body_bytes.to_vec()));
|
||||
resp
|
||||
}
|
||||
|
||||
/// 从请求头中提取客户端 IP。
|
||||
fn extract_client_ip(headers: &axum::http::HeaderMap) -> String {
|
||||
headers
|
||||
|
||||
Reference in New Issue
Block a user