Files
nj/crates/erp-auth/src/middleware/jwt_auth.rs
iven c539e6fd83 feat: initialize Nuanji (Warm Notes) project
- Base platform from base.git (ERP base: auth, core, config, message, workflow, plugin)
- Created erp-diary module skeleton (lib.rs, dto.rs, error.rs, event.rs, state.rs)
- Integrated erp-diary into workspace and erp-server
- Added DiaryModule registration in main.rs
- Added DiaryState FromRef in state.rs
- Diary routes mounted (empty routes, ready for implementation)
- Product design spec v1.2 preserved in docs/
- Implementation plan preserved in plans/

Cargo check: OK
Cargo test: OK (78+ base tests passing)
2026-05-31 20:52:19 +08:00

274 lines
9.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::body::Body;
use axum::http::Request;
use axum::middleware::Next;
use axum::response::Response;
use dashmap::DashMap;
use erp_core::error::AppError;
use erp_core::request_info::REQUEST_INFO;
use erp_core::request_info::RequestInfo;
use erp_core::types::{DataScope, TenantContext};
use crate::service::token_service::TokenService;
type DeptIds = Vec<uuid::Uuid>;
type DataScopes = std::collections::HashMap<String, DataScope>;
type ScopeCacheEntry = (DeptIds, DataScopes, std::time::Instant);
/// 用户权限数据缓存user_id -> (department_ids, data_scopes, cached_at)
/// DashMap 分片并发,读写无锁竞争
static USER_SCOPE_CACHE: std::sync::LazyLock<DashMap<uuid::Uuid, ScopeCacheEntry>> =
std::sync::LazyLock::new(DashMap::new);
/// Access Token 吊销黑名单token_hash -> 过期时间戳)
/// key = SHA-256(token) 前 16 字符value = token 的 exp 时间戳
/// 惰性清理:检查时自动移除过期条目
static TOKEN_BLACKLIST: std::sync::LazyLock<DashMap<String, i64>> =
std::sync::LazyLock::new(DashMap::new);
const SCOPE_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(60);
/// 吊销单个 access token直到其自然过期
pub fn revoke_access_token(token: &str, exp: i64) {
let hash = token_hash(token);
TOKEN_BLACKLIST.insert(hash, exp);
}
/// 吊销用户所有 token清除权限缓存强制下次请求重新认证
pub fn revoke_all_user_tokens(user_id: uuid::Uuid) {
USER_SCOPE_CACHE.remove(&user_id);
}
/// 检查 token 是否已被吊销
fn is_token_revoked(token: &str, _exp: i64) -> bool {
let now = chrono::Utc::now().timestamp();
// 惰性清理过期条目
if TOKEN_BLACKLIST.len() > 10_000 {
TOKEN_BLACKLIST.retain(|_, exp_ts| *exp_ts > now);
}
let hash = token_hash(token);
match TOKEN_BLACKLIST.get(&hash) {
Some(exp_ts) => {
if *exp_ts <= now {
drop(exp_ts);
TOKEN_BLACKLIST.remove(&hash);
false
} else {
true
}
}
None => false,
}
}
fn token_hash(token: &str) -> String {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
token.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
/// JWT authentication middleware function.
///
/// Extracts the `Bearer` token from the `Authorization` header, validates it
/// using `TokenService::decode_token`, and injects a `TenantContext` into the
/// request extensions so downstream handlers can access tenant/user identity.
///
/// 同时提取请求的 IP 地址和 User-Agent通过 task_local 传递给审计服务,
/// 使所有审计日志自动记录来源信息。
///
/// The `jwt_secret` parameter is passed explicitly by the server crate at
/// middleware construction time, avoiding any circular dependency between
/// erp-auth and erp-server.
///
/// When `db` is provided, the middleware queries `user_departments` to populate
/// `department_ids` in the `TenantContext`. If `db` is `None` or the query fails,
/// `department_ids` defaults to an empty list (equivalent to "all" data scope).
///
/// # Errors
///
/// Returns `AppError::Unauthorized` if:
/// - The `Authorization` header is missing
/// - The header value does not start with `"Bearer "`
/// - The token cannot be decoded or has expired
/// - The token type is not "access"
pub async fn jwt_auth_middleware_fn(
jwt_secret: String,
db: Option<sea_orm::DatabaseConnection>,
req: Request<Body>,
next: Next,
) -> Result<Response, AppError> {
// 优先从 Authorization 头提取 token
// 回退到 URL query parameter ?token=xxxSSE/EventSource 无法设置自定义头)
let token = req
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
.map(String::from)
.or_else(|| {
req.uri().query().and_then(|q| {
q.split('&')
.find_map(|pair| pair.strip_prefix("token="))
.map(String::from)
})
})
.ok_or(AppError::Unauthorized)?;
let claims =
TokenService::decode_token(&token, &jwt_secret).map_err(|_| AppError::Unauthorized)?;
// 检查 token 是否已被吊销(密码修改/管理员强制下线)
if is_token_revoked(&token, claims.exp) {
return Err(AppError::Unauthorized);
}
// Verify this is an access token, not a refresh token
if claims.token_type != "access" {
return Err(AppError::Unauthorized);
}
// 查询用户所属部门 ID 列表 + 权限数据范围(带 60 秒缓存)
let cached = USER_SCOPE_CACHE.get(&claims.sub).and_then(|entry| {
let (_, _, at) = entry.value();
if at.elapsed() < SCOPE_CACHE_TTL {
let (depts, scopes, _) = entry.value();
Some((depts.clone(), scopes.clone()))
} else {
drop(entry);
USER_SCOPE_CACHE.remove(&claims.sub);
None
}
});
let (department_ids, permission_data_scopes) = match cached {
Some(hit) => hit,
None => fetch_and_cache_scopes(claims.sub, claims.tid, &db).await,
};
// 提取请求来源信息IP + User-Agent用于审计日志
let request_info = RequestInfo::from_headers(req.headers());
let ctx = TenantContext {
tenant_id: claims.tid,
user_id: claims.sub,
roles: claims.roles,
permissions: claims.permissions,
department_ids,
permission_data_scopes,
};
// Reconstruct the request with the TenantContext injected into extensions.
// We cannot borrow `req` mutably after reading headers, so we rebuild.
let (parts, body) = req.into_parts();
let mut req = Request::from_parts(parts, body);
req.extensions_mut().insert(ctx);
// 在 task_local scope 中运行后续处理,审计服务可自动读取请求信息
Ok(REQUEST_INFO.scope(request_info, next.run(req)).await)
}
/// 查询用户所属的所有部门 ID通过 user_departments 关联表)
async fn fetch_user_department_ids(
user_id: uuid::Uuid,
tenant_id: uuid::Uuid,
db: &sea_orm::DatabaseConnection,
) -> Vec<uuid::Uuid> {
use crate::entity::user_department;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
user_department::Entity::find()
.filter(user_department::Column::UserId.eq(user_id))
.filter(user_department::Column::TenantId.eq(tenant_id))
.filter(user_department::Column::DeletedAt.is_null())
.all(db)
.await
.map(|rows| rows.into_iter().map(|r| r.department_id).collect())
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "查询用户部门列表失败,默认为空");
vec![]
})
}
/// 查询用户每个权限的数据范围(从 role_permissions 表)
async fn fetch_permission_data_scopes(
user_id: uuid::Uuid,
tenant_id: uuid::Uuid,
db: &sea_orm::DatabaseConnection,
) -> std::collections::HashMap<String, DataScope> {
use sea_orm::ConnectionTrait;
let sql = r#"
SELECT p.code, MIN(
CASE rp.data_scope
WHEN 'all' THEN 0
WHEN 'department_tree' THEN 1
WHEN 'department' THEN 2
WHEN 'self' THEN 3
ELSE 0
END
) AS scope_rank,
MIN(rp.data_scope) AS data_scope
FROM user_roles ur
JOIN role_permissions rp ON ur.role_id = rp.role_id AND ur.tenant_id = rp.tenant_id
JOIN permissions p ON rp.permission_id = p.id
WHERE ur.user_id = $1
AND ur.tenant_id = $2
AND ur.deleted_at IS NULL
AND rp.deleted_at IS NULL
GROUP BY p.code
"#;
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[user_id.into(), tenant_id.into()],
);
match db.query_all(stmt).await {
Ok(rows) => {
let mut scopes = std::collections::HashMap::new();
for row in rows {
if let (Ok(code), Ok(scope)) = (
row.try_get_by_index::<String>(0),
row.try_get_by_index::<String>(2),
) {
scopes.insert(code, DataScope::parse_scope(&scope));
}
}
scopes
}
Err(e) => {
tracing::warn!(error = %e, "查询权限数据范围失败,默认全部 All");
std::collections::HashMap::new()
}
}
}
/// 从 DB 查询部门 + 权限范围,并写入缓存
async fn fetch_and_cache_scopes(
user_id: uuid::Uuid,
tenant_id: uuid::Uuid,
db: &Option<sea_orm::DatabaseConnection>,
) -> (
Vec<uuid::Uuid>,
std::collections::HashMap<String, DataScope>,
) {
let depts = match db {
Some(conn) => fetch_user_department_ids(user_id, tenant_id, conn).await,
None => vec![],
};
let scopes = match db {
Some(conn) => fetch_permission_data_scopes(user_id, tenant_id, conn).await,
None => std::collections::HashMap::new(),
};
USER_SCOPE_CACHE.insert(
user_id,
(depts.clone(), scopes.clone(), std::time::Instant::now()),
);
// 惰性淘汰过期条目,防止 DashMap 无限增长
if USER_SCOPE_CACHE.len() > 500 {
let now = std::time::Instant::now();
USER_SCOPE_CACHE.retain(|_, (_, _, at)| now.duration_since(*at) < SCOPE_CACHE_TTL);
}
(depts, scopes)
}