feat: 新增补丁管理和异常检测插件及相关功能

feat(protocol): 添加补丁管理和行为指标协议类型
feat(client): 实现补丁管理插件采集功能
feat(server): 添加补丁管理和异常检测API
feat(database): 新增补丁状态和异常检测相关表
feat(web): 添加补丁管理和异常检测前端页面
fix(security): 增强输入验证和防注入保护
refactor(auth): 重构认证检查逻辑
perf(service): 优化Windows服务恢复策略
style: 统一健康评分显示样式
docs: 更新知识库文档
This commit is contained in:
iven
2026-04-11 15:59:53 +08:00
parent b5333d8c93
commit 60ee38a3c2
49 changed files with 3988 additions and 461 deletions

View File

@@ -41,6 +41,26 @@ pub async fn cleanup_task(state: AppState) {
error!("Failed to cleanup alert records: {}", e);
}
// Cleanup old revoked token families (keep 30 days for audit)
if let Err(e) = sqlx::query(
"DELETE FROM revoked_token_families WHERE revoked_at < datetime('now', '-30 days')"
)
.execute(&state.db)
.await
{
error!("Failed to cleanup revoked token families: {}", e);
}
// Cleanup old anomaly alerts that have been handled
if let Err(e) = sqlx::query(
"DELETE FROM anomaly_alerts WHERE handled = 1 AND triggered_at < datetime('now', '-90 days')"
)
.execute(&state.db)
.await
{
error!("Failed to cleanup handled anomaly alerts: {}", e);
}
// Mark devices as offline if no heartbeat for 2 minutes
if let Err(e) = sqlx::query(
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"

View File

@@ -0,0 +1,170 @@
use sqlx::Row;
use tracing::{info, warn};
use csm_protocol::BehaviorMetricsPayload;
/// Check incoming behavior metrics against anomaly rules and generate alerts
pub async fn check_anomalies(
pool: &sqlx::SqlitePool,
ws_hub: &crate::ws::WsHub,
metrics: &BehaviorMetricsPayload,
) {
let mut alerts: Vec<serde_json::Value> = Vec::new();
// Rule 1: Night-time clipboard operations (> 10 in reporting period)
if metrics.clipboard_ops_night > 10 {
alerts.push(serde_json::json!({
"anomaly_type": "night_clipboard_spike",
"severity": "high",
"detail": format!("非工作时间剪贴板操作异常: {}次 (阈值: 10次)", metrics.clipboard_ops_night),
"metric_value": metrics.clipboard_ops_night,
}));
}
// Rule 2: High USB file operations (> 100 per hour)
if metrics.period_secs > 0 {
let usb_per_hour = (metrics.usb_file_ops_count as f64 / metrics.period_secs as f64) * 3600.0;
if usb_per_hour > 100.0 {
alerts.push(serde_json::json!({
"anomaly_type": "usb_file_exfiltration",
"severity": "critical",
"detail": format!("USB文件操作频率异常: {:.0}次/小时 (阈值: 100次/小时)", usb_per_hour),
"metric_value": usb_per_hour,
}));
}
}
// Rule 3: High print volume (> 50 per reporting period)
if metrics.print_jobs_count > 50 {
alerts.push(serde_json::json!({
"anomaly_type": "high_print_volume",
"severity": "medium",
"detail": format!("打印量异常: {}次 (阈值: 50次)", metrics.print_jobs_count),
"metric_value": metrics.print_jobs_count,
}));
}
// Rule 4: Excessive new processes (> 20 per hour)
if metrics.period_secs > 0 {
let procs_per_hour = (metrics.new_processes_count as f64 / metrics.period_secs as f64) * 3600.0;
if procs_per_hour > 20.0 {
alerts.push(serde_json::json!({
"anomaly_type": "process_spawn_spike",
"severity": "medium",
"detail": format!("新进程启动异常: {:.0}次/小时 (阈值: 20次/小时)", procs_per_hour),
"metric_value": procs_per_hour,
}));
}
}
// Insert anomaly alerts
for alert in &alerts {
if let Err(e) = sqlx::query(
"INSERT INTO anomaly_alerts (device_uid, anomaly_type, severity, detail, metric_value, triggered_at) \
VALUES (?, ?, ?, ?, ?, datetime('now'))"
)
.bind(&metrics.device_uid)
.bind(alert.get("anomaly_type").and_then(|v| v.as_str()).unwrap_or("unknown"))
.bind(alert.get("severity").and_then(|v| v.as_str()).unwrap_or("medium"))
.bind(alert.get("detail").and_then(|v| v.as_str()).unwrap_or(""))
.bind(alert.get("metric_value").and_then(|v| v.as_f64()).unwrap_or(0.0))
.execute(pool)
.await
{
warn!("Failed to insert anomaly alert: {}", e);
}
}
// Broadcast anomaly alerts via WebSocket
if !alerts.is_empty() {
for alert in &alerts {
ws_hub.broadcast(serde_json::json!({
"type": "anomaly_alert",
"device_uid": metrics.device_uid,
"anomaly_type": alert.get("anomaly_type"),
"severity": alert.get("severity"),
"detail": alert.get("detail"),
}).to_string()).await;
}
info!("Detected {} anomalies for device {}", alerts.len(), metrics.device_uid);
}
}
/// Get anomaly alert summary for a device or all devices
pub async fn get_anomaly_summary(
pool: &sqlx::SqlitePool,
device_uid: Option<&str>,
page: u32,
page_size: u32,
) -> anyhow::Result<serde_json::Value> {
let offset = page.saturating_sub(1) * page_size;
let rows = if let Some(uid) = device_uid {
sqlx::query(
"SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \
WHERE a.device_uid = ? ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(uid)
.bind(page_size)
.bind(offset)
.fetch_all(pool)
.await?
} else {
sqlx::query(
"SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \
ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(page_size)
.bind(offset)
.fetch_all(pool)
.await?
};
let total: i64 = if let Some(uid) = device_uid {
sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts WHERE device_uid = ?")
.bind(uid)
.fetch_one(pool)
.await?
} else {
sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts")
.fetch_one(pool)
.await?
};
let alerts: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"anomaly_type": r.get::<String, _>("anomaly_type"),
"severity": r.get::<String, _>("severity"),
"detail": r.get::<String, _>("detail"),
"metric_value": r.get::<f64, _>("metric_value"),
"handled": r.get::<i32, _>("handled"),
"triggered_at": r.get::<String, _>("triggered_at"),
})).collect();
// Summary counts (scoped to same filter)
let unhandled: i64 = if let Some(uid) = device_uid {
sqlx::query_scalar(
"SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0 AND device_uid = ?"
)
.bind(uid)
.fetch_one(pool)
.await
.unwrap_or(0)
} else {
sqlx::query_scalar(
"SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0"
)
.fetch_one(pool)
.await
.unwrap_or(0)
};
Ok(serde_json::json!({
"alerts": alerts,
"total": total,
"unhandled_count": unhandled,
"page": page,
"page_size": page_size,
}))
}

View File

