use axum::body::Body; use axum::http::Request; use axum::middleware::Next; use axum::response::Response; use erp_core::error::AppError; use erp_core::request_info::REQUEST_INFO; use erp_core::request_info::RequestInfo; use erp_core::types::{DataScope, TenantContext}; use crate::service::token_service::TokenService; /// JWT authentication middleware function. /// /// Extracts the `Bearer` token from the `Authorization` header, validates it /// using `TokenService::decode_token`, and injects a `TenantContext` into the /// request extensions so downstream handlers can access tenant/user identity. /// /// 同时提取请求的 IP 地址和 User-Agent,通过 task_local 传递给审计服务, /// 使所有审计日志自动记录来源信息。 /// /// The `jwt_secret` parameter is passed explicitly by the server crate at /// middleware construction time, avoiding any circular dependency between /// erp-auth and erp-server. /// /// When `db` is provided, the middleware queries `user_departments` to populate /// `department_ids` in the `TenantContext`. If `db` is `None` or the query fails, /// `department_ids` defaults to an empty list (equivalent to "all" data scope). /// /// # Errors /// /// Returns `AppError::Unauthorized` if: /// - The `Authorization` header is missing /// - The header value does not start with `"Bearer "` /// - The token cannot be decoded or has expired /// - The token type is not "access" pub async fn jwt_auth_middleware_fn( jwt_secret: String, db: Option, req: Request, next: Next, ) -> Result { // 优先从 Authorization 头提取 token; // 回退到 URL query parameter ?token=xxx(SSE/EventSource 无法设置自定义头) let token = req .headers() .get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|h| h.strip_prefix("Bearer ")) .map(String::from) .or_else(|| { req.uri().query().and_then(|q| { q.split('&') .find_map(|pair| pair.strip_prefix("token=")) .map(String::from) }) }) .ok_or(AppError::Unauthorized)?; let claims = TokenService::decode_token(&token, &jwt_secret).map_err(|_| AppError::Unauthorized)?; // Verify this is an access token, not a refresh token if claims.token_type != "access" { return Err(AppError::Unauthorized); } // 查询用户所属部门 ID 列表 let department_ids = match &db { Some(conn) => fetch_user_department_ids(claims.sub, claims.tid, conn).await, None => vec![], }; // 查询每个权限的数据范围 let permission_data_scopes = match &db { Some(conn) => fetch_permission_data_scopes(claims.sub, claims.tid, conn).await, None => std::collections::HashMap::new(), }; // 提取请求来源信息(IP + User-Agent),用于审计日志 let request_info = RequestInfo::from_headers(req.headers()); let ctx = TenantContext { tenant_id: claims.tid, user_id: claims.sub, roles: claims.roles, permissions: claims.permissions, department_ids, permission_data_scopes, }; // Reconstruct the request with the TenantContext injected into extensions. // We cannot borrow `req` mutably after reading headers, so we rebuild. let (parts, body) = req.into_parts(); let mut req = Request::from_parts(parts, body); req.extensions_mut().insert(ctx); // 在 task_local scope 中运行后续处理,审计服务可自动读取请求信息 Ok(REQUEST_INFO.scope(request_info, next.run(req)).await) } /// 查询用户所属的所有部门 ID(通过 user_departments 关联表) async fn fetch_user_department_ids( user_id: uuid::Uuid, tenant_id: uuid::Uuid, db: &sea_orm::DatabaseConnection, ) -> Vec { use crate::entity::user_department; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; user_department::Entity::find() .filter(user_department::Column::UserId.eq(user_id)) .filter(user_department::Column::TenantId.eq(tenant_id)) .filter(user_department::Column::DeletedAt.is_null()) .all(db) .await .map(|rows| rows.into_iter().map(|r| r.department_id).collect()) .unwrap_or_else(|e| { tracing::warn!(error = %e, "查询用户部门列表失败,默认为空"); vec![] }) } /// 查询用户每个权限的数据范围(从 role_permissions 表) async fn fetch_permission_data_scopes( user_id: uuid::Uuid, tenant_id: uuid::Uuid, db: &sea_orm::DatabaseConnection, ) -> std::collections::HashMap { use sea_orm::ConnectionTrait; let sql = r#" SELECT p.code, MIN( CASE rp.data_scope WHEN 'all' THEN 0 WHEN 'department_tree' THEN 1 WHEN 'department' THEN 2 WHEN 'self' THEN 3 ELSE 0 END ) AS scope_rank, MIN(rp.data_scope) AS data_scope FROM user_roles ur JOIN role_permissions rp ON ur.role_id = rp.role_id AND ur.tenant_id = rp.tenant_id JOIN permissions p ON rp.permission_id = p.id WHERE ur.user_id = $1 AND ur.tenant_id = $2 AND ur.deleted_at IS NULL AND rp.deleted_at IS NULL GROUP BY p.code "#; let stmt = sea_orm::Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, sql, [user_id.into(), tenant_id.into()], ); match db.query_all(stmt).await { Ok(rows) => { let mut scopes = std::collections::HashMap::new(); for row in rows { if let (Ok(code), Ok(scope)) = ( row.try_get_by_index::(0), row.try_get_by_index::(2), ) { scopes.insert(code, DataScope::from_str(&scope)); } } scopes } Err(e) => { tracing::warn!(error = %e, "查询权限数据范围失败,默认全部 All"); std::collections::HashMap::new() } } }