use axum::body::Body; use axum::http::Request; use axum::middleware::Next; use axum::response::Response; use dashmap::DashMap; use erp_core::error::AppError; use erp_core::request_info::REQUEST_INFO; use erp_core::request_info::RequestInfo; use erp_core::types::{DataScope, TenantContext}; use crate::service::token_service::TokenService; type DeptIds = Vec; type DataScopes = std::collections::HashMap; type ScopeCacheEntry = (DeptIds, DataScopes, std::time::Instant); /// 用户权限数据缓存(user_id -> (department_ids, data_scopes, cached_at)) /// DashMap 分片并发,读写无锁竞争 static USER_SCOPE_CACHE: std::sync::LazyLock> = std::sync::LazyLock::new(DashMap::new); /// Access Token 吊销黑名单(token_hash -> 过期时间戳) /// key = SHA-256(token) 前 16 字符,value = token 的 exp 时间戳 /// 惰性清理:检查时自动移除过期条目 static TOKEN_BLACKLIST: std::sync::LazyLock> = std::sync::LazyLock::new(DashMap::new); const SCOPE_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(60); /// 吊销单个 access token(直到其自然过期) pub fn revoke_access_token(token: &str, exp: i64) { let hash = token_hash(token); TOKEN_BLACKLIST.insert(hash, exp); } /// 吊销用户所有 token(清除权限缓存,强制下次请求重新认证) pub fn revoke_all_user_tokens(user_id: uuid::Uuid) { USER_SCOPE_CACHE.remove(&user_id); } /// 检查 token 是否已被吊销 fn is_token_revoked(token: &str, _exp: i64) -> bool { let now = chrono::Utc::now().timestamp(); // 惰性清理过期条目 if TOKEN_BLACKLIST.len() > 10_000 { TOKEN_BLACKLIST.retain(|_, exp_ts| *exp_ts > now); } let hash = token_hash(token); match TOKEN_BLACKLIST.get(&hash) { Some(exp_ts) => { if *exp_ts <= now { drop(exp_ts); TOKEN_BLACKLIST.remove(&hash); false } else { true } } None => false, } } fn token_hash(token: &str) -> String { use std::hash::{Hash, Hasher}; let mut hasher = std::collections::hash_map::DefaultHasher::new(); token.hash(&mut hasher); format!("{:016x}", hasher.finish()) } /// JWT authentication middleware function. /// /// Extracts the `Bearer` token from the `Authorization` header, validates it /// using `TokenService::decode_token`, and injects a `TenantContext` into the /// request extensions so downstream handlers can access tenant/user identity. /// /// 同时提取请求的 IP 地址和 User-Agent,通过 task_local 传递给审计服务, /// 使所有审计日志自动记录来源信息。 /// /// The `jwt_secret` parameter is passed explicitly by the server crate at /// middleware construction time, avoiding any circular dependency between /// erp-auth and erp-server. /// /// When `db` is provided, the middleware queries `user_departments` to populate /// `department_ids` in the `TenantContext`. If `db` is `None` or the query fails, /// `department_ids` defaults to an empty list (equivalent to "all" data scope). /// /// # Errors /// /// Returns `AppError::Unauthorized` if: /// - The `Authorization` header is missing /// - The header value does not start with `"Bearer "` /// - The token cannot be decoded or has expired /// - The token type is not "access" pub async fn jwt_auth_middleware_fn( jwt_secret: String, db: Option, req: Request, next: Next, ) -> Result { // 优先从 Authorization 头提取 token; // 回退到 URL query parameter ?token=xxx(SSE/EventSource 无法设置自定义头) let token = req .headers() .get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|h| h.strip_prefix("Bearer ")) .map(String::from) .or_else(|| { req.uri().query().and_then(|q| { q.split('&') .find_map(|pair| pair.strip_prefix("token=")) .map(String::from) }) }) .ok_or(AppError::Unauthorized)?; let claims = TokenService::decode_token(&token, &jwt_secret).map_err(|_| AppError::Unauthorized)?; // 检查 token 是否已被吊销(密码修改/管理员强制下线) if is_token_revoked(&token, claims.exp) { return Err(AppError::Unauthorized); } // Verify this is an access token, not a refresh token if claims.token_type != "access" { return Err(AppError::Unauthorized); } // 查询用户所属部门 ID 列表 + 权限数据范围(带 60 秒缓存) let cached = USER_SCOPE_CACHE.get(&claims.sub).and_then(|entry| { let (_, _, at) = entry.value(); if at.elapsed() < SCOPE_CACHE_TTL { let (depts, scopes, _) = entry.value(); Some((depts.clone(), scopes.clone())) } else { drop(entry); USER_SCOPE_CACHE.remove(&claims.sub); None } }); let (department_ids, permission_data_scopes) = match cached { Some(hit) => hit, None => fetch_and_cache_scopes(claims.sub, claims.tid, &db).await, }; // 提取请求来源信息(IP + User-Agent),用于审计日志 let request_info = RequestInfo::from_headers(req.headers()); let ctx = TenantContext { tenant_id: claims.tid, user_id: claims.sub, roles: claims.roles, permissions: claims.permissions, department_ids, permission_data_scopes, }; // Reconstruct the request with the TenantContext injected into extensions. // We cannot borrow `req` mutably after reading headers, so we rebuild. let (parts, body) = req.into_parts(); let mut req = Request::from_parts(parts, body); req.extensions_mut().insert(ctx); // 在 task_local scope 中运行后续处理,审计服务可自动读取请求信息 Ok(REQUEST_INFO.scope(request_info, next.run(req)).await) } /// 查询用户所属的所有部门 ID(通过 user_departments 关联表) async fn fetch_user_department_ids( user_id: uuid::Uuid, tenant_id: uuid::Uuid, db: &sea_orm::DatabaseConnection, ) -> Vec { use crate::entity::user_department; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; user_department::Entity::find() .filter(user_department::Column::UserId.eq(user_id)) .filter(user_department::Column::TenantId.eq(tenant_id)) .filter(user_department::Column::DeletedAt.is_null()) .all(db) .await .map(|rows| rows.into_iter().map(|r| r.department_id).collect()) .unwrap_or_else(|e| { tracing::warn!(error = %e, "查询用户部门列表失败,默认为空"); vec![] }) } /// 查询用户每个权限的数据范围(从 role_permissions 表) async fn fetch_permission_data_scopes( user_id: uuid::Uuid, tenant_id: uuid::Uuid, db: &sea_orm::DatabaseConnection, ) -> std::collections::HashMap { use sea_orm::ConnectionTrait; let sql = r#" SELECT p.code, MIN( CASE rp.data_scope WHEN 'all' THEN 0 WHEN 'department_tree' THEN 1 WHEN 'department' THEN 2 WHEN 'self' THEN 3 ELSE 0 END ) AS scope_rank, MIN(rp.data_scope) AS data_scope FROM user_roles ur JOIN role_permissions rp ON ur.role_id = rp.role_id AND ur.tenant_id = rp.tenant_id JOIN permissions p ON rp.permission_id = p.id WHERE ur.user_id = $1 AND ur.tenant_id = $2 AND ur.deleted_at IS NULL AND rp.deleted_at IS NULL GROUP BY p.code "#; let stmt = sea_orm::Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, sql, [user_id.into(), tenant_id.into()], ); match db.query_all(stmt).await { Ok(rows) => { let mut scopes = std::collections::HashMap::new(); for row in rows { if let (Ok(code), Ok(scope)) = ( row.try_get_by_index::(0), row.try_get_by_index::(2), ) { scopes.insert(code, DataScope::parse_scope(&scope)); } } scopes } Err(e) => { tracing::warn!(error = %e, "查询权限数据范围失败,默认全部 All"); std::collections::HashMap::new() } } } /// 从 DB 查询部门 + 权限范围,并写入缓存 async fn fetch_and_cache_scopes( user_id: uuid::Uuid, tenant_id: uuid::Uuid, db: &Option, ) -> ( Vec, std::collections::HashMap, ) { let depts = match db { Some(conn) => fetch_user_department_ids(user_id, tenant_id, conn).await, None => vec![], }; let scopes = match db { Some(conn) => fetch_permission_data_scopes(user_id, tenant_id, conn).await, None => std::collections::HashMap::new(), }; USER_SCOPE_CACHE.insert( user_id, (depts.clone(), scopes.clone(), std::time::Instant::now()), ); // 惰性淘汰过期条目,防止 DashMap 无限增长 if USER_SCOPE_CACHE.len() > 500 { let now = std::time::Instant::now(); USER_SCOPE_CACHE.retain(|_, (_, _, at)| now.duration_since(*at) < SCOPE_CACHE_TTL); } (depts, scopes) }