Files
csm/crates/server/src/api/auth.rs
iven 60ee38a3c2 feat: 新增补丁管理和异常检测插件及相关功能
feat(protocol): 添加补丁管理和行为指标协议类型
feat(client): 实现补丁管理插件采集功能
feat(server): 添加补丁管理和异常检测API
feat(database): 新增补丁状态和异常检测相关表
feat(web): 添加补丁管理和异常检测前端页面
fix(security): 增强输入验证和防注入保护
refactor(auth): 重构认证检查逻辑
perf(service): 优化Windows服务恢复策略
style: 统一健康评分显示样式
docs: 更新知识库文档
2026-04-11 15:59:53 +08:00

596 lines
20 KiB
Rust

use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::{Response, IntoResponse}};
use axum::http::header::{SET_COOKIE, HeaderValue};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::Mutex;
use crate::AppState;
use super::ApiResponse;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: i64,
pub username: String,
pub role: String,
pub exp: u64,
pub iat: u64,
pub token_type: String,
/// Random family ID for refresh token rotation detection
pub family: String,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub user: UserInfo,
}
#[derive(Debug, Serialize)]
pub struct MeResponse {
pub user: UserInfo,
pub expires_at: String,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct UserInfo {
pub id: i64,
pub username: String,
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct ChangePasswordRequest {
pub old_password: String,
pub new_password: String,
}
// ---------------------------------------------------------------------------
// Cookie helpers
// ---------------------------------------------------------------------------
fn is_secure_cookies() -> bool {
std::env::var("CSM_DEV").is_err()
}
fn access_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"access_token={}; HttpOnly{}; SameSite=Strict; Path=/; Max-Age={}",
token, secure, ttl_secs
)).expect("valid cookie header")
}
fn refresh_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"refresh_token={}; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age={}",
token, secure, ttl_secs
)).expect("valid cookie header")
}
fn clear_cookie_headers() -> Vec<HeaderValue> {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
vec![
HeaderValue::from_str(&format!("access_token=; HttpOnly{}; SameSite=Strict; Path=/; Max-Age=0", secure)).expect("valid"),
HeaderValue::from_str(&format!("refresh_token=; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age=0", secure)).expect("valid"),
]
}
/// Attach Set-Cookie headers to a response.
fn with_cookies(mut response: Response, cookies: Vec<HeaderValue>) -> Response {
for cookie in cookies {
response.headers_mut().append(SET_COOKIE, cookie);
}
response
}
/// Extract a cookie value by name from the raw Cookie header.
fn extract_cookie_value(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
let cookie_header = headers.get("cookie")?.to_str().ok()?;
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) {
return Some(value.to_string());
}
}
None
}
// ---------------------------------------------------------------------------
// Rate limiter
// ---------------------------------------------------------------------------
#[derive(Clone, Default)]
pub struct LoginRateLimiter {
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
}
impl LoginRateLimiter {
pub fn new() -> Self {
Self::default()
}
pub async fn is_limited(&self, key: &str) -> bool {
let mut attempts = self.attempts.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(300);
let max_attempts = 10u32;
if let Some((first_attempt, count)) = attempts.get_mut(key) {
if now.duration_since(*first_attempt) > window {
*first_attempt = now;
*count = 1;
false
} else if *count >= max_attempts {
true
} else {
*count += 1;
false
}
} else {
attempts.insert(key.to_string(), (now, 1));
if attempts.len() > 1000 {
let cutoff = now - window;
attempts.retain(|_, (t, _)| *t > cutoff);
}
false
}
}
}
// ---------------------------------------------------------------------------
// Endpoints
// ---------------------------------------------------------------------------
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> impl IntoResponse {
if state.login_limiter.is_limited(&req.username).await {
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts. Try again later."))).into_response();
}
if state.login_limiter.is_limited("ip:default").await {
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts from this location. Try again later."))).into_response();
}
let row: Option<(UserInfo, String)> = sqlx::query_as::<_, (i64, String, String, String)>(
"SELECT id, username, role, password FROM users WHERE username = ?"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.ok()
.flatten()
.map(|(id, username, role, password)| {
(UserInfo { id, username, role }, password)
});
let (user, hash) = match row {
Some(r) => r,
None => {
let _ = bcrypt::verify("timing-constant-dummy", "$2b$12$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
}
};
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
}
let now = chrono::Utc::now().timestamp() as u64;
let family = uuid::Uuid::new_v4().to_string();
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
let _ = sqlx::query(
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
)
.bind(user.id)
.bind(&family)
.bind(refresh_expires as i64)
.execute(&state.db)
.await;
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
)
.bind(user.id)
.bind(format!("User {} logged in", user.username))
.execute(&state.db)
.await;
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
with_cookies(response, vec![
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
])
}
pub async fn refresh(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let refresh_token = match extract_cookie_value(&headers, "refresh_token") {
Some(t) => t,
None => return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Missing refresh token"))).into_response(),
clear_cookie_headers(),
),
};
let claims = match decode::<Claims>(
&refresh_token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response(),
clear_cookie_headers(),
),
};
if claims.token_type != "refresh" {
return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid token type"))).into_response(),
clear_cookie_headers(),
);
}
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let revoked: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
)
.bind(&claims.family)
.fetch_one(&mut *tx)
.await
.unwrap_or(0) > 0;
if revoked {
tx.rollback().await.ok();
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.sub, claims.family);
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.sub)
.execute(&state.db)
.await;
return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Token reuse detected. Please log in again."))).into_response(),
clear_cookie_headers(),
);
}
let family_exists: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM refresh_tokens WHERE family = ? AND user_id = ?"
)
.bind(&claims.family)
.bind(claims.sub)
.fetch_one(&mut *tx)
.await
.unwrap_or(0) > 0;
if !family_exists {
tx.rollback().await.ok();
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response();
}
let user = UserInfo {
id: claims.sub,
username: claims.username,
role: claims.role,
};
let new_family = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().timestamp() as u64;
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
if sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.family)
.bind(claims.sub)
.execute(&mut *tx)
.await
.is_err()
{
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
if sqlx::query(
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
)
.bind(user.id)
.bind(&new_family)
.bind(refresh_expires as i64)
.execute(&mut *tx)
.await
.is_err()
{
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if tx.commit().await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
with_cookies(response, vec![
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
])
}
/// Get current authenticated user info from access_token cookie.
pub async fn me(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = match extract_cookie_value(&headers, "access_token") {
Some(t) => t,
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Not authenticated"))).into_response(),
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token"))).into_response(),
};
if claims.token_type != "access" {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token type"))).into_response();
}
let expires_at = chrono::DateTime::from_timestamp(claims.exp as i64, 0)
.map(|t| t.to_rfc3339())
.unwrap_or_default();
(StatusCode::OK, Json(ApiResponse::ok(MeResponse {
user: UserInfo {
id: claims.sub,
username: claims.username,
role: claims.role,
},
expires_at,
}))).into_response()
}
/// Logout: clear auth cookies and revoke refresh token family.
pub async fn logout(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
if let Some(token) = extract_cookie_value(&headers, "access_token") {
if let Ok(claims) = decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.execute(&state.db)
.await;
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'logout', ?)"
)
.bind(claims.claims.sub)
.bind(format!("User {} logged out", claims.claims.username))
.execute(&state.db)
.await;
}
}
let response = (StatusCode::OK, Json(ApiResponse::ok(()))).into_response();
with_cookies(response, clear_cookie_headers())
}
// ---------------------------------------------------------------------------
// WebSocket ticket
// ---------------------------------------------------------------------------
#[derive(Debug, Serialize)]
pub struct WsTicketResponse {
pub ticket: String,
pub expires_in: u64,
}
/// Create a one-time ticket for WebSocket authentication.
/// Requires a valid access_token cookie (set by login).
pub async fn create_ws_ticket(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = match extract_cookie_value(&headers, "access_token") {
Some(t) => t,
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Not authenticated"))).into_response(),
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token"))).into_response(),
};
if claims.token_type != "access" {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token type"))).into_response();
}
let ticket = uuid::Uuid::new_v4().to_string();
let claim = crate::ws::TicketClaim {
user_id: claims.sub,
username: claims.username,
role: claims.role,
created_at: std::time::Instant::now(),
};
{
let mut tickets = state.ws_tickets.lock().await;
tickets.insert(ticket.clone(), claim);
// Cleanup expired tickets (>30s old) on every creation
tickets.retain(|_, c| c.created_at.elapsed().as_secs() < 30);
}
(StatusCode::OK, Json(ApiResponse::ok(WsTicketResponse {
ticket,
expires_in: 30,
}))).into_response()
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
let claims = Claims {
sub: user.id,
username: user.username.clone(),
role: user.role.clone(),
exp: now + ttl,
iat: now,
token_type: token_type.to_string(),
family: family.to_string(),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// Axum middleware: require valid JWT access token from cookie
pub async fn require_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let token = extract_cookie_value(request.headers(), "access_token")
.ok_or(StatusCode::UNAUTHORIZED)?;
let claims = decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "access" {
return Err(StatusCode::UNAUTHORIZED);
}
request.extensions_mut().insert(claims.claims);
Ok(next.run(request).await)
}
/// Axum middleware: require admin role for write operations + audit log
pub async fn require_admin(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let claims = request.extensions()
.get::<Claims>()
.ok_or(StatusCode::UNAUTHORIZED)?;
if claims.role != "admin" {
return Err(StatusCode::FORBIDDEN);
}
let method = request.method().clone();
let path = request.uri().path().to_string();
let user_id = claims.sub;
let username = claims.username.clone();
let response = next.run(request).await;
let status = response.status();
if status.is_success() {
let action = format!("{} {}", method, path);
let detail = format!("by {}", username);
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)"
)
.bind(user_id)
.bind(&action)
.bind(&detail)
.execute(&state.db)
.await;
}
Ok(response)
}
pub async fn change_password(
State(state): State<AppState>,
claims: axum::Extension<Claims>,
Json(req): Json<ChangePasswordRequest>,
) -> Result<(StatusCode, Json<ApiResponse<()>>), StatusCode> {
if req.new_password.len() < 6 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("新密码至少6位"))));
}
if req.new_password.len() > 72 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("密码不能超过72位"))));
}
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
.bind(claims.sub)
.fetch_one(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !bcrypt::verify(&req.old_password, &hash).unwrap_or(false) {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("当前密码错误"))));
}
let new_hash = bcrypt::hash(&req.new_password, 12).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
sqlx::query("UPDATE users SET password = ? WHERE id = ?")
.bind(&new_hash)
.bind(claims.sub)
.execute(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'change_password', ?)"
)
.bind(claims.sub)
.bind(format!("User {} changed password", claims.username))
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(()))))
}