@@ -21,7 +21,7 @@ pub async fn list_rules(
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
FROM alert_rules ORDER BY created_at DESC"
FROM alert_rules ORDER BY created_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await;
@@ -116,7 +116,26 @@ pub async fn create_rule(
State(state): State<AppState>,
Json(body): Json<CreateRuleRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
// Validate rule_type
if !matches!(body.rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid rule_type")));
}
// Validate severity
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid severity")));
}
// Validate webhook URL (SSRF prevention)
if let Some(ref url) = body.notify_webhook {
if !url.starts_with("https://") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL must use HTTPS")));
}
if url.len() > 2048 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL too long")));
}
}
let result = sqlx::query(
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
@@ -174,6 +193,26 @@ pub async fn update_rule(
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
// Validate rule_type
if !matches!(rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") {
return Json(ApiResponse::error("Invalid rule_type"));
}
// Validate severity
if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") {
return Json(ApiResponse::error("Invalid severity"));
}
// Validate webhook URL (SSRF prevention)
if let Some(ref url) = notify_webhook {
if !url.starts_with("https://") {
return Json(ApiResponse::error("Webhook URL must use HTTPS"));
}
if url.len() > 2048 {
return Json(ApiResponse::error("Webhook URL too long"));
}
}
let result = sqlx::query(
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"

View File

@@ -1,4 +1,5 @@
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::{Response, IntoResponse}};
use axum::http::header::{SET_COOKIE, HeaderValue};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;
@@ -28,11 +29,15 @@ pub struct LoginRequest {
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: String,
pub user: UserInfo,
}
#[derive(Debug, Serialize)]
pub struct MeResponse {
pub user: UserInfo,
pub expires_at: String,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct UserInfo {
pub id: i64,
@@ -40,18 +45,68 @@ pub struct UserInfo {
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Deserialize)]
pub struct ChangePasswordRequest {
pub old_password: String,
pub new_password: String,
}
/// In-memory rate limiter for login attempts
// ---------------------------------------------------------------------------
// Cookie helpers
// ---------------------------------------------------------------------------
fn is_secure_cookies() -> bool {
std::env::var("CSM_DEV").is_err()
}
fn access_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"access_token={}; HttpOnly{}; SameSite=Strict; Path=/; Max-Age={}",
token, secure, ttl_secs
)).expect("valid cookie header")
}
fn refresh_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"refresh_token={}; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age={}",
token, secure, ttl_secs
)).expect("valid cookie header")
}
fn clear_cookie_headers() -> Vec<HeaderValue> {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
vec![
HeaderValue::from_str(&format!("access_token=; HttpOnly{}; SameSite=Strict; Path=/; Max-Age=0", secure)).expect("valid"),
HeaderValue::from_str(&format!("refresh_token=; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age=0", secure)).expect("valid"),
]
}
/// Attach Set-Cookie headers to a response.
fn with_cookies(mut response: Response, cookies: Vec<HeaderValue>) -> Response {
for cookie in cookies {
response.headers_mut().append(SET_COOKIE, cookie);
}
response
}
/// Extract a cookie value by name from the raw Cookie header.
fn extract_cookie_value(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
let cookie_header = headers.get("cookie")?.to_str().ok()?;
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) {
return Some(value.to_string());
}
}
None
}
// ---------------------------------------------------------------------------
// Rate limiter
// ---------------------------------------------------------------------------
#[derive(Clone, Default)]
pub struct LoginRateLimiter {
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
@@ -62,28 +117,25 @@ impl LoginRateLimiter {
Self::default()
}
/// Returns true if the request should be rate-limited
pub async fn is_limited(&self, key: &str) -> bool {
let mut attempts = self.attempts.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(300); // 5-minute window
let window = std::time::Duration::from_secs(300);
let max_attempts = 10u32;
if let Some((first_attempt, count)) = attempts.get_mut(key) {
if now.duration_since(*first_attempt) > window {
// Window expired, reset
*first_attempt = now;
*count = 1;
false
} else if *count >= max_attempts {
true // Rate limited
true
} else {
*count += 1;
false
}
} else {
attempts.insert(key.to_string(), (now, 1));
// Cleanup old entries periodically
if attempts.len() > 1000 {
let cutoff = now - window;
attempts.retain(|_, (t, _)| *t > cutoff);
@@ -93,46 +145,67 @@ impl LoginRateLimiter {
}
}
// ---------------------------------------------------------------------------
// Endpoints
// ---------------------------------------------------------------------------
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
// Rate limit check
) -> impl IntoResponse {
if state.login_limiter.is_limited(&req.username).await {
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts. Try again later."))).into_response();
}
if state.login_limiter.is_limited("ip:default").await {
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts from this location. Try again later."))).into_response();
}
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
"SELECT id, username, role FROM users WHERE username = ?"
let row: Option<(UserInfo, String)> = sqlx::query_as::<_, (i64, String, String, String)>(
"SELECT id, username, role, password FROM users WHERE username = ?"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.ok()
.flatten()
.map(|(id, username, role, password)| {
(UserInfo { id, username, role }, password)
});
let user = match user {
Some(u) => u,
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
let (user, hash) = match row {
Some(r) => r,
None => {
let _ = bcrypt::verify("timing-constant-dummy", "$2b$12$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
}
};
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
.bind(user.id)
.fetch_one(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
}
let now = chrono::Utc::now().timestamp() as u64;
let family = uuid::Uuid::new_v4().to_string();
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
let _ = sqlx::query(
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
)
.bind(user.id)
.bind(&family)
.bind(refresh_expires as i64)
.execute(&state.db)
.await;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
)
@@ -141,73 +214,262 @@ pub async fn login(
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
with_cookies(response, vec![
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
])
}
pub async fn refresh(
State(state): State<AppState>,
Json(req): Json<RefreshRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
let claims = decode::<Claims>(
&req.refresh_token,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let refresh_token = match extract_cookie_value(&headers, "refresh_token") {
Some(t) => t,
None => return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Missing refresh token"))).into_response(),
clear_cookie_headers(),
),
};
let claims = match decode::<Claims>(
&refresh_token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
) {
Ok(c) => c.claims,
Err(_) => return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response(),
clear_cookie_headers(),
),
};
if claims.claims.token_type != "refresh" {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
if claims.token_type != "refresh" {
return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid token type"))).into_response(),
clear_cookie_headers(),
);
}
// Check if this refresh token family has been revoked (reuse detection)
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let revoked: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
)
.bind(&claims.claims.family)
.fetch_one(&state.db)
.bind(&claims.family)
.fetch_one(&mut *tx)
.await
.unwrap_or(0) > 0;
if revoked {
// Token reuse detected — revoke entire family and force re-login
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
tx.rollback().await.ok();
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.sub, claims.family);
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.bind(claims.sub)
.execute(&state.db)
.await;
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Token reuse detected. Please log in again."))).into_response(),
clear_cookie_headers(),
);
}
let family_exists: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM refresh_tokens WHERE family = ? AND user_id = ?"
)
.bind(&claims.family)
.bind(claims.sub)
.fetch_one(&mut *tx)
.await
.unwrap_or(0) > 0;
if !family_exists {
tx.rollback().await.ok();
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response();
}
let user = UserInfo {
id: claims.claims.sub,
username: claims.claims.username,
role: claims.claims.role,
id: claims.sub,
username: claims.username,
role: claims.role,
};
// Rotate: new family for each refresh
let new_family = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().timestamp() as u64;
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
// Revoke old family
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.claims.family)
.bind(claims.claims.sub)
.execute(&state.db)
.await;
if sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.family)
.bind(claims.sub)
.execute(&mut *tx)
.await
.is_err()
{
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
if sqlx::query(
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
)
.bind(user.id)
.bind(&new_family)
.bind(refresh_expires as i64)
.execute(&mut *tx)
.await
.is_err()
{
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if tx.commit().await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
with_cookies(response, vec![
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
])
}
/// Get current authenticated user info from access_token cookie.
pub async fn me(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = match extract_cookie_value(&headers, "access_token") {
Some(t) => t,
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Not authenticated"))).into_response(),
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token"))).into_response(),
};
if claims.token_type != "access" {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token type"))).into_response();
}
let expires_at = chrono::DateTime::from_timestamp(claims.exp as i64, 0)
.map(|t| t.to_rfc3339())
.unwrap_or_default();
(StatusCode::OK, Json(ApiResponse::ok(MeResponse {
user: UserInfo {
id: claims.sub,
username: claims.username,
role: claims.role,
},
expires_at,
}))).into_response()
}
/// Logout: clear auth cookies and revoke refresh token family.
pub async fn logout(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
if let Some(token) = extract_cookie_value(&headers, "access_token") {
if let Ok(claims) = decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.execute(&state.db)
.await;
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'logout', ?)"
)
.bind(claims.claims.sub)
.bind(format!("User {} logged out", claims.claims.username))
.execute(&state.db)
.await;
}
}
let response = (StatusCode::OK, Json(ApiResponse::ok(()))).into_response();
with_cookies(response, clear_cookie_headers())
}
// ---------------------------------------------------------------------------
// WebSocket ticket
// ---------------------------------------------------------------------------
#[derive(Debug, Serialize)]
pub struct WsTicketResponse {
pub ticket: String,
pub expires_in: u64,
}
/// Create a one-time ticket for WebSocket authentication.
/// Requires a valid access_token cookie (set by login).
pub async fn create_ws_ticket(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = match extract_cookie_value(&headers, "access_token") {
Some(t) => t,
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Not authenticated"))).into_response(),
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token"))).into_response(),
};
if claims.token_type != "access" {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token type"))).into_response();
}
let ticket = uuid::Uuid::new_v4().to_string();
let claim = crate::ws::TicketClaim {
user_id: claims.sub,
username: claims.username,
role: claims.role,
created_at: std::time::Instant::now(),
};
{
let mut tickets = state.ws_tickets.lock().await;
tickets.insert(ticket.clone(), claim);
// Cleanup expired tickets (>30s old) on every creation
tickets.retain(|_, c| c.created_at.elapsed().as_secs() < 30);
}
(StatusCode::OK, Json(ApiResponse::ok(WsTicketResponse {
ticket,
expires_in: 30,
}))).into_response()
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
let claims = Claims {
sub: user.id,
@@ -227,24 +489,17 @@ fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// Axum middleware: require valid JWT access token
/// Axum middleware: require valid JWT access token from cookie
pub async fn require_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let token = match auth_header {
Some(t) => t,
None => return Err(StatusCode::UNAUTHORIZED),
};
let token = extract_cookie_value(request.headers(), "access_token")
.ok_or(StatusCode::UNAUTHORIZED)?;
let claims = decode::<Claims>(
token,
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
@@ -254,9 +509,7 @@ pub async fn require_auth(
return Err(StatusCode::UNAUTHORIZED);
}
// Inject claims into request extensions for handlers to use
request.extensions_mut().insert(claims.claims);
Ok(next.run(request).await)
}
@@ -274,7 +527,6 @@ pub async fn require_admin(
return Err(StatusCode::FORBIDDEN);
}
// Capture audit info before running handler
let method = request.method().clone();
let path = request.uri().path().to_string();
let user_id = claims.sub;
@@ -282,7 +534,6 @@ pub async fn require_admin(
let response = next.run(request).await;
// Record admin action to audit log (fire and forget — don't block response)
let status = response.status();
if status.is_success() {
let action = format!("{} {}", method, path);
@@ -308,8 +559,10 @@ pub async fn change_password(
if req.new_password.len() < 6 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("新密码至少6位"))));
}
if req.new_password.len() > 72 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("密码不能超过72位"))));
}
// Verify old password
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
@@ -322,7 +575,6 @@ pub async fn change_password(
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("当前密码错误"))));
}
// Update password
let new_hash = bcrypt::hash(&req.new_password, 12).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
sqlx::query("UPDATE users SET password = ? WHERE id = ?")
.bind(&new_hash)
@@ -331,7 +583,6 @@ pub async fn change_password(
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'change_password', ?)"
)

View File

