Files
hms/crates/erp-auth/src/middleware/jwt_auth.rs
iven 988f6cd6a5
Some checks failed
CI / frontend-build (push) Has been cancelled
CI / rust-check (push) Has been cancelled
CI / rust-test (push) Has been cancelled
CI / security-audit (push) Has been cancelled
fix(auth): JWT 中间件支持 query parameter token 回退
SSE/EventSource 无法设置自定义 Authorization 头,前端通过
?token=xxx 传参。中间件现在优先读 Authorization 头,回退到
URL query parameter,修复 SSE 连接永远 401 的问题。
2026-04-28 11:23:53 +08:00

177 lines
6.1 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::body::Body;
use axum::http::Request;
use axum::middleware::Next;
use axum::response::Response;
use erp_core::error::AppError;
use erp_core::request_info::REQUEST_INFO;
use erp_core::request_info::RequestInfo;
use erp_core::types::{DataScope, TenantContext};
use crate::service::token_service::TokenService;
/// JWT authentication middleware function.
///
/// Extracts the `Bearer` token from the `Authorization` header, validates it
/// using `TokenService::decode_token`, and injects a `TenantContext` into the
/// request extensions so downstream handlers can access tenant/user identity.
///
/// 同时提取请求的 IP 地址和 User-Agent通过 task_local 传递给审计服务,
/// 使所有审计日志自动记录来源信息。
///
/// The `jwt_secret` parameter is passed explicitly by the server crate at
/// middleware construction time, avoiding any circular dependency between
/// erp-auth and erp-server.
///
/// When `db` is provided, the middleware queries `user_departments` to populate
/// `department_ids` in the `TenantContext`. If `db` is `None` or the query fails,
/// `department_ids` defaults to an empty list (equivalent to "all" data scope).
///
/// # Errors
///
/// Returns `AppError::Unauthorized` if:
/// - The `Authorization` header is missing
/// - The header value does not start with `"Bearer "`
/// - The token cannot be decoded or has expired
/// - The token type is not "access"
pub async fn jwt_auth_middleware_fn(
jwt_secret: String,
db: Option<sea_orm::DatabaseConnection>,
req: Request<Body>,
next: Next,
) -> Result<Response, AppError> {
// 优先从 Authorization 头提取 token
// 回退到 URL query parameter ?token=xxxSSE/EventSource 无法设置自定义头)
let token = req
.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
.map(String::from)
.or_else(|| {
req.uri().query().and_then(|q| {
q.split('&')
.find_map(|pair| pair.strip_prefix("token="))
.map(String::from)
})
})
.ok_or(AppError::Unauthorized)?;
let claims =
TokenService::decode_token(&token, &jwt_secret).map_err(|_| AppError::Unauthorized)?;
// Verify this is an access token, not a refresh token
if claims.token_type != "access" {
return Err(AppError::Unauthorized);
}
// 查询用户所属部门 ID 列表
let department_ids = match &db {
Some(conn) => fetch_user_department_ids(claims.sub, claims.tid, conn).await,
None => vec![],
};
// 查询每个权限的数据范围
let permission_data_scopes = match &db {
Some(conn) => fetch_permission_data_scopes(claims.sub, claims.tid, conn).await,
None => std::collections::HashMap::new(),
};
// 提取请求来源信息IP + User-Agent用于审计日志
let request_info = RequestInfo::from_headers(req.headers());
let ctx = TenantContext {
tenant_id: claims.tid,
user_id: claims.sub,
roles: claims.roles,
permissions: claims.permissions,
department_ids,
permission_data_scopes,
};
// Reconstruct the request with the TenantContext injected into extensions.
// We cannot borrow `req` mutably after reading headers, so we rebuild.
let (parts, body) = req.into_parts();
let mut req = Request::from_parts(parts, body);
req.extensions_mut().insert(ctx);
// 在 task_local scope 中运行后续处理,审计服务可自动读取请求信息
Ok(REQUEST_INFO.scope(request_info, next.run(req)).await)
}
/// 查询用户所属的所有部门 ID通过 user_departments 关联表)
async fn fetch_user_department_ids(
user_id: uuid::Uuid,
tenant_id: uuid::Uuid,
db: &sea_orm::DatabaseConnection,
) -> Vec<uuid::Uuid> {
use crate::entity::user_department;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
user_department::Entity::find()
.filter(user_department::Column::UserId.eq(user_id))
.filter(user_department::Column::TenantId.eq(tenant_id))
.filter(user_department::Column::DeletedAt.is_null())
.all(db)
.await
.map(|rows| rows.into_iter().map(|r| r.department_id).collect())
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "查询用户部门列表失败,默认为空");
vec![]
})
}
/// 查询用户每个权限的数据范围(从 role_permissions 表)
async fn fetch_permission_data_scopes(
user_id: uuid::Uuid,
tenant_id: uuid::Uuid,
db: &sea_orm::DatabaseConnection,
) -> std::collections::HashMap<String, DataScope> {
use sea_orm::ConnectionTrait;
let sql = r#"
SELECT p.code, MIN(
CASE rp.data_scope
WHEN 'all' THEN 0
WHEN 'department_tree' THEN 1
WHEN 'department' THEN 2
WHEN 'self' THEN 3
ELSE 0
END
) AS scope_rank,
MIN(rp.data_scope) AS data_scope
FROM user_roles ur
JOIN role_permissions rp ON ur.role_id = rp.role_id AND ur.tenant_id = rp.tenant_id
JOIN permissions p ON rp.permission_id = p.id
WHERE ur.user_id = $1
AND ur.tenant_id = $2
AND ur.deleted_at IS NULL
AND rp.deleted_at IS NULL
GROUP BY p.code
"#;
let stmt = sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[user_id.into(), tenant_id.into()],
);
match db.query_all(stmt).await {
Ok(rows) => {
let mut scopes = std::collections::HashMap::new();
for row in rows {
if let (Ok(code), Ok(scope)) = (
row.try_get_by_index::<String>(0),
row.try_get_by_index::<String>(2),
) {
scopes.insert(code, DataScope::from_str(&scope));
}
}
scopes
}
Err(e) => {
tracing::warn!(error = %e, "查询权限数据范围失败,默认全部 All");
std::collections::HashMap::new()
}
}
}