- Base platform from base.git (ERP base: auth, core, config, message, workflow, plugin) - Created erp-diary module skeleton (lib.rs, dto.rs, error.rs, event.rs, state.rs) - Integrated erp-diary into workspace and erp-server - Added DiaryModule registration in main.rs - Added DiaryState FromRef in state.rs - Diary routes mounted (empty routes, ready for implementation) - Product design spec v1.2 preserved in docs/ - Implementation plan preserved in plans/ Cargo check: OK Cargo test: OK (78+ base tests passing)
274 lines
9.4 KiB
Rust
274 lines
9.4 KiB
Rust
use axum::body::Body;
|
||
use axum::http::Request;
|
||
use axum::middleware::Next;
|
||
use axum::response::Response;
|
||
use dashmap::DashMap;
|
||
use erp_core::error::AppError;
|
||
use erp_core::request_info::REQUEST_INFO;
|
||
use erp_core::request_info::RequestInfo;
|
||
use erp_core::types::{DataScope, TenantContext};
|
||
|
||
use crate::service::token_service::TokenService;
|
||
|
||
type DeptIds = Vec<uuid::Uuid>;
|
||
type DataScopes = std::collections::HashMap<String, DataScope>;
|
||
type ScopeCacheEntry = (DeptIds, DataScopes, std::time::Instant);
|
||
|
||
/// 用户权限数据缓存(user_id -> (department_ids, data_scopes, cached_at))
|
||
/// DashMap 分片并发,读写无锁竞争
|
||
static USER_SCOPE_CACHE: std::sync::LazyLock<DashMap<uuid::Uuid, ScopeCacheEntry>> =
|
||
std::sync::LazyLock::new(DashMap::new);
|
||
|
||
/// Access Token 吊销黑名单(token_hash -> 过期时间戳)
|
||
/// key = SHA-256(token) 前 16 字符,value = token 的 exp 时间戳
|
||
/// 惰性清理:检查时自动移除过期条目
|
||
static TOKEN_BLACKLIST: std::sync::LazyLock<DashMap<String, i64>> =
|
||
std::sync::LazyLock::new(DashMap::new);
|
||
|
||
const SCOPE_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(60);
|
||
|
||
/// 吊销单个 access token(直到其自然过期)
|
||
pub fn revoke_access_token(token: &str, exp: i64) {
|
||
let hash = token_hash(token);
|
||
TOKEN_BLACKLIST.insert(hash, exp);
|
||
}
|
||
|
||
/// 吊销用户所有 token(清除权限缓存,强制下次请求重新认证)
|
||
pub fn revoke_all_user_tokens(user_id: uuid::Uuid) {
|
||
USER_SCOPE_CACHE.remove(&user_id);
|
||
}
|
||
|
||
/// 检查 token 是否已被吊销
|
||
fn is_token_revoked(token: &str, _exp: i64) -> bool {
|
||
let now = chrono::Utc::now().timestamp();
|
||
// 惰性清理过期条目
|
||
if TOKEN_BLACKLIST.len() > 10_000 {
|
||
TOKEN_BLACKLIST.retain(|_, exp_ts| *exp_ts > now);
|
||
}
|
||
let hash = token_hash(token);
|
||
match TOKEN_BLACKLIST.get(&hash) {
|
||
Some(exp_ts) => {
|
||
if *exp_ts <= now {
|
||
drop(exp_ts);
|
||
TOKEN_BLACKLIST.remove(&hash);
|
||
false
|
||
} else {
|
||
true
|
||
}
|
||
}
|
||
None => false,
|
||
}
|
||
}
|
||
|
||
fn token_hash(token: &str) -> String {
|
||
use std::hash::{Hash, Hasher};
|
||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||
token.hash(&mut hasher);
|
||
format!("{:016x}", hasher.finish())
|
||
}
|
||
|
||
/// JWT authentication middleware function.
|
||
///
|
||
/// Extracts the `Bearer` token from the `Authorization` header, validates it
|
||
/// using `TokenService::decode_token`, and injects a `TenantContext` into the
|
||
/// request extensions so downstream handlers can access tenant/user identity.
|
||
///
|
||
/// 同时提取请求的 IP 地址和 User-Agent,通过 task_local 传递给审计服务,
|
||
/// 使所有审计日志自动记录来源信息。
|
||
///
|
||
/// The `jwt_secret` parameter is passed explicitly by the server crate at
|
||
/// middleware construction time, avoiding any circular dependency between
|
||
/// erp-auth and erp-server.
|
||
///
|
||
/// When `db` is provided, the middleware queries `user_departments` to populate
|
||
/// `department_ids` in the `TenantContext`. If `db` is `None` or the query fails,
|
||
/// `department_ids` defaults to an empty list (equivalent to "all" data scope).
|
||
///
|
||
/// # Errors
|
||
///
|
||
/// Returns `AppError::Unauthorized` if:
|
||
/// - The `Authorization` header is missing
|
||
/// - The header value does not start with `"Bearer "`
|
||
/// - The token cannot be decoded or has expired
|
||
/// - The token type is not "access"
|
||
pub async fn jwt_auth_middleware_fn(
|
||
jwt_secret: String,
|
||
db: Option<sea_orm::DatabaseConnection>,
|
||
req: Request<Body>,
|
||
next: Next,
|
||
) -> Result<Response, AppError> {
|
||
// 优先从 Authorization 头提取 token;
|
||
// 回退到 URL query parameter ?token=xxx(SSE/EventSource 无法设置自定义头)
|
||
let token = req
|
||
.headers()
|
||
.get("Authorization")
|
||
.and_then(|v| v.to_str().ok())
|
||
.and_then(|h| h.strip_prefix("Bearer "))
|
||
.map(String::from)
|
||
.or_else(|| {
|
||
req.uri().query().and_then(|q| {
|
||
q.split('&')
|
||
.find_map(|pair| pair.strip_prefix("token="))
|
||
.map(String::from)
|
||
})
|
||
})
|
||
.ok_or(AppError::Unauthorized)?;
|
||
|
||
let claims =
|
||
TokenService::decode_token(&token, &jwt_secret).map_err(|_| AppError::Unauthorized)?;
|
||
|
||
// 检查 token 是否已被吊销(密码修改/管理员强制下线)
|
||
if is_token_revoked(&token, claims.exp) {
|
||
return Err(AppError::Unauthorized);
|
||
}
|
||
|
||
// Verify this is an access token, not a refresh token
|
||
if claims.token_type != "access" {
|
||
return Err(AppError::Unauthorized);
|
||
}
|
||
|
||
// 查询用户所属部门 ID 列表 + 权限数据范围(带 60 秒缓存)
|
||
let cached = USER_SCOPE_CACHE.get(&claims.sub).and_then(|entry| {
|
||
let (_, _, at) = entry.value();
|
||
if at.elapsed() < SCOPE_CACHE_TTL {
|
||
let (depts, scopes, _) = entry.value();
|
||
Some((depts.clone(), scopes.clone()))
|
||
} else {
|
||
drop(entry);
|
||
USER_SCOPE_CACHE.remove(&claims.sub);
|
||
None
|
||
}
|
||
});
|
||
let (department_ids, permission_data_scopes) = match cached {
|
||
Some(hit) => hit,
|
||
None => fetch_and_cache_scopes(claims.sub, claims.tid, &db).await,
|
||
};
|
||
|
||
// 提取请求来源信息(IP + User-Agent),用于审计日志
|
||
let request_info = RequestInfo::from_headers(req.headers());
|
||
|
||
let ctx = TenantContext {
|
||
tenant_id: claims.tid,
|
||
user_id: claims.sub,
|
||
roles: claims.roles,
|
||
permissions: claims.permissions,
|
||
department_ids,
|
||
permission_data_scopes,
|
||
};
|
||
|
||
// Reconstruct the request with the TenantContext injected into extensions.
|
||
// We cannot borrow `req` mutably after reading headers, so we rebuild.
|
||
let (parts, body) = req.into_parts();
|
||
let mut req = Request::from_parts(parts, body);
|
||
req.extensions_mut().insert(ctx);
|
||
|
||
// 在 task_local scope 中运行后续处理,审计服务可自动读取请求信息
|
||
Ok(REQUEST_INFO.scope(request_info, next.run(req)).await)
|
||
}
|
||
|
||
/// 查询用户所属的所有部门 ID(通过 user_departments 关联表)
|
||
async fn fetch_user_department_ids(
|
||
user_id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
db: &sea_orm::DatabaseConnection,
|
||
) -> Vec<uuid::Uuid> {
|
||
use crate::entity::user_department;
|
||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||
|
||
user_department::Entity::find()
|
||
.filter(user_department::Column::UserId.eq(user_id))
|
||
.filter(user_department::Column::TenantId.eq(tenant_id))
|
||
.filter(user_department::Column::DeletedAt.is_null())
|
||
.all(db)
|
||
.await
|
||
.map(|rows| rows.into_iter().map(|r| r.department_id).collect())
|
||
.unwrap_or_else(|e| {
|
||
tracing::warn!(error = %e, "查询用户部门列表失败,默认为空");
|
||
vec![]
|
||
})
|
||
}
|
||
|
||
/// 查询用户每个权限的数据范围(从 role_permissions 表)
|
||
async fn fetch_permission_data_scopes(
|
||
user_id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
db: &sea_orm::DatabaseConnection,
|
||
) -> std::collections::HashMap<String, DataScope> {
|
||
use sea_orm::ConnectionTrait;
|
||
|
||
let sql = r#"
|
||
SELECT p.code, MIN(
|
||
CASE rp.data_scope
|
||
WHEN 'all' THEN 0
|
||
WHEN 'department_tree' THEN 1
|
||
WHEN 'department' THEN 2
|
||
WHEN 'self' THEN 3
|
||
ELSE 0
|
||
END
|
||
) AS scope_rank,
|
||
MIN(rp.data_scope) AS data_scope
|
||
FROM user_roles ur
|
||
JOIN role_permissions rp ON ur.role_id = rp.role_id AND ur.tenant_id = rp.tenant_id
|
||
JOIN permissions p ON rp.permission_id = p.id
|
||
WHERE ur.user_id = $1
|
||
AND ur.tenant_id = $2
|
||
AND ur.deleted_at IS NULL
|
||
AND rp.deleted_at IS NULL
|
||
GROUP BY p.code
|
||
"#;
|
||
|
||
let stmt = sea_orm::Statement::from_sql_and_values(
|
||
sea_orm::DatabaseBackend::Postgres,
|
||
sql,
|
||
[user_id.into(), tenant_id.into()],
|
||
);
|
||
|
||
match db.query_all(stmt).await {
|
||
Ok(rows) => {
|
||
let mut scopes = std::collections::HashMap::new();
|
||
for row in rows {
|
||
if let (Ok(code), Ok(scope)) = (
|
||
row.try_get_by_index::<String>(0),
|
||
row.try_get_by_index::<String>(2),
|
||
) {
|
||
scopes.insert(code, DataScope::parse_scope(&scope));
|
||
}
|
||
}
|
||
scopes
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "查询权限数据范围失败,默认全部 All");
|
||
std::collections::HashMap::new()
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 从 DB 查询部门 + 权限范围,并写入缓存
|
||
async fn fetch_and_cache_scopes(
|
||
user_id: uuid::Uuid,
|
||
tenant_id: uuid::Uuid,
|
||
db: &Option<sea_orm::DatabaseConnection>,
|
||
) -> (
|
||
Vec<uuid::Uuid>,
|
||
std::collections::HashMap<String, DataScope>,
|
||
) {
|
||
let depts = match db {
|
||
Some(conn) => fetch_user_department_ids(user_id, tenant_id, conn).await,
|
||
None => vec![],
|
||
};
|
||
let scopes = match db {
|
||
Some(conn) => fetch_permission_data_scopes(user_id, tenant_id, conn).await,
|
||
None => std::collections::HashMap::new(),
|
||
};
|
||
USER_SCOPE_CACHE.insert(
|
||
user_id,
|
||
(depts.clone(), scopes.clone(), std::time::Instant::now()),
|
||
);
|
||
// 惰性淘汰过期条目,防止 DashMap 无限增长
|
||
if USER_SCOPE_CACHE.len() > 500 {
|
||
let now = std::time::Instant::now();
|
||
USER_SCOPE_CACHE.retain(|_, (_, _, at)| now.duration_since(*at) < SCOPE_CACHE_TTL);
|
||
}
|
||
(depts, scopes)
|
||
}
|