@@ -0,0 +1,243 @@
use axum::{extract::State, Json};
use serde::Serialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
#[derive(Debug, Serialize)]
pub struct PolicyConflict {
pub conflict_type: String,
pub severity: String,
pub description: String,
pub policies: Vec<ConflictPolicyRef>,
}
#[derive(Debug, Serialize)]
pub struct ConflictPolicyRef {
pub table_name: String,
pub row_id: i64,
pub name: String,
pub target_type: String,
pub target_id: Option<String>,
}
/// GET /api/policies/conflicts — scan all policies for conflicts
pub async fn scan_conflicts(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let mut conflicts: Vec<PolicyConflict> = Vec::new();
// 1. USB: multiple enabled policies for the same target_group
if let Ok(rows) = sqlx::query(
"SELECT target_group, COUNT(*) as cnt, GROUP_CONCAT(id) as ids, GROUP_CONCAT(name) as names, \
GROUP_CONCAT(policy_type) as types \
FROM usb_policies WHERE enabled = 1 AND target_group IS NOT NULL \
GROUP BY target_group HAVING cnt > 1"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let group: String = row.get("target_group");
let ids: String = row.get("ids");
let names: String = row.get("names");
let types: String = row.get("types");
let id_vec: Vec<i64> = ids.split(',').filter_map(|s| s.parse().ok()).collect();
let name_vec: Vec<&str> = names.split(',').collect();
let type_vec: Vec<&str> = types.split(',').collect();
conflicts.push(PolicyConflict {
conflict_type: "usb_duplicate_policy".to_string(),
severity: "high".to_string(),
description: format!("分组 '{}' 同时存在 {} 条启用的USB策略 ({})", group, id_vec.len(), type_vec.join(", ")),
policies: id_vec.iter().enumerate().map(|(i, id)| ConflictPolicyRef {
table_name: "usb_policies".to_string(),
row_id: *id,
name: name_vec.get(i).unwrap_or(&"?").to_string(),
target_type: "group".to_string(),
target_id: Some(group.clone()),
}).collect(),
});
}
}
// 2. USB: all_block + whitelist for same target (contradictory)
if let Ok(rows) = sqlx::query(
"SELECT a.id as aid, a.name as aname, a.target_group as agroup, \
b.id as bid, b.name as bname \
FROM usb_policies a JOIN usb_policies b ON a.target_group = b.target_group AND a.id < b.id \
WHERE a.enabled = 1 AND b.enabled = 1 \
AND ((a.policy_type = 'all_block' AND b.policy_type = 'whitelist') OR \
(a.policy_type = 'whitelist' AND b.policy_type = 'all_block'))"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let group: Option<String> = row.get("agroup");
conflicts.push(PolicyConflict {
conflict_type: "usb_block_vs_whitelist".to_string(),
severity: "critical".to_string(),
description: format!("分组 '{}' 同时存在全封堵和白名单USB策略互斥", group.as_deref().unwrap_or("?")),
policies: vec![
ConflictPolicyRef {
table_name: "usb_policies".to_string(),
row_id: row.get("aid"),
name: row.get("aname"),
target_type: "group".to_string(),
target_id: group.clone(),
},
ConflictPolicyRef {
table_name: "usb_policies".to_string(),
row_id: row.get("bid"),
name: row.get("bname"),
target_type: "group".to_string(),
target_id: group,
},
],
});
}
}
// 3. Web filter: same target, same pattern, different rule_type (allow vs block)
if let Ok(rows) = sqlx::query(
"SELECT a.id as aid, a.pattern as apattern, a.rule_type as artype, \
b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \
FROM web_filter_rules a JOIN web_filter_rules b ON a.pattern = b.pattern AND a.id < b.id \
WHERE a.enabled = 1 AND b.enabled = 1 \
AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \
AND a.rule_type != b.rule_type"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let pattern: String = row.get("apattern");
let artype: String = row.get("artype");
let brtype: String = row.get("brtype");
let ttype: String = row.get("ttype");
let tid: Option<String> = row.get("tid");
conflicts.push(PolicyConflict {
conflict_type: "web_filter_allow_vs_block".to_string(),
severity: "high".to_string(),
description: format!("URL '{}' 同时被 {} 和 {},互斥", pattern, artype, brtype),
policies: vec![
ConflictPolicyRef {
table_name: "web_filter_rules".to_string(),
row_id: row.get("aid"),
name: format!("{}: {}", artype, pattern),
target_type: ttype.clone(),
target_id: tid.clone(),
},
ConflictPolicyRef {
table_name: "web_filter_rules".to_string(),
row_id: row.get("bid"),
name: format!("{}: {}", brtype, pattern),
target_type: ttype,
target_id: tid,
},
],
});
}
}
// 4. Clipboard: same source/target process, allow vs block
if let Ok(rows) = sqlx::query(
"SELECT a.id as aid, a.rule_type as artype, a.source_process as asrc, a.target_process as adst, \
b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \
FROM clipboard_rules a JOIN clipboard_rules b ON a.id < b.id \
WHERE a.enabled = 1 AND b.enabled = 1 \
AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \
AND a.direction = b.direction \
AND COALESCE(a.source_process,'') = COALESCE(b.source_process,'') \
AND COALESCE(a.target_process,'') = COALESCE(b.target_process,'') \
AND a.rule_type != b.rule_type"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let asrc: Option<String> = row.get("asrc");
let adst: Option<String> = row.get("adst");
let artype: String = row.get("artype");
let brtype: String = row.get("brtype");
let desc = format!(
"剪贴板规则冲突: {}{} 同时存在 {}{}",
asrc.as_deref().unwrap_or("*"),
adst.as_deref().unwrap_or("*"),
artype, brtype,
);
let ttype: String = row.get("ttype");
let tid: Option<String> = row.get("tid");
conflicts.push(PolicyConflict {
conflict_type: "clipboard_allow_vs_block".to_string(),
severity: "medium".to_string(),
description: desc,
policies: vec![
ConflictPolicyRef {
table_name: "clipboard_rules".to_string(),
row_id: row.get("aid"),
name: format!("{}: {}", artype, asrc.as_deref().unwrap_or("*")),
target_type: ttype.clone(),
target_id: tid.clone(),
},
ConflictPolicyRef {
table_name: "clipboard_rules".to_string(),
row_id: row.get("bid"),
name: format!("{}: {}", brtype, asrc.as_deref().unwrap_or("*")),
target_type: ttype,
target_id: tid,
},
],
});
}
}
// 5. Plugin disabled but has active rules
let plugin_tables: [(&str, &str, &str, &str); 4] = [
("web_filter_rules", "上网行为过滤", "web_filter", "SELECT COUNT(*) FROM web_filter_rules WHERE enabled = 1"),
("software_blacklist", "软件黑名单", "software_blocker", "SELECT COUNT(*) FROM software_blacklist WHERE enabled = 1"),
("popup_filter_rules", "弹窗拦截", "popup_blocker", "SELECT COUNT(*) FROM popup_filter_rules WHERE enabled = 1"),
("clipboard_rules", "剪贴板管控", "clipboard_control", "SELECT COUNT(*) FROM clipboard_rules WHERE enabled = 1"),
];
for (_table, label, plugin, query) in &plugin_tables {
let active_count: i64 = sqlx::query_scalar(query)
.fetch_one(&state.db)
.await
.unwrap_or(0);
if active_count > 0 {
let disabled: bool = sqlx::query_scalar::<_, i32>(
"SELECT COUNT(*) FROM plugin_state WHERE plugin_name = ? AND enabled = 0"
)
.bind(plugin)
.fetch_one(&state.db)
.await
.unwrap_or(0) > 0;
if disabled {
conflicts.push(PolicyConflict {
conflict_type: "plugin_disabled_with_rules".to_string(),
severity: "low".to_string(),
description: format!("插件 '{}' 已禁用,但仍有 {} 条启用规则,规则不会生效", label, active_count),
policies: vec![ConflictPolicyRef {
table_name: "plugin_state".to_string(),
row_id: 0,
name: plugin.to_string(),
target_type: "global".to_string(),
target_id: None,
}],
});
}
}
}
Json(ApiResponse::ok(serde_json::json!({
"conflicts": conflicts,
"total": conflicts.len(),
"critical_count": conflicts.iter().filter(|c| c.severity == "critical").count(),
"high_count": conflicts.iter().filter(|c| c.severity == "high").count(),
"medium_count": conflicts.iter().filter(|c| c.severity == "medium").count(),
"low_count": conflicts.iter().filter(|c| c.severity == "low").count(),
})))
}

View File

