SSE/EventSource 无法设置自定义 Authorization 头,前端通过 ?token=xxx 传参。中间件现在优先读 Authorization 头,回退到 URL query parameter,修复 SSE 连接永远 401 的问题。
177 lines
6.1 KiB
Rust
177 lines
6.1 KiB
Rust
use axum::body::Body;
|
||
use axum::http::Request;
|
||
use axum::middleware::Next;
|
||
use axum::response::Response;
|
||
use erp_core::error::AppError;
|
||
use erp_core::request_info::REQUEST_INFO;
|
||
use erp_core::request_info::RequestInfo;
|
||
use erp_core::types::{DataScope, TenantContext};
|
||
|
||
use crate::service::token_service::TokenService;
|
||
|
||
/// JWT authentication middleware function.
|
||
///
|
||
/// Extracts the `Bearer` token from the `Authorization` header, validates it
|
||
/// using `TokenService::decode_token`, and injects a `TenantContext` into the
|
||
/// request extensions so downstream handlers can access tenant/user identity.
|
||
///
|
||
/// 同时提取请求的 IP 地址和 User-Agent,通过 task_local 传递给审计服务,
|
||
/// 使所有审计日志自动记录来源信息。
|
||
///
|
||
/// The `jwt_secret` parameter is passed explicitly by the server crate at
|
||
/// middleware construction time, avoiding any circular dependency between
|
||
/// erp-auth and erp-server.
|
||
///
|
||
/// When `db` is provided, the middleware queries `user_departments` to populate
|
||
/// `department_ids` in the `TenantContext`. If `db` is `None` or the query fails,
|
||
/// `department_ids` defaults to an empty list (equivalent to "all" data scope).
|
||
///
|
||
/// # Errors
|
||
///
|
||
/// Returns `AppError::Unauthorized` if:
|
||
/// - The `Authorization` header is missing
|
||
/// - The header value does not start with `"Bearer "`
|
||
/// - The token cannot be decoded or has expired
|
||
/// - The token type is not "access"
|
||
pub async fn jwt_auth_middleware_fn(
|
||
jwt_secret: String,
|
||
db: Option<sea_orm::DatabaseConnection>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Result<Response, AppError> {
|
||
// 优先从 Authorization 头提取 token;
|
||
// 回退到 URL query parameter ?token=xxx(SSE/EventSource 无法设置自定义头)
|
||
let token = req
|
||
.headers()
|
||
.get("Authorization")
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(|h| h.strip_prefix("Bearer "))
|
||
.map(String::from)
|
||
.or_else(|| {
|
||
req.uri().query().and_then(|q| {
|
||
q.split('&')
|
||
.find_map(|pair| pair.strip_prefix("token="))
|
||
.map(String::from)
|
||
})
|
||
})
|
||
.ok_or(AppError::Unauthorized)?;
|
||
|
||
let claims =
|
||
TokenService::decode_token(&token, &jwt_secret).map_err(|_| AppError::Unauthorized)?;
|
||
|
||
// Verify this is an access token, not a refresh token
|
||
if claims.token_type != "access" {
|
||
return Err(AppError::Unauthorized);
|
||
}
|
||
|
||
// 查询用户所属部门 ID 列表
|
||
let department_ids = match &db {
|
||
Some(conn) => fetch_user_department_ids(claims.sub, claims.tid, conn).await,
|
||
None => vec![],
|
||
};
|
||
|
||
// 查询每个权限的数据范围
|
||
let permission_data_scopes = match &db {
|
||
Some(conn) => fetch_permission_data_scopes(claims.sub, claims.tid, conn).await,
|
||
None => std::collections::HashMap::new(),
|
||
};
|
||
|
||
// 提取请求来源信息(IP + User-Agent),用于审计日志
|
||
let request_info = RequestInfo::from_headers(req.headers());
|
||
|
||
let ctx = TenantContext {
|
||
tenant_id: claims.tid,
|
||
user_id: claims.sub,
|
||
roles: claims.roles,
|
||
permissions: claims.permissions,
|
||
department_ids,
|
||
permission_data_scopes,
|
||
};
|
||
|
||
// Reconstruct the request with the TenantContext injected into extensions.
|
||
// We cannot borrow `req` mutably after reading headers, so we rebuild.
|
||
let (parts, body) = req.into_parts();
|
||
let mut req = Request::from_parts(parts, body);
|
||
req.extensions_mut().insert(ctx);
|
||
|
||
// 在 task_local scope 中运行后续处理,审计服务可自动读取请求信息
|
||
Ok(REQUEST_INFO.scope(request_info, next.run(req)).await)
|
||
}
|
||
|
||
/// 查询用户所属的所有部门 ID(通过 user_departments 关联表)
|
||
async fn fetch_user_department_ids(
|
||
user_id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
db: &sea_orm::DatabaseConnection,
|
||
) -> Vec<uuid::Uuid> {
|
||
use crate::entity::user_department;
|
||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||
|
||
user_department::Entity::find()
|
||
.filter(user_department::Column::UserId.eq(user_id))
|
||
.filter(user_department::Column::TenantId.eq(tenant_id))
|
||
.filter(user_department::Column::DeletedAt.is_null())
|
||
.all(db)
|
||
.await
|
||
.map(|rows| rows.into_iter().map(|r| r.department_id).collect())
|
||
.unwrap_or_else(|e| {
|
||
tracing::warn!(error = %e, "查询用户部门列表失败,默认为空");
|
||
vec![]
|
||
})
|
||
}
|
||
|
||
/// 查询用户每个权限的数据范围(从 role_permissions 表)
|
||
async fn fetch_permission_data_scopes(
|
||
user_id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
db: &sea_orm::DatabaseConnection,
|
||
) -> std::collections::HashMap<String, DataScope> {
|
||
use sea_orm::ConnectionTrait;
|
||
|
||
let sql = r#"
|
||
SELECT p.code, MIN(
|
||
CASE rp.data_scope
|
||
WHEN 'all' THEN 0
|
||
WHEN 'department_tree' THEN 1
|
||
WHEN 'department' THEN 2
|
||
WHEN 'self' THEN 3
|
||
ELSE 0
|
||
END
|
||
) AS scope_rank,
|
||
MIN(rp.data_scope) AS data_scope
|
||
FROM user_roles ur
|
||
JOIN role_permissions rp ON ur.role_id = rp.role_id AND ur.tenant_id = rp.tenant_id
|
||
JOIN permissions p ON rp.permission_id = p.id
|
||
WHERE ur.user_id = $1
|
||
AND ur.tenant_id = $2
|
||
AND ur.deleted_at IS NULL
|
||
AND rp.deleted_at IS NULL
|
||
GROUP BY p.code
|
||
"#;
|
||
|
||
let stmt = sea_orm::Statement::from_sql_and_values(
|
||
sea_orm::DatabaseBackend::Postgres,
|
||
sql,
|
||
[user_id.into(), tenant_id.into()],
|
||
);
|
||
|
||
match db.query_all(stmt).await {
|
||
Ok(rows) => {
|
||
let mut scopes = std::collections::HashMap::new();
|
||
for row in rows {
|
||
if let (Ok(code), Ok(scope)) = (
|
||
row.try_get_by_index::<String>(0),
|
||
row.try_get_by_index::<String>(2),
|
||
) {
|
||
scopes.insert(code, DataScope::from_str(&scope));
|
||
}
|
||
}
|
||
scopes
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "查询权限数据范围失败,默认全部 All");
|
||
std::collections::HashMap::new()
|
||
}
|
||
}
|
||
}
|