@@ -4,6 +4,28 @@ use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
/// GET /api/devices/:uid/health-score
pub async fn get_health_score(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
match crate::health::get_device_score(&state.db, &uid).await {
Ok(Some(score)) => Json(ApiResponse::ok(score)),
Ok(None) => Json(ApiResponse::error("No health score available")),
Err(e) => Json(ApiResponse::internal_error("health score", e)),
}
}
/// GET /api/dashboard/health-overview
pub async fn health_overview(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
match crate::health::get_health_overview(&state.db).await {
Ok(overview) => Json(ApiResponse::ok(overview)),
Err(e) => Json(ApiResponse::internal_error("health overview", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct DeviceListParams {
pub status: Option<String>,
@@ -26,6 +48,10 @@ pub struct DeviceRow {
pub last_heartbeat: Option<String>,
pub registered_at: Option<String>,
pub group_name: Option<String>,
#[sqlx(default)]
pub health_score: Option<i32>,
#[sqlx(default)]
pub health_level: Option<String>,
}
pub async fn list(
@@ -41,13 +67,16 @@ pub async fn list(
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let devices = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
"SELECT d.id, d.device_uid, d.hostname, d.ip_address, d.mac_address, d.os_version, d.client_version,
d.status, d.last_heartbeat, d.registered_at, d.group_name,
h.score as health_score, h.level as health_level
FROM devices d
LEFT JOIN device_health_scores h ON h.device_uid = d.device_uid
WHERE 1=1
AND (? IS NULL OR d.status = ?)
AND (? IS NULL OR d.group_name = ?)
AND (? IS NULL OR d.hostname LIKE '%' || ? || '%' OR d.ip_address LIKE '%' || ? || '%')
ORDER BY d.registered_at DESC LIMIT ? OFFSET ?"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
@@ -187,16 +216,6 @@ pub async fn remove(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<()>> {
// If client is connected, send self-destruct command
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
// Delete device and all associated data in a transaction
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
@@ -224,6 +243,8 @@ pub async fn remove(
// Delete plugin-related data
let cleanup_tables = [
"hardware_assets",
"software_assets",
"asset_changes",
"usb_events",
"usb_file_operations",
"usage_daily",
@@ -231,8 +252,20 @@ pub async fn remove(
"software_violations",
"web_access_log",
"popup_block_stats",
"disk_encryption_status",
"disk_encryption_alerts",
"print_events",
"clipboard_violations",
"behavior_metrics",
"anomaly_alerts",
"device_health_scores",
"patch_status",
];
for table in &cleanup_tables {
// Safety: table names are hardcoded constants above, not user input.
// Parameterized ? is used for device_uid.
debug_assert!(table.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
"BUG: table name contains unexpected characters: {}", table);
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
.bind(&uid)
.execute(&mut *tx)
@@ -253,6 +286,17 @@ pub async fn remove(
if let Err(e) = tx.commit().await {
return Json(ApiResponse::internal_error("commit device deletion", e));
}
// Send self-destruct command AFTER successful commit
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
state.clients.unregister(&uid).await;
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
Json(ApiResponse::ok(()))

View File

@@ -50,6 +50,9 @@ pub async fn create_group(
if name.is_empty() || name.len() > 50 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效"))));
}
if name.contains('<') || name.contains('>') || name.contains('"') || name.contains('\'') {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符"))));
}
// Check if group already exists
let exists: bool = sqlx::query_scalar(
@@ -78,6 +81,9 @@ pub async fn rename_group(
if new_name.is_empty() || new_name.len() > 50 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效"))));
}
if new_name.contains('<') || new_name.contains('>') || new_name.contains('"') || new_name.contains('\'') {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符"))));
}
let result = sqlx::query(
"UPDATE devices SET group_name = ? WHERE group_name = ?"

View File

@@ -1,4 +1,4 @@
use axum::{routing::{get, post, put, delete}, Router, Json, middleware};
use axum::{routing::{get, post, put, delete}, Router, Json, middleware, http::StatusCode, response::IntoResponse};
use serde::{Deserialize, Serialize};
use crate::AppState;
@@ -9,23 +9,31 @@ pub mod usb;
pub mod alerts;
pub mod plugins;
pub mod groups;
pub mod conflict;
pub fn routes(state: AppState) -> Router<AppState> {
let public = Router::new()
.route("/api/auth/login", post(auth::login))
.route("/api/auth/refresh", post(auth::refresh))
.route("/api/auth/logout", post(auth::logout))
.route("/health", get(health_check))
.with_state(state.clone());
// Read-only routes (any authenticated user)
let read_routes = Router::new()
// Auth
.route("/api/auth/me", get(auth::me))
.route("/api/auth/change-password", put(auth::change_password))
// WebSocket ticket (requires auth cookie)
.route("/api/ws/ticket", post(auth::create_ws_ticket))
// Devices
.route("/api/devices", get(devices::list))
.route("/api/devices/:uid", get(devices::get_detail))
.route("/api/devices/:uid/status", get(devices::get_status))
.route("/api/devices/:uid/history", get(devices::get_history))
.route("/api/devices/:uid/health-score", get(devices::get_health_score))
// Dashboard
.route("/api/dashboard/health-overview", get(devices::health_overview))
// Assets
.route("/api/assets/hardware", get(assets::list_hardware))
.route("/api/assets/software", get(assets::list_software))
@@ -40,6 +48,8 @@ pub fn routes(state: AppState) -> Router<AppState> {
.route("/api/alerts/records", get(alerts::list_records))
// Plugin read routes
.merge(plugins::read_routes())
// Policy conflict scan
.route("/api/policies/conflicts", get(conflict::scan_conflicts))
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// Write routes (admin only)
@@ -50,6 +60,8 @@ pub fn routes(state: AppState) -> Router<AppState> {
.route("/api/groups", post(groups::create_group))
.route("/api/groups/:name", put(groups::rename_group).delete(groups::delete_group))
.route("/api/devices/:uid/group", put(groups::move_device))
// TLS cert rotation
.route("/api/system/tls-rotate", post(system_tls_rotate))
// USB (write)
.route("/api/usb/policies", post(usb::create_policy))
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
@@ -76,6 +88,45 @@ pub fn routes(state: AppState) -> Router<AppState> {
.merge(ws_router)
}
/// Trigger TLS certificate rotation for all online devices.
/// Admin sends the new certificate PEM and a transition deadline.
/// The server pushes a ConfigUpdate(TlsCertRotate) to all connected clients.
#[derive(Deserialize)]
struct TlsRotateRequest {
/// Path to the new certificate PEM file
cert_path: String,
/// ISO 8601 timestamp when the old cert stops being valid (transition deadline)
valid_until: String,
}
#[derive(Serialize)]
struct TlsRotateResponse {
devices_notified: usize,
}
async fn system_tls_rotate(
axum::extract::State(state): axum::extract::State<AppState>,
Json(req): Json<TlsRotateRequest>,
) -> impl IntoResponse {
let cert_pem = match tokio::fs::read(&req.cert_path).await {
Ok(pem) => pem,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::<TlsRotateResponse>::error(
format!("Cannot read cert file {}: {}", req.cert_path, e),
)),
).into_response();
}
};
let count = crate::tcp::push_tls_cert_rotation(&state.clients, &cert_pem, &req.valid_until).await;
(StatusCode::OK, Json(ApiResponse::ok(TlsRotateResponse {
devices_notified: count,
}))).into_response()
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,

View File

@@ -0,0 +1,48 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::Deserialize;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct AnomalyListParams {
pub device_uid: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
/// GET /api/plugins/anomaly/alerts
pub async fn list_anomaly_alerts(
State(state): State<AppState>,
Query(params): Query<AnomalyListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let page = params.page.unwrap_or(1);
let page_size = params.page_size.unwrap_or(20).min(100);
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty());
match crate::anomaly::get_anomaly_summary(&state.db, device_uid, page, page_size).await {
Ok(result) => Json(ApiResponse::ok(result)),
Err(e) => Json(ApiResponse::internal_error("anomaly alerts", e)),
}
}
/// PUT /api/plugins/anomaly/alerts/:id/handle
/// Mark an anomaly alert as handled.
pub async fn handle_anomaly_alert(
State(state): State<AppState>,
Path(id): Path<i64>,
claims: axum::Extension<crate::api::auth::Claims>,
) -> Json<ApiResponse<()>> {
let result = sqlx::query(
"UPDATE anomaly_alerts SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
)
.bind(&claims.username)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => Json(ApiResponse::ok(())),
Ok(_) => Json(ApiResponse::error("Alert not found")),
Err(e) => Json(ApiResponse::internal_error("handle anomaly alert", e)),
}
}

View File

@@ -21,7 +21,7 @@ pub struct CreateRuleRequest {
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled, updated_at \
FROM clipboard_rules ORDER BY updated_at DESC"
FROM clipboard_rules ORDER BY updated_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await
@@ -127,6 +127,18 @@ pub async fn update_rule(
let content_pattern = body.content_pattern.or_else(|| existing.get::<Option<String>, _>("content_pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate merged values
if let Some(ref rt) = rule_type {
if !matches!(rt.as_str(), "allow" | "block") {
return Json(ApiResponse::error("rule_type must be 'allow' or 'block'"));
}
}
if let Some(ref d) = direction {
if !matches!(d.as_str(), "in" | "out" | "both") {
return Json(ApiResponse::error("direction must be 'in', 'out', or 'both'"));
}
}
let result = sqlx::query(
"UPDATE clipboard_rules SET rule_type = ?, direction = ?, source_process = ?, target_process = ?, content_pattern = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)

View File

@@ -28,7 +28,7 @@ pub async fn list_status(
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
ORDER BY s.device_uid, s.drive_letter"
ORDER BY s.device_uid, s.drive_letter LIMIT 500"
)
.fetch_all(&state.db)
.await
@@ -58,7 +58,7 @@ pub async fn list_alerts(State(state): State<AppState>) -> Json<ApiResponse<serd
match sqlx::query(
"SELECT a.id, a.device_uid, a.drive_letter, a.alert_type, a.status, a.created_at, a.resolved_at, \
d.hostname FROM encryption_alerts a LEFT JOIN devices d ON a.device_uid = d.device_uid \
ORDER BY a.created_at DESC"
ORDER BY a.created_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await

View File

@@ -8,6 +8,8 @@ pub mod disk_encryption;
pub mod print_audit;
pub mod clipboard_control;
pub mod plugin_control;
pub mod patch;
pub mod anomaly;
use axum::{Router, routing::{get, post, put}};
use crate::AppState;
@@ -25,6 +27,7 @@ pub fn read_routes() -> Router<AppState> {
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
.route("/api/plugins/software-blocker/whitelist", get(software_blocker::list_whitelist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
@@ -36,7 +39,6 @@ pub fn read_routes() -> Router<AppState> {
// Disk Encryption
.route("/api/plugins/disk-encryption/status", get(disk_encryption::list_status))
.route("/api/plugins/disk-encryption/alerts", get(disk_encryption::list_alerts))
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
// Print Audit
.route("/api/plugins/print-audit/events", get(print_audit::list_events))
.route("/api/plugins/print-audit/events/:id", get(print_audit::get_event))
@@ -45,6 +47,12 @@ pub fn read_routes() -> Router<AppState> {
.route("/api/plugins/clipboard-control/violations", get(clipboard_control::list_violations))
// Plugin Control
.route("/api/plugins/control", get(plugin_control::list_plugins))
// Patch Management
.route("/api/plugins/patch/status", get(patch::list_patch_status))
.route("/api/plugins/patch/summary", get(patch::patch_summary))
.route("/api/plugins/patch/device/:uid", get(patch::device_patches))
// Anomaly Detection
.route("/api/plugins/anomaly/alerts", get(anomaly::list_anomaly_alerts))
}
/// Write plugin routes (admin only — require_admin middleware applied by caller)
@@ -56,6 +64,8 @@ pub fn write_routes() -> Router<AppState> {
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
.route("/api/plugins/software-blocker/whitelist", post(software_blocker::add_to_whitelist))
.route("/api/plugins/software-blocker/whitelist/:id", put(software_blocker::update_whitelist).delete(software_blocker::remove_from_whitelist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
@@ -65,6 +75,10 @@ pub fn write_routes() -> Router<AppState> {
// Clipboard Control
.route("/api/plugins/clipboard-control/rules", post(clipboard_control::create_rule))
.route("/api/plugins/clipboard-control/rules/:id", put(clipboard_control::update_rule).delete(clipboard_control::delete_rule))
// Disk Encryption
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
// Plugin Control (enable/disable)
.route("/api/plugins/control/:plugin_name", put(plugin_control::set_plugin_state))
// Anomaly Detection — handle alert
.route("/api/plugins/anomaly/alerts/:id/handle", put(anomaly::handle_anomaly_alert))
}

View File

@@ -0,0 +1,146 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct PatchListParams {
pub device_uid: Option<String>,
pub severity: Option<String>,
#[allow(dead_code)]
pub installed: Option<i32>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
/// GET /api/plugins/patch/status
pub async fn list_patch_status(
State(state): State<AppState>,
Query(params): Query<PatchListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty());
let severity = params.severity.as_deref().filter(|s| !s.is_empty());
let rows = sqlx::query(
"SELECT p.*, d.hostname FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \
WHERE 1=1 \
AND (? IS NULL OR p.device_uid = ?) \
AND (? IS NULL OR p.severity = ?) \
ORDER BY p.updated_at DESC LIMIT ? OFFSET ?"
)
.bind(device_uid).bind(device_uid)
.bind(severity).bind(severity)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"kb_id": r.get::<String, _>("kb_id"),
"title": r.get::<String, _>("title"),
"severity": r.get::<Option<String>, _>("severity"),
"is_installed": r.get::<i32, _>("is_installed"),
"installed_at": r.get::<Option<String>, _>("installed_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
// Summary stats (scoped to same filters as main query)
let total_installed: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM patch_status WHERE is_installed = 1 \
AND (? IS NULL OR device_uid = ?) \
AND (? IS NULL OR severity = ?)"
)
.bind(device_uid).bind(device_uid)
.bind(severity).bind(severity)
.fetch_one(&state.db).await.unwrap_or(0);
let total_missing: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM patch_status WHERE is_installed = 0 \
AND (? IS NULL OR device_uid = ?) \
AND (? IS NULL OR severity = ?)"
)
.bind(device_uid).bind(device_uid)
.bind(severity).bind(severity)
.fetch_one(&state.db).await.unwrap_or(0);
Json(ApiResponse::ok(serde_json::json!({
"patches": items,
"summary": {
"total_installed": total_installed,
"total_missing": total_missing,
},
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query patch status", e)),
}
}
/// GET /api/plugins/patch/summary — per-device patch summary
pub async fn patch_summary(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT p.device_uid, d.hostname, \
COUNT(*) as total_patches, \
SUM(CASE WHEN p.is_installed = 1 THEN 1 ELSE 0 END) as installed, \
SUM(CASE WHEN p.is_installed = 0 THEN 1 ELSE 0 END) as missing, \
MAX(p.updated_at) as last_scan \
FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \
GROUP BY p.device_uid ORDER BY missing DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let devices: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"total_patches": r.get::<i64, _>("total_patches"),
"installed": r.get::<i64, _>("installed"),
"missing": r.get::<i64, _>("missing"),
"last_scan": r.get::<String, _>("last_scan"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({ "devices": devices })))
}
Err(e) => Json(ApiResponse::internal_error("patch summary", e)),
}
}
/// GET /api/plugins/patch/device/:uid — patches for a single device
pub async fn device_patches(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT kb_id, title, severity, is_installed, installed_at, updated_at \
FROM patch_status WHERE device_uid = ? ORDER BY updated_at DESC"
)
.bind(&uid)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let patches: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"kb_id": r.get::<String, _>("kb_id"),
"title": r.get::<String, _>("title"),
"severity": r.get::<Option<String>, _>("severity"),
"is_installed": r.get::<i32, _>("is_installed"),
"installed_at": r.get::<Option<String>, _>("installed_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({ "patches": patches })))
}
Err(e) => Json(ApiResponse::internal_error("device patches", e)),
}
}

View File

@@ -17,7 +17,7 @@ pub struct CreateRuleRequest {
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
@@ -47,6 +47,16 @@ pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRu
if !has_filter {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
}
// Length validation for filter fields
if let Some(ref t) = req.window_title {
if t.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_title too long (max 255)"))); }
}
if let Some(ref c) = req.window_class {
if c.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_class too long (max 255)"))); }
}
if let Some(ref p) = req.process_name {
if p.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("process_name too long (max 255)"))); }
}
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
@@ -81,6 +91,14 @@ pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Jso
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Ensure at least one filter is non-empty after update
let has_filter = window_title.as_ref().map_or(false, |s| !s.is_empty())
|| window_class.as_ref().map_or(false, |s| !s.is_empty())
|| process_name.as_ref().map_or(false, |s| !s.is_empty());
if !has_filter {
return Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required"));
}
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
.bind(&window_title)
.bind(&window_class)

View File

@@ -1,7 +1,7 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use csm_protocol::{Frame, MessageType};
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
@@ -16,7 +16,7 @@ pub struct CreateBlacklistRequest {
}
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
@@ -53,8 +53,8 @@ pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<Cre
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
let payload = fetch_software_payload_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
@@ -80,6 +80,14 @@ pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Input validation (same as create)
if name_pattern.trim().is_empty() || name_pattern.len() > 255 {
return Json(ApiResponse::error("name_pattern must be 1-255 chars"));
}
if !matches!(action.as_str(), "block" | "alert") {
return Json(ApiResponse::error("action must be 'block' or 'alert'"));
}
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern)
.bind(&action)
@@ -92,8 +100,8 @@ pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
let payload = fetch_software_payload_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
@@ -110,8 +118,8 @@ pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path
};
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
let payload = fetch_software_payload_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
@@ -134,6 +142,29 @@ pub async fn list_violations(State(state): State<AppState>, Query(f): Query<Viol
}
}
/// Build the payload for pushing software control config to clients.
/// Includes both blacklist (scoped by target) and whitelist (global).
async fn fetch_software_payload_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> serde_json::Value {
let blacklist = fetch_blacklist_for_push(db, target_type, target_id).await;
// Whitelist is always global — fetch all enabled entries
let whitelist: Vec<String> = sqlx::query_scalar(
"SELECT name_pattern FROM software_whitelist WHERE enabled = 1"
)
.fetch_all(db)
.await
.unwrap_or_default();
serde_json::json!({
"blacklist": blacklist,
"whitelist": whitelist,
})
}
async fn fetch_blacklist_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
@@ -156,3 +187,112 @@ async fn fetch_blacklist_for_push(
})).collect())
.unwrap_or_default()
}
// ─── Whitelist management ───
/// GET /api/plugins/software-blocker/whitelist
pub async fn list_whitelist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, reason, is_builtin, enabled, created_at FROM software_whitelist ORDER BY is_builtin DESC, created_at ASC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"whitelist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"),
"name_pattern": r.get::<String,_>("name_pattern"),
"reason": r.get::<Option<String>,_>("reason"),
"is_builtin": r.get::<bool,_>("is_builtin"),
"enabled": r.get::<bool,_>("enabled"),
"created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software whitelist", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreateWhitelistRequest {
pub name_pattern: String,
pub reason: Option<String>,
}
/// POST /api/plugins/software-blocker/whitelist
pub async fn add_to_whitelist(State(state): State<AppState>, Json(req): Json<CreateWhitelistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
}
match sqlx::query("INSERT INTO software_whitelist (name_pattern, reason) VALUES (?, ?)")
.bind(&req.name_pattern).bind(&req.reason)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push updated whitelist to all online clients
push_whitelist_to_all(&state).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add whitelist entry", e))),
}
}
/// PUT /api/plugins/software-blocker/whitelist/:id
#[derive(Debug, Deserialize)]
pub struct UpdateWhitelistRequest {
pub name_pattern: Option<String>,
pub enabled: Option<bool>,
}
pub async fn update_whitelist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateWhitelistRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT name_pattern, enabled FROM software_whitelist WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query whitelist", e)),
};
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Input validation — validate merged value
if name_pattern.trim().is_empty() || name_pattern.len() > 255 {
return Json(ApiResponse::error("name_pattern must be 1-255 chars"));
}
match sqlx::query("UPDATE software_whitelist SET name_pattern = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern).bind(enabled).bind(id)
.execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
push_whitelist_to_all(&state).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update whitelist", e)),
}
}
/// DELETE /api/plugins/software-blocker/whitelist/:id
pub async fn remove_from_whitelist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
match sqlx::query("DELETE FROM software_whitelist WHERE id = ? AND is_builtin = 0")
.bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
push_whitelist_to_all(&state).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found or is built-in entry")),
Err(e) => Json(ApiResponse::internal_error("remove whitelist entry", e)),
}
}
/// Push updated whitelist to all online clients by resending the full software control config.
async fn push_whitelist_to_all(state: &AppState) {
// Fetch payload once, then broadcast to all online clients
let payload = fetch_software_payload_for_push(&state.db, "global", None).await;
let frame = match Frame::new_json(MessageType::SoftwareBlacklist, &payload) {
Ok(f) => f.encode(),
Err(_) => return,
};
let online = state.clients.list_online().await;
for uid in &online {
state.clients.send_to(uid, frame.clone()).await;
}
tracing::info!("Pushed updated whitelist to {} online clients", online.len());
}

View File

@@ -16,7 +16,7 @@ pub struct CreateRuleRequest {
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
@@ -75,6 +75,14 @@ pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Jso
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate merged values
if !matches!(rule_type.as_str(), "blacklist" | "whitelist" | "category") {
return Json(ApiResponse::error("rule_type must be 'blacklist', 'whitelist', or 'category'"));
}
if pattern.trim().is_empty() || pattern.len() > 255 {
return Json(ApiResponse::error("pattern must be 1-255 chars"));
}
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
.bind(&rule_type)
.bind(&pattern)
@@ -114,17 +122,42 @@ pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) ->
}
#[derive(Debug, Deserialize)]
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
pub struct LogFilters {
pub device_uid: Option<String>,
pub action: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>() }))),
let limit = f.page_size.unwrap_or(20).min(100);
let offset = f.page.unwrap_or(1).saturating_sub(1) * limit;
let device_uid = f.device_uid.as_deref().filter(|s| !s.is_empty());
let action = f.action.as_deref().filter(|s| !s.is_empty());
let rows = sqlx::query(
"SELECT id, device_uid, url, action, timestamp FROM web_access_log \
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) \
ORDER BY timestamp DESC LIMIT ? OFFSET ?"
)
.bind(device_uid).bind(device_uid)
.bind(action).bind(action)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => Json(ApiResponse::ok(serde_json::json!({
"log": records.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"),
"device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"),
"action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>(),
"page": f.page.unwrap_or(1),
"page_size": limit,
}))),
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
}
}

View File

@@ -66,7 +66,7 @@ pub async fn list_policies(
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
FROM usb_policies ORDER BY created_at DESC"
FROM usb_policies ORDER BY created_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await;
@@ -106,6 +106,11 @@ pub async fn create_policy(
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = body.enabled.unwrap_or(1);
// Input validation
if body.name.trim().is_empty() || body.name.len() > 100 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name must be 1-100 chars")));
}
let result = sqlx::query(
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
)

315
crates/server/src/health.rs Normal file
View File

@@ -0,0 +1,315 @@
use crate::AppState;
use sqlx::Row;
use tracing::{info, error};
/// Background task: recompute device health scores every 5 minutes
pub async fn health_score_task(state: AppState) {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300));
// First computation runs immediately, then every 5 minutes
loop {
interval.tick().await;
if let Err(e) = recompute_all_scores(&state).await {
error!("Health score computation failed: {}", e);
}
}
}
async fn recompute_all_scores(state: &AppState) -> anyhow::Result<()> {
// Get all device UIDs
let devices: Vec<String> = sqlx::query_scalar(
"SELECT device_uid FROM devices"
)
.fetch_all(&state.db)
.await?;
let mut computed = 0u32;
let mut errors = 0u32;
for uid in &devices {
match compute_and_store_score(&state.db, uid).await {
Ok(score) => {
computed += 1;
tracing::debug!("Health score for {}: {} ({})", uid, score.score, score.level);
}
Err(e) => {
errors += 1;
error!("Failed to compute health score for {}: {}", uid, e);
}
}
}
if computed > 0 {
info!("Health scores computed: {} devices, {} errors", computed, errors);
}
Ok(())
}
struct HealthScoreResult {
score: i32,
status_score: i32,
encryption_score: i32,
load_score: i32,
alert_score: i32,
compliance_score: i32,
patch_score: i32,
level: String,
details: String,
}
async fn compute_and_store_score(
pool: &sqlx::SqlitePool,
device_uid: &str,
) -> anyhow::Result<HealthScoreResult> {
let mut details = Vec::new();
// 1. Online status (15 points)
let status_score: i32 = sqlx::query_scalar(
"SELECT CASE WHEN status = 'online' THEN 15 ELSE 0 END FROM devices WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
if status_score < 15 {
details.push("设备离线".to_string());
}
// 2. Disk encryption (20 points)
let encryption_score: i32 = sqlx::query_scalar(
"SELECT CASE \
WHEN COUNT(*) = 0 THEN 10 \
WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) = COUNT(*) THEN 20 \
WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) > 0 THEN 10 \
ELSE 0 END \
FROM disk_encryption_status WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
if encryption_score < 20 {
let unencrypted: Vec<String> = sqlx::query_scalar(
"SELECT drive_letter FROM disk_encryption_status WHERE device_uid = ? AND protection_status != 'On'"
)
.bind(device_uid)
.fetch_all(pool)
.await
.unwrap_or_default();
if unencrypted.is_empty() && encryption_score < 20 {
details.push("未检测到加密状态".to_string());
} else if !unencrypted.is_empty() {
details.push(format!("未加密驱动器: {}", unencrypted.join(", ")));
}
}
// 3. System load (20 points): CPU(7) + Memory(7) + Disk(6)
let load_row = sqlx::query(
"SELECT cpu_usage, memory_usage, disk_usage FROM device_status WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_optional(pool)
.await?;
let load_score = if let Some(row) = load_row {
let cpu = row.get::<f64, _>("cpu_usage");
let mem = row.get::<f64, _>("memory_usage");
let disk = row.get::<f64, _>("disk_usage");
let cpu_pts = if cpu < 70.0 { 7 } else if cpu < 90.0 { 4 } else { 0 };
let mem_pts = if mem < 80.0 { 7 } else if mem < 95.0 { 4 } else { 0 };
let disk_pts = if disk < 80.0 { 6 } else if disk < 95.0 { 3 } else { 0 };
let total = cpu_pts + mem_pts + disk_pts;
if cpu >= 90.0 { details.push(format!("CPU过高 ({:.0}%)", cpu)); }
if mem >= 95.0 { details.push(format!("内存过高 ({:.0}%)", mem)); }
if disk >= 95.0 { details.push(format!("磁盘空间不足 ({:.0}%)", disk)); }
total
} else {
details.push("无状态数据".to_string());
0
};
// 4. Alert clearance (15 points)
let unhandled_alerts: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM alert_records WHERE device_uid = ? AND handled = 0"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
let alert_score: i32 = if unhandled_alerts == 0 { 15 } else { 0 };
if unhandled_alerts > 0 {
details.push(format!("{}条未处理告警", unhandled_alerts));
}
// 5. Compliance (10 points): no recent software violations
let recent_violations: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM software_violations WHERE device_uid = ? AND timestamp > datetime('now', '-7 days')"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
let compliance_score: i32 = if recent_violations == 0 { 10 } else {
details.push(format!("近期{}次软件违规", recent_violations));
(10 - (recent_violations as i32).min(10)).max(0)
};
// 6. Patch status (20 points): reserved for future patch management
// For now, give full score if device is online
let patch_score: i32 = if status_score > 0 { 20 } else { 10 };
let score = status_score + encryption_score + load_score + alert_score + compliance_score + patch_score;
let level = if score >= 80 {
"healthy"
} else if score >= 50 {
"warning"
} else if score > 0 {
"critical"
} else {
"unknown"
};
let details_json = if details.is_empty() {
"[]".to_string()
} else {
serde_json::to_string(&details).unwrap_or_else(|_| "[]".to_string())
};
let result = HealthScoreResult {
score,
status_score,
encryption_score,
load_score,
alert_score,
compliance_score,
patch_score,
level: level.to_string(),
details: details_json,
};
// Upsert the score
sqlx::query(
"INSERT INTO device_health_scores \
(device_uid, score, status_score, encryption_score, load_score, alert_score, compliance_score, patch_score, level, details, computed_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) \
ON CONFLICT(device_uid) DO UPDATE SET \
score = excluded.score, status_score = excluded.status_score, \
encryption_score = excluded.encryption_score, load_score = excluded.load_score, \
alert_score = excluded.alert_score, compliance_score = excluded.compliance_score, \
patch_score = excluded.patch_score, level = excluded.level, \
details = excluded.details, computed_at = datetime('now')"
)
.bind(device_uid)
.bind(result.score)
.bind(result.status_score)
.bind(result.encryption_score)
.bind(result.load_score)
.bind(result.alert_score)
.bind(result.compliance_score)
.bind(result.patch_score)
.bind(&result.level)
.bind(&result.details)
.execute(pool)
.await?;
Ok(result)
}
/// Compute a single device's health score on demand
pub async fn get_device_score(
pool: &sqlx::SqlitePool,
device_uid: &str,
) -> anyhow::Result<Option<serde_json::Value>> {
// Try to get cached score
let row = sqlx::query(
"SELECT score, status_score, encryption_score, load_score, alert_score, compliance_score, \
patch_score, level, details, computed_at \
FROM device_health_scores WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_optional(pool)
.await?;
match row {
Some(r) => Ok(Some(serde_json::json!({
"device_uid": device_uid,
"score": r.get::<i32, _>("score"),
"breakdown": {
"status": r.get::<i32, _>("status_score"),
"encryption": r.get::<i32, _>("encryption_score"),
"load": r.get::<i32, _>("load_score"),
"alerts": r.get::<i32, _>("alert_score"),
"compliance": r.get::<i32, _>("compliance_score"),
"patches": r.get::<i32, _>("patch_score"),
},
"level": r.get::<String, _>("level"),
"details": serde_json::from_str::<serde_json::Value>(
&r.get::<String, _>("details")
).unwrap_or(serde_json::json!([])),
"computed_at": r.get::<String, _>("computed_at"),
}))),
None => Ok(None),
}
}
/// Get health overview for all devices (dashboard aggregation)
pub async fn get_health_overview(pool: &sqlx::SqlitePool) -> anyhow::Result<serde_json::Value> {
let rows = sqlx::query(
"SELECT h.device_uid, h.score, h.level, d.hostname, d.status, d.group_name \
FROM device_health_scores h \
JOIN devices d ON d.device_uid = h.device_uid \
ORDER BY h.score ASC"
)
.fetch_all(pool)
.await?;
let mut healthy = 0u32;
let mut warning = 0u32;
let mut critical = 0u32;
let mut unknown = 0u32;
let mut total_score = 0i64;
let mut devices: Vec<serde_json::Value> = Vec::with_capacity(rows.len());
for r in &rows {
let level: String = r.get("level");
match level.as_str() {
"healthy" => healthy += 1,
"warning" => warning += 1,
"critical" => critical += 1,
_ => unknown += 1,
}
total_score += r.get::<i32, _>("score") as i64;
devices.push(serde_json::json!({
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"status": r.get::<String, _>("status"),
"group_name": r.get::<String, _>("group_name"),
"score": r.get::<i32, _>("score"),
"level": level,
}));
}
let total = devices.len().max(1);
let avg_score = total_score as f64 / total as f64;
Ok(serde_json::json!({
"summary": {
"total": total,
"healthy": healthy,
"warning": warning,
"critical": critical,
"unknown": unknown,
"avg_score": (avg_score * 10.0).round() / 10.0,
},
"devices": devices,
}))
}

View File

@@ -7,6 +7,7 @@ use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::collections::HashMap;
use tokio::net::TcpListener;
use axum::http::Method as HttpMethod;
use tower_http::cors::CorsLayer;
@@ -23,6 +24,8 @@ mod db;
mod tcp;
mod ws;
mod alert;
mod health;
mod anomaly;
use config::AppConfig;
@@ -38,6 +41,7 @@ pub struct AppState {
pub clients: Arc<tcp::ClientRegistry>,
pub ws_hub: Arc<ws::WsHub>,
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
pub ws_tickets: Arc<tokio::sync::Mutex<HashMap<String, ws::TicketClaim>>>,
}
#[tokio::main]
@@ -58,15 +62,19 @@ async fn main() -> Result<()> {
// Security checks
if config.registration_token.is_empty() {
warn!("SECURITY: registration_token is empty — any device can register!");
anyhow::bail!("FATAL: registration_token is empty. Set it in config.toml or via CSM_REGISTRATION_TOKEN env var. Device registration is disabled for security.");
}
if config.auth.jwt_secret.len() < 32 {
warn!("SECURITY: jwt_secret is too short ({} chars) — consider using a 32+ byte key from CSM_JWT_SECRET env var", config.auth.jwt_secret.len());
if config.auth.jwt_secret.is_empty() || config.auth.jwt_secret.len() < 32 {
anyhow::bail!("FATAL: jwt_secret is missing or too short. Set CSM_JWT_SECRET env var with a 32+ byte random key.");
}
if config.server.tls.is_none() {
warn!("SECURITY: No TLS configured — all TCP communication is plaintext. Configure [server.tls] for production.");
if std::env::var("CSM_DEV").is_err() {
warn!("Set CSM_DEV=1 to suppress this warning in development environments.");
}
}
let config = Arc::new(config);
// Initialize database
let db = init_database(&config.database.path).await?;
run_migrations(&db).await?;
info!("Database initialized at {}", config.database.path);
@@ -84,6 +92,7 @@ async fn main() -> Result<()> {
clients: clients.clone(),
ws_hub: ws_hub.clone(),
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
ws_tickets: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
};
// Start background tasks
@@ -92,6 +101,12 @@ async fn main() -> Result<()> {
alert::cleanup_task(cleanup_state).await;
});
// Health score computation task
let health_state = state.clone();
tokio::spawn(async move {
health::health_score_task(health_state).await;
});
// Start TCP listener for client connections
let tcp_state = state.clone();
let tcp_addr = config.server.tcp_addr.clone();
@@ -131,7 +146,11 @@ async fn main() -> Result<()> {
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("content-security-policy"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' wss:; font-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("permissions-policy"),
axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=(), payment=()"),
))
.with_state(state);
@@ -197,6 +216,9 @@ async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
include_str!("../../../migrations/014_clipboard_control.sql"),
include_str!("../../../migrations/015_plugin_control.sql"),
include_str!("../../../migrations/016_encryption_alerts_unique.sql"),
include_str!("../../../migrations/017_device_health_scores.sql"),
include_str!("../../../migrations/018_patch_management.sql"),
include_str!("../../../migrations/019_software_whitelist.sql"),
];
// Create migrations tracking table
@@ -257,11 +279,27 @@ async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
.await?;
warn!("Created default admin user (username: admin)");
// Print password directly to stderr — bypasses tracing JSON formatter
eprintln!("============================================================");
eprintln!(" Generated admin password: {}", random_password);
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
eprintln!("============================================================");
// Write password to restricted file instead of stderr (avoid log capture)
let pw_path = std::path::Path::new("data/initial-password.txt");
if let Some(parent) = pw_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
match std::fs::write(pw_path, &random_password) {
Ok(_) => {
warn!("Initial admin password saved to data/initial-password.txt (delete after first login)");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(pw_path, std::fs::Permissions::from_mode(0o600));
}
#[cfg(not(unix))]
{
// Windows: restrict ACL would require windows-rs; at minimum hide the file
let _ = std::process::Command::new("attrib").args(["+H", &pw_path.to_string_lossy()]).output();
}
}
Err(e) => warn!("Failed to save initial password to file: {}. Password was: {}", e, random_password),
}
}
Ok(())
@@ -278,13 +316,14 @@ fn build_cors_layer(origins: &[String]) -> CorsLayer {
.collect();
if allowed_origins.is_empty() {
// No CORS — production safe by default
// No CORS — production safe by default (same-origin cookies work without CORS)
CorsLayer::new()
} else {
CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
.allow_methods([HttpMethod::GET, HttpMethod::POST, HttpMethod::PUT, HttpMethod::DELETE])
.allow_headers([axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE])
.allow_credentials(true)
.max_age(std::time::Duration::from_secs(3600))
}
}

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tracing::{info, warn, debug};
use hmac::{Hmac, Mac};
use sha2::Sha256;
@@ -167,7 +167,7 @@ pub async fn push_all_plugin_configs(
}
}
// Software blacklist
// Software blacklist + whitelist
if let Ok(rows) = sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
@@ -180,8 +180,20 @@ pub async fn push_all_plugin_configs(
"name_pattern": r.get::<String, _>("name_pattern"),
"action": r.get::<String, _>("action"),
})).collect();
if !entries.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
// Fetch whitelist (global, always pushed to all devices)
let whitelist: Vec<String> = sqlx::query_scalar(
"SELECT name_pattern FROM software_whitelist WHERE enabled = 1"
)
.fetch_all(db)
.await
.unwrap_or_default();
if !entries.is_empty() || !whitelist.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({
"blacklist": entries,
"whitelist": whitelist,
})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
@@ -261,17 +273,53 @@ pub async fn push_all_plugin_configs(
}
}
// Disk encryption config — push default reporting interval (no dedicated config table)
// Disk encryption config — read from patch_policies if available, else default
{
let config = csm_protocol::DiskEncryptionConfigPayload {
enabled: true,
report_interval_secs: 3600,
let config = if let Ok(Some(row)) = sqlx::query(
"SELECT auto_approve, enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1"
)
.fetch_optional(db)
.await
{
// If patch_policies exist, infer disk encryption should be enabled
csm_protocol::DiskEncryptionConfigPayload {
enabled: row.get::<i32, _>("enabled") != 0,
report_interval_secs: 3600,
}
} else {
csm_protocol::DiskEncryptionConfigPayload {
enabled: true,
report_interval_secs: 3600,
}
};
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
// Patch scan config — read from patch_policies if available, else default
{
let config = if let Ok(Some(row)) = sqlx::query(
"SELECT enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1"
)
.fetch_optional(db)
.await
{
csm_protocol::PatchScanConfigPayload {
enabled: row.get::<i32, _>("enabled") != 0,
scan_interval_secs: 43200,
}
} else {
csm_protocol::PatchScanConfigPayload {
enabled: true,
scan_interval_secs: 43200,
}
};
if let Ok(frame) = Frame::new_json(MessageType::PatchScanConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
// Push plugin enable/disable state — disable any plugins that admin has turned off
if let Ok(rows) = sqlx::query(
"SELECT plugin_name FROM plugin_state WHERE enabled = 0"
@@ -297,10 +345,17 @@ pub async fn push_all_plugin_configs(
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Registry of all connected client sessions
/// Registry of all connected client sessions, including cached device secrets.
#[derive(Clone, Default)]
pub struct ClientRegistry {
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
sessions: Arc<RwLock<HashMap<String, ClientSession>>>,
}
/// Per-device session data kept in memory for fast access.
struct ClientSession {
tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
/// Cached device_secret for HMAC verification — avoids a DB query per heartbeat.
secret: Option<String>,
}
impl ClientRegistry {
@@ -308,8 +363,8 @@ impl ClientRegistry {
Self::default()
}
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, tx);
pub async fn register(&self, device_uid: String, secret: Option<String>, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, ClientSession { tx, secret });
}
pub async fn unregister(&self, device_uid: &str) {
@@ -317,13 +372,25 @@ impl ClientRegistry {
}
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
if let Some(tx) = self.sessions.read().await.get(device_uid) {
tx.send(data).await.is_ok()
if let Some(session) = self.sessions.read().await.get(device_uid) {
session.tx.send(data).await.is_ok()
} else {
false
}
}
/// Get cached device secret for HMAC verification (avoids DB query per heartbeat).
pub async fn get_secret(&self, device_uid: &str) -> Option<String> {
self.sessions.read().await.get(device_uid).and_then(|s| s.secret.clone())
}
/// Backfill cached device secret after a cache miss (e.g. server restart).
pub async fn set_secret(&self, device_uid: &str, secret: String) {
if let Some(session) = self.sessions.write().await.get_mut(device_uid) {
session.secret = Some(secret);
}
}
pub async fn count(&self) -> usize {
self.sessions.read().await.len()
}
@@ -366,7 +433,7 @@ pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<(
Some(acceptor) => {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_client_tls(tls_stream, state).await {
if let Err(e) = handle_client(tls_stream, state).await {
warn!("Client {} TLS error: {}", peer_addr, e);
}
}
@@ -451,6 +518,17 @@ fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &
}
}
/// Constant-time string comparison to prevent timing attacks on secrets.
fn constant_time_eq(a: &str, b: &str) -> bool {
use std::iter;
if a.len() != b.len() {
// Still do a comparison to avoid leaking length via timing
let _ = a.as_bytes().iter().zip(iter::repeat(0u8)).map(|(x, y)| x ^ y);
return false;
}
a.as_bytes().iter().zip(b.as_bytes()).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0
}
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
/// `hmac_fail_count` tracks consecutive HMAC failures; caller checks it for disconnect threshold.
async fn process_frame(
@@ -467,10 +545,10 @@ async fn process_frame(
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
// Validate registration token against configured token
// Validate registration token against configured token (constant-time comparison)
let expected_token = &state.config.registration_token;
if !expected_token.is_empty() {
if req.registration_token.is_empty() || req.registration_token != *expected_token {
if req.registration_token.is_empty() || !constant_time_eq(&req.registration_token, expected_token) {
warn!("Registration rejected for {}: invalid token", req.device_uid);
let err_frame = Frame::new_json(MessageType::RegisterResponse,
&serde_json::json!({"error": "invalid_registration_token"}))?;
@@ -514,7 +592,7 @@ async fn process_frame(
*device_uid = Some(req.device_uid.clone());
// If this device was already connected on a different session, evict the old one
// The new register() call will replace it in the hashmap
state.clients.register(req.device_uid.clone(), tx.clone()).await;
state.clients.register(req.device_uid.clone(), Some(device_secret.clone()), tx.clone()).await;
// Send registration response
let config = csm_protocol::ClientConfig::default();
@@ -539,17 +617,25 @@ async fn process_frame(
return Ok(());
}
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
let secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
// Verify HMAC — use cached secret from ClientRegistry, fall back to DB on cache miss (e.g. after restart)
let mut secret = state.clients.get_secret(&heartbeat.device_uid).await;
if secret.is_none() {
// Cache miss (server restarted) — query DB and backfill cache
let db_secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
if let Some(ref s) = db_secret {
state.clients.set_secret(&heartbeat.device_uid, s.clone()).await;
}
secret = db_secret;
}
if let Some(ref secret) = secret {
if !secret.is_empty() {
@@ -650,6 +736,40 @@ async fn process_frame(
crate::db::DeviceRepo::upsert_software(&state.db, &sw).await?;
}
MessageType::AssetChange => {
let change: csm_protocol::AssetChange = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid asset change: {}", e))?;
if !verify_device_uid(device_uid, "AssetChange", &change.device_uid) {
return Ok(());
}
let change_type_str = match change.change_type {
csm_protocol::AssetChangeType::Hardware => "hardware",
csm_protocol::AssetChangeType::SoftwareAdded => "software_added",
csm_protocol::AssetChangeType::SoftwareRemoved => "software_removed",
};
sqlx::query(
"INSERT INTO asset_changes (device_uid, change_type, change_detail, detected_at) \
VALUES (?, ?, ?, datetime('now'))"
)
.bind(&change.device_uid)
.bind(change_type_str)
.bind(serde_json::to_string(&change.change_detail).map_err(|e| anyhow::anyhow!("Failed to serialize asset change detail: {}", e))?)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting asset change: {}", e))?;
debug!("Asset change: {} {:?} for device {}", change_type_str, change.change_detail, change.device_uid);
state.ws_hub.broadcast(serde_json::json!({
"type": "asset_change",
"device_uid": change.device_uid,
"change_type": change_type_str,
}).to_string()).await;
}
MessageType::UsageReport => {
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
@@ -910,23 +1030,86 @@ async fn process_frame(
return Ok(());
}
for rule_stat in &stats.rule_stats {
sqlx::query(
"INSERT INTO popup_block_stats (device_uid, rule_id, blocked_count, period_secs, reported_at) \
VALUES (?, ?, ?, ?, datetime('now'))"
)
.bind(&stats.device_uid)
.bind(rule_stat.rule_id)
.bind(rule_stat.hits as i32)
.bind(stats.period_secs as i32)
.execute(&state.db)
.await
.ok();
}
// Upsert aggregate stats per device per day
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
sqlx::query(
"INSERT INTO popup_block_stats (device_uid, blocked_count, date) \
VALUES (?, ?, ?) \
ON CONFLICT(device_uid, date) DO UPDATE SET \
blocked_count = blocked_count + excluded.blocked_count"
)
.bind(&stats.device_uid)
.bind(stats.blocked_count as i32)
.bind(&today)
.execute(&state.db)
.await
.ok();
debug!("Popup block stats: {} blocked {} windows in {}s", stats.device_uid, stats.blocked_count, stats.period_secs);
}
MessageType::PatchStatusReport => {
let payload: csm_protocol::PatchStatusPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid patch status report: {}", e))?;
if !verify_device_uid(device_uid, "PatchStatusReport", &payload.device_uid) {
return Ok(());
}
for patch in &payload.patches {
sqlx::query(
"INSERT INTO patch_status (device_uid, kb_id, title, severity, is_installed, installed_at, updated_at) \
VALUES (?, ?, ?, ?, ?, ?, datetime('now')) \
ON CONFLICT(device_uid, kb_id) DO UPDATE SET \
title = excluded.title, severity = COALESCE(excluded.severity, patch_status.severity), \
is_installed = excluded.is_installed, installed_at = excluded.installed_at, \
updated_at = datetime('now')"
)
.bind(&payload.device_uid)
.bind(&patch.kb_id)
.bind(&patch.title)
.bind(&patch.severity)
.bind(patch.is_installed as i32)
.bind(&patch.installed_at)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting patch status: {}", e))?;
}
info!("Patch status reported: {} ({} patches)", payload.device_uid, payload.patches.len());
}
MessageType::BehaviorMetricsReport => {
let metrics: csm_protocol::BehaviorMetricsPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid behavior metrics: {}", e))?;
if !verify_device_uid(device_uid, "BehaviorMetricsReport", &metrics.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO behavior_metrics (device_uid, clipboard_ops_count, clipboard_ops_night, print_jobs_count, usb_file_ops_count, new_processes_count, period_secs, reported_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(&metrics.device_uid)
.bind(metrics.clipboard_ops_count as i32)
.bind(metrics.clipboard_ops_night as i32)
.bind(metrics.print_jobs_count as i32)
.bind(metrics.usb_file_ops_count as i32)
.bind(metrics.new_processes_count as i32)
.bind(metrics.period_secs as i32)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting behavior metrics: {}", e))?;
// Run anomaly detection inline
crate::anomaly::check_anomalies(&state.db, &state.ws_hub, &metrics).await;
debug!("Behavior metrics saved: {} (clipboard={}, print={}, usb_file={}, procs={})",
metrics.device_uid, metrics.clipboard_ops_count, metrics.print_jobs_count,
metrics.usb_file_ops_count, metrics.new_processes_count);
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
@@ -935,13 +1118,14 @@ async fn process_frame(
Ok(())
}
/// Handle a single client TCP connection
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
/// Handle a single client TCP connection (plaintext or TLS)
async fn handle_client<S>(stream: S, state: AppState) -> anyhow::Result<()>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let _ = stream.set_nodelay(true);
let (mut reader, mut writer) = stream.into_split();
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
@@ -1018,81 +1202,50 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
Ok(())
}
/// Handle a TLS-wrapped client connection
async fn handle_client_tls(
stream: tokio_rustls::server::TlsStream<TcpStream>,
state: AppState,
) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let hmac_fail_count = Arc::new(AtomicU32::new(0));
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
/// Push a TLS certificate rotation notice to all online devices.
/// Computes the fingerprint of the new certificate and sends ConfigUpdate(TlsCertRotate).
pub async fn push_tls_cert_rotation(clients: &ClientRegistry, new_cert_pem: &[u8], valid_until: &str) -> usize {
// Compute SHA-256 fingerprint of the new certificate
let certs: Vec<_> = match rustls_pemfile::certs(&mut &new_cert_pem[..]).collect::<Result<Vec<_>, _>>() {
Ok(c) => c,
Err(e) => {
warn!("Failed to parse new certificate for rotation: {:?}", e);
return 0;
}
});
};
// Reader loop with idle timeout
'reader: loop {
let read_result = tokio::time::timeout(
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
reader.read(&mut buffer),
).await;
let end_entity = match certs.first() {
Some(c) => c,
None => {
warn!("No certificates found in PEM for rotation");
return 0;
}
};
let n = match read_result {
Ok(Ok(0)) => break,
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
warn!("Idle timeout for TLS device {:?}, disconnecting", device_uid);
break;
}
let fingerprint = {
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(end_entity.as_ref());
hex::encode(hasher.finalize())
};
info!("Pushing TLS cert rotation: new fingerprint={}... valid_until={}", &fingerprint[..16], valid_until);
let config_update = csm_protocol::ConfigUpdateType::TlsCertRotate {
new_cert_hash: fingerprint,
valid_until: valid_until.to_string(),
};
let online = clients.list_online().await;
let mut pushed = 0usize;
for uid in &online {
let frame = match Frame::new_json(MessageType::ConfigUpdate, &config_update) {
Ok(f) => f,
Err(_) => continue,
};
read_buf.extend_from_slice(&buffer[..n]);
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("TLS connection exceeded max buffer size, dropping");
break;
}
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
read_buf.drain(..frame_size);
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if !rate_limiter.check() {
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
break 'reader;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
warn!("Frame processing error: {}", e);
}
// Disconnect if too many consecutive HMAC failures
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
warn!("Too many HMAC failures for TLS device {:?}, disconnecting", device_uid);
break 'reader;
}
if clients.send_to(uid, frame.encode()).await {
pushed += 1;
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
pushed
}

View File

@@ -1,12 +1,10 @@
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
use axum::response::IntoResponse;
use axum::extract::Query;
use jsonwebtoken::{decode, Validation, DecodingKey};
use serde::Deserialize;
use tokio::sync::broadcast;
use std::sync::Arc;
use tracing::{debug, warn};
use crate::api::auth::Claims;
use crate::AppState;
/// WebSocket hub for broadcasting real-time events to admin browsers
@@ -32,65 +30,73 @@ impl WsHub {
}
}
#[derive(Debug, Deserialize)]
pub struct WsAuthParams {
pub token: Option<String>,
/// Claim stored when a WS ticket is created. Consumed on WS connection.
#[derive(Debug, Clone)]
pub struct TicketClaim {
pub user_id: i64,
pub username: String,
pub role: String,
pub created_at: std::time::Instant,
}
/// HTTP upgrade handler for WebSocket connections
/// Validates JWT token from query parameter before upgrading
#[derive(Debug, Deserialize)]
pub struct WsTicketParams {
pub ticket: Option<String>,
}
/// HTTP upgrade handler for WebSocket connections.
/// Validates a one-time ticket (obtained via POST /api/ws/ticket) before upgrading.
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(params): Query<WsAuthParams>,
Query(params): Query<WsTicketParams>,
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl IntoResponse {
let token = match params.token {
let ticket = match params.ticket {
Some(t) => t,
None => {
warn!("WebSocket connection rejected: no token provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
warn!("WebSocket connection rejected: no ticket provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing ticket").into_response();
}
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(e) => {
warn!("WebSocket connection rejected: invalid token - {}", e);
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
// Consume (remove) the ticket from the store — single use
let claim = {
let mut tickets = state.ws_tickets.lock().await;
match tickets.remove(&ticket) {
Some(claim) => claim,
None => {
warn!("WebSocket connection rejected: invalid or expired ticket");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid or expired ticket").into_response();
}
}
};
if claims.token_type != "access" {
warn!("WebSocket connection rejected: not an access token");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
// Check ticket age (30 second TTL)
if claim.created_at.elapsed().as_secs() > 30 {
warn!("WebSocket connection rejected: ticket expired");
return (axum::http::StatusCode::UNAUTHORIZED, "Ticket expired").into_response();
}
let hub = state.ws_hub.clone();
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
ws.on_upgrade(move |socket| handle_socket(socket, claim, hub))
}
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claims.username);
async fn handle_socket(mut socket: WebSocket, claim: TicketClaim, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claim.username);
let welcome = serde_json::json!({
"type": "connected",
"message": "CSM real-time feed active",
"user": claims.username
"user": claim.username
});
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
return;
}
// Subscribe to broadcast hub for real-time events
let mut rx = hub.subscribe();
loop {
tokio::select! {
// Forward broadcast messages to WebSocket client
msg = rx.recv() => {
match msg {
Ok(text) => {
@@ -104,7 +110,6 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
Err(broadcast::error::RecvError::Closed) => break,
}
}
// Handle incoming WebSocket messages (ping/close)
msg = socket.recv() => {
match msg {
Some(Ok(Message::Ping(data))) => {
@@ -121,5 +126,5 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
}
}
debug!("WebSocket client disconnected: user={}", claims.username);
debug!("WebSocket client disconnected: user={}", claim.username);
}