From 60ee38a3c218af1d00df83fa5fa1136fbb231b5a Mon Sep 17 00:00:00 2001 From: iven Date: Sat, 11 Apr 2026 15:59:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E8=A1=A5=E4=B8=81?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=92=8C=E5=BC=82=E5=B8=B8=E6=A3=80=E6=B5=8B?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E5=8F=8A=E7=9B=B8=E5=85=B3=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat(protocol): 添加补丁管理和行为指标协议类型 feat(client): 实现补丁管理插件采集功能 feat(server): 添加补丁管理和异常检测API feat(database): 新增补丁状态和异常检测相关表 feat(web): 添加补丁管理和异常检测前端页面 fix(security): 增强输入验证和防注入保护 refactor(auth): 重构认证检查逻辑 perf(service): 优化Windows服务恢复策略 style: 统一健康评分显示样式 docs: 更新知识库文档 --- CLAUDE.md | 2 + crates/client/src/asset/mod.rs | 8 + crates/client/src/main.rs | 33 +- crates/client/src/network/mod.rs | 208 ++++++++- crates/client/src/patch/mod.rs | 116 +++++ crates/client/src/service.rs | 31 +- crates/client/src/software_blocker/mod.rs | 263 ++++++++--- crates/protocol/src/lib.rs | 2 + crates/protocol/src/message.rs | 52 ++- crates/server/src/alert.rs | 20 + crates/server/src/anomaly.rs | 170 +++++++ crates/server/src/api/alerts.rs | 41 +- crates/server/src/api/auth.rs | 431 ++++++++++++++---- crates/server/src/api/conflict.rs | 243 ++++++++++ crates/server/src/api/devices.rs | 78 +++- crates/server/src/api/groups.rs | 6 + crates/server/src/api/mod.rs | 53 ++- crates/server/src/api/plugins/anomaly.rs | 48 ++ .../src/api/plugins/clipboard_control.rs | 14 +- .../server/src/api/plugins/disk_encryption.rs | 4 +- crates/server/src/api/plugins/mod.rs | 16 +- crates/server/src/api/plugins/patch.rs | 146 ++++++ .../server/src/api/plugins/popup_blocker.rs | 20 +- .../src/api/plugins/software_blocker.rs | 156 ++++++- crates/server/src/api/plugins/web_filter.rs | 53 ++- crates/server/src/api/usb.rs | 7 +- crates/server/src/health.rs | 315 +++++++++++++ crates/server/src/main.rs | 63 ++- crates/server/src/tcp.rs | 389 +++++++++++----- crates/server/src/ws.rs | 67 +-- migrations/017_device_health_scores.sql | 20 + migrations/018_patch_management.sql | 59 +++ migrations/019_software_whitelist.sql | 54 +++ web/src/lib/api.ts | 91 ++-- web/src/router/index.ts | 45 +- web/src/views/Dashboard.vue | 107 ++++- web/src/views/Devices.vue | 41 ++ web/src/views/Layout.vue | 30 +- web/src/views/Settings.vue | 15 +- web/src/views/plugins/AnomalyDetection.vue | 90 ++++ web/src/views/plugins/PatchManagement.vue | 92 ++++ wiki/SECURITY-AUDIT.md | 255 +++++++++++ wiki/client.md | 88 ++++ wiki/database.md | 71 +++ wiki/index.md | 46 ++ wiki/plugins.md | 78 ++++ wiki/protocol.md | 58 +++ wiki/server.md | 84 ++++ wiki/web-frontend.md | 70 +++ 49 files changed, 3988 insertions(+), 461 deletions(-) create mode 100644 crates/client/src/patch/mod.rs create mode 100644 crates/server/src/anomaly.rs create mode 100644 crates/server/src/api/conflict.rs create mode 100644 crates/server/src/api/plugins/anomaly.rs create mode 100644 crates/server/src/api/plugins/patch.rs create mode 100644 crates/server/src/health.rs create mode 100644 migrations/017_device_health_scores.sql create mode 100644 migrations/018_patch_management.sql create mode 100644 migrations/019_software_whitelist.sql create mode 100644 web/src/views/plugins/AnomalyDetection.vue create mode 100644 web/src/views/plugins/PatchManagement.vue create mode 100644 wiki/SECURITY-AUDIT.md create mode 100644 wiki/client.md create mode 100644 wiki/database.md create mode 100644 wiki/index.md create mode 100644 wiki/plugins.md create mode 100644 wiki/protocol.md create mode 100644 wiki/server.md create mode 100644 wiki/web-frontend.md diff --git a/CLAUDE.md b/CLAUDE.md index ca5616d..4d1ef7f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,5 +1,7 @@ # CSM — 企业终端安全管理系统 +> **知识库**: @wiki/index.md — 编译后的模块化知识,新会话加载即了解全貌。 + ## 项目概览 CSM (Client Security Manager) 是一个医院设备终端安全管控平台,采用 C/S + Web 管理面板三层架构。 diff --git a/crates/client/src/asset/mod.rs b/crates/client/src/asset/mod.rs index ea480fd..0b780bb 100644 --- a/crates/client/src/asset/mod.rs +++ b/crates/client/src/asset/mod.rs @@ -121,6 +121,14 @@ fn collect_system_details() -> (Option, Option, Option) #[cfg(target_os = "windows")] fn powershell_lines(command: &str) -> Vec { use std::process::Command; + + // Reject commands containing suspicious patterns that could indicate injection + let lower = command.to_lowercase(); + if lower.contains("invoke-expression") || lower.contains("iex ") || lower.contains("& ") { + tracing::warn!("Rejected suspicious PowerShell command pattern"); + return Vec::new(); + } + let output = match Command::new("powershell") .args(["-NoProfile", "-NonInteractive", "-Command", &format!("[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}", command)]) diff --git a/crates/client/src/main.rs b/crates/client/src/main.rs index 07f9907..6535c14 100644 --- a/crates/client/src/main.rs +++ b/crates/client/src/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use tracing::{info, error, warn}; +use tracing::{info, error, warn, debug}; use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload}; mod monitor; @@ -17,6 +17,7 @@ mod web_filter; mod disk_encryption; mod clipboard_control; mod print_audit; +mod patch; #[cfg(target_os = "windows")] mod service; @@ -58,17 +59,22 @@ fn main() -> Result<()> { info!("CSM Client starting (console mode)..."); let device_uid = load_or_create_device_uid()?; - info!("Device UID: {}", device_uid); + debug!("Device UID: {}", device_uid); let server_addr = std::env::var("CSM_SERVER") .unwrap_or_else(|_| "127.0.0.1:9999".to_string()); + let registration_token = std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(); + if registration_token.is_empty() { + tracing::warn!("CSM_REGISTRATION_TOKEN not set — device registration may fail"); + } + let state = ClientState { device_uid, server_addr, config: ClientConfig::default(), device_secret: load_device_secret(), - registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(), + registration_token, use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"), }; @@ -97,6 +103,7 @@ pub async fn run(state: ClientState) -> Result<()> { let (disk_encryption_tx, disk_encryption_rx) = tokio::sync::watch::channel(disk_encryption::DiskEncryptionConfig::default()); let (print_audit_tx, print_audit_rx) = tokio::sync::watch::channel(print_audit::PrintAuditConfig::default()); let (clipboard_control_tx, clipboard_control_rx) = tokio::sync::watch::channel(clipboard_control::ClipboardControlConfig::default()); + let (patch_tx, patch_rx) = tokio::sync::watch::channel(patch::PluginConfig::default()); // Build plugin channels struct let plugins = network::PluginChannels { @@ -110,6 +117,7 @@ pub async fn run(state: ClientState) -> Result<()> { disk_encryption_tx, print_audit_tx, clipboard_control_tx, + patch_tx, }; // Spawn core monitoring tasks @@ -182,6 +190,12 @@ pub async fn run(state: ClientState) -> Result<()> { clipboard_control::start(clipboard_control_rx, cc_data_tx, cc_uid).await; }); + let patch_data_tx = data_tx.clone(); + let patch_uid = state.device_uid.clone(); + tokio::spawn(async move { + patch::start(patch_rx, patch_data_tx, patch_uid).await; + }); + // Connect to server with reconnect let mut backoff = Duration::from_secs(1); let max_backoff = Duration::from_secs(60); @@ -270,5 +284,16 @@ fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Resu #[cfg(not(unix))] fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> { - std::fs::write(path, content) + std::fs::write(path, content)?; + // Restrict file ACL to SYSTEM and Administrators only on Windows + let path_str = path.to_string_lossy(); + // Remove inherited permissions and grant only SYSTEM full control + let _ = std::process::Command::new("icacls") + .args([&*path_str, "/inheritance:r", "/grant:r", "SYSTEM:(F)", "/grant:r", "Administrators:(F)"]) + .output(); + // Also hide the file + let _ = std::process::Command::new("attrib") + .args(["+H", &*path_str]) + .output(); + Ok(()) } diff --git a/crates/client/src/network/mod.rs b/crates/client/src/network/mod.rs index 737092a..9d36f08 100644 --- a/crates/client/src/network/mod.rs +++ b/crates/client/src/network/mod.rs @@ -1,11 +1,13 @@ use anyhow::Result; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::net::TcpStream; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::{info, debug, warn}; -use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload}; +use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload, PatchScanConfigPayload}; use hmac::{Hmac, Mac}; -use sha2::Sha256; +use sha2::{Sha256, Digest}; use crate::ClientState; @@ -21,6 +23,7 @@ pub struct PluginChannels { pub disk_encryption_tx: tokio::sync::watch::Sender, pub print_audit_tx: tokio::sync::watch::Sender, pub clipboard_control_tx: tokio::sync::watch::Sender, + pub patch_tx: tokio::sync::watch::Sender, } /// Connect to server and run the main communication loop @@ -30,7 +33,7 @@ pub async fn connect_and_run( plugins: &PluginChannels, ) -> Result<()> { let tcp_stream = TcpStream::connect(&state.server_addr).await?; - info!("TCP connected to {}", state.server_addr); + debug!("TCP connected to {}", state.server_addr); if state.use_tls { let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?; @@ -40,9 +43,7 @@ pub async fn connect_and_run( } } -/// Wrap a TCP stream with TLS. -/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file). -/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only). +/// Wrap a TCP stream with TLS and certificate pinning. async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result> { let mut root_store = rustls::RootCertStore::empty(); @@ -62,19 +63,38 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result Result PathBuf { + if let Ok(custom) = std::env::var("CSM_TLS_PIN_FILE") { + PathBuf::from(custom) + } else if cfg!(target_os = "windows") { + std::env::var("PROGRAMDATA") + .map(|p| PathBuf::from(p).join("CSM").join("server_cert_pin")) + .unwrap_or_else(|_| PathBuf::from("server_cert_pin")) + } else { + PathBuf::from("/var/lib/csm/server_cert_pin") + } +} + +/// Load pinned certificate hashes from file. +/// Format: one hex-encoded SHA-256 hash per line. +fn load_pinned_hashes(path: &PathBuf) -> Vec { + match std::fs::read_to_string(path) { + Ok(content) => content.lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect(), + Err(_) => Vec::new(), // First connection — no pin file yet + } +} + +/// Save a pinned hash to the pin file. +fn save_pinned_hash(path: &PathBuf, hash: &str) { + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let _ = std::fs::write(path, format!("{}\n", hash)); +} + +/// Compute SHA-256 fingerprint of a DER-encoded certificate. +fn cert_fingerprint(cert: &rustls_pki_types::CertificateDer) -> String { + let mut hasher = Sha256::new(); + hasher.update(cert.as_ref()); + hex::encode(hasher.finalize()) +} + +/// Certificate verifier with pinning support. +/// On first connection (no stored pin), records the certificate fingerprint. +/// On subsequent connections, verifies the fingerprint matches. +#[derive(Debug)] +struct PinnedCertVerifier { + inner: Arc, + pin_file: PathBuf, + pinned_hashes: Arc>>, +} + +impl rustls::client::danger::ServerCertVerifier for PinnedCertVerifier { + fn verify_server_cert( + &self, + end_entity: &rustls_pki_types::CertificateDer, + intermediates: &[rustls_pki_types::CertificateDer], + server_name: &rustls_pki_types::ServerName, + ocsp_response: &[u8], + now: rustls_pki_types::UnixTime, + ) -> Result { + // 1. Standard PKIX verification + self.inner.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?; + + // 2. Compute certificate fingerprint + let fingerprint = cert_fingerprint(end_entity); + + // 3. Check against pinned hashes + let mut pinned = self.pinned_hashes.lock().unwrap(); + if pinned.is_empty() { + // First connection — record the certificate fingerprint + info!("Recording server certificate pin: {}...", &fingerprint[..16]); + save_pinned_hash(&self.pin_file, &fingerprint); + pinned.push(fingerprint); + } else if !pinned.contains(&fingerprint) { + warn!("Certificate pin mismatch! Expected one of {:?}, got {}", pinned, fingerprint); + return Err(rustls::Error::General( + "Server certificate does not match pinned fingerprint. Possible MITM attack.".into(), + )); + } + + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &rustls_pki_types::CertificateDer, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &rustls_pki_types::CertificateDer, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +/// Update pinned certificate hash (called when receiving TlsCertRotate). +pub fn update_cert_pin(new_hash: &str) { + let pin_file = pin_file_path(); + let mut pinned = load_pinned_hashes(&pin_file); + if !pinned.contains(&new_hash.to_string()) { + pinned.push(new_hash.to_string()); + // Keep only the last 2 hashes (current + rotating) + while pinned.len() > 2 { + pinned.remove(0); + } + // Write all hashes to file + if let Some(parent) = pin_file.parent() { + let _ = std::fs::create_dir_all(parent); + } + let content = pinned.iter().map(|h| h.as_str()).collect::>().join("\n"); + let _ = std::fs::write(&pin_file, format!("{}\n", content)); + info!("Updated certificate pin file with new hash: {}...", &new_hash[..16]); + } +} + /// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true). #[derive(Debug)] struct NoVerifier; @@ -242,7 +387,20 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> { info!("Received policy update: {}", policy); } MessageType::ConfigUpdate => { - info!("Received config update"); + let update: csm_protocol::ConfigUpdateType = frame.decode_payload() + .map_err(|e| anyhow::anyhow!("Invalid config update: {}", e))?; + match update { + csm_protocol::ConfigUpdateType::UpdateIntervals { heartbeat, status, asset } => { + info!("Config update: intervals heartbeat={}s status={}s asset={}s", heartbeat, status, asset); + } + csm_protocol::ConfigUpdateType::TlsCertRotate { new_cert_hash, valid_until } => { + info!("Certificate rotation: new hash={}... valid_until={}", &new_cert_hash[..16.min(new_cert_hash.len())], valid_until); + update_cert_pin(&new_cert_hash); + } + csm_protocol::ConfigUpdateType::SelfDestruct => { + warn!("Self-destruct command received (not implemented)"); + } + } } MessageType::TaskExecute => { warn!("Task execution requested (not yet implemented)"); @@ -276,7 +434,14 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> { let blacklist: Vec = payload.get("blacklist") .and_then(|r| serde_json::from_value(r.clone()).ok()) .unwrap_or_default(); - let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist }; + let whitelist: Vec = payload.get("whitelist") + .and_then(|r| serde_json::from_value(r.clone()).ok()) + .unwrap_or_default(); + let config = crate::software_blocker::SoftwareBlockerConfig { + enabled: true, + blacklist, + whitelist, + }; plugins.software_blocker_tx.send(config)?; } MessageType::PopupRules => { @@ -322,6 +487,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> { }; plugins.clipboard_control_tx.send(config)?; } + MessageType::PatchScanConfig => { + let config: PatchScanConfigPayload = frame.decode_payload() + .map_err(|e| anyhow::anyhow!("Invalid patch scan config: {}", e))?; + info!("Received patch scan config: enabled={}, interval={}s", config.enabled, config.scan_interval_secs); + let plugin_config = crate::patch::PluginConfig { + enabled: config.enabled, + scan_interval_secs: config.scan_interval_secs, + }; + plugins.patch_tx.send(plugin_config)?; + } _ => { debug!("Unhandled message type: {:?}", frame.msg_type); } @@ -351,7 +526,7 @@ fn handle_plugin_control( } "software_blocker" => { if !enabled { - plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?; + plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![], whitelist: vec![] })?; } } "popup_blocker" => { @@ -384,6 +559,11 @@ fn handle_plugin_control( plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?; } } + "patch" => { + if !enabled { + plugins.patch_tx.send(crate::patch::PluginConfig { enabled: false, scan_interval_secs: 43200 })?; + } + } _ => { warn!("Unknown plugin: {}", payload.plugin_name); } diff --git a/crates/client/src/patch/mod.rs b/crates/client/src/patch/mod.rs new file mode 100644 index 0000000..bf9cd1f --- /dev/null +++ b/crates/client/src/patch/mod.rs @@ -0,0 +1,116 @@ +use tokio::sync::watch; +use csm_protocol::{Frame, MessageType, PatchStatusPayload, PatchEntry}; +use tracing::{debug, warn}; + +#[derive(Debug, Clone, Default)] +pub struct PluginConfig { + pub enabled: bool, + pub scan_interval_secs: u64, +} + +pub async fn start( + mut config_rx: watch::Receiver, + data_tx: tokio::sync::mpsc::Sender, + device_uid: String, +) { + let mut config = config_rx.borrow_and_update().clone(); + let mut interval = tokio::time::interval(std::time::Duration::from_secs( + if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 } + )); + interval.tick().await; + + loop { + tokio::select! { + result = config_rx.changed() => { + if result.is_err() { break; } + config = config_rx.borrow_and_update().clone(); + let new_secs = if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 }; + interval = tokio::time::interval(std::time::Duration::from_secs(new_secs)); + interval.tick().await; + debug!("Patch config updated: enabled={}, interval={}s", config.enabled, new_secs); + } + _ = interval.tick() => { + if !config.enabled { continue; } + match collect_patches().await { + Ok(patches) => { + if patches.is_empty() { + debug!("No patches collected for device {}", device_uid); + continue; + } + debug!("Collected {} patches for device {}", patches.len(), device_uid); + let payload = PatchStatusPayload { + device_uid: device_uid.clone(), + patches, + }; + if let Ok(frame) = Frame::new_json(MessageType::PatchStatusReport, &payload) { + if data_tx.send(frame).await.is_err() { + warn!("Failed to send patch status: channel closed"); + break; + } + } + } + Err(e) => { + warn!("Patch collection failed: {}", e); + } + } + } + } + } +} + +async fn collect_patches() -> anyhow::Result> { + // SECURITY: PowerShell command uses only hardcoded strings with no user/remote input. + // The format!() only inserts PowerShell syntax, not external data. + let output = tokio::process::Command::new("powershell") + .args([ + "-NoProfile", "-NonInteractive", "-Command", + &format!( + "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; \ + Get-HotFix | Select-Object -First 200 | \ + ForEach-Object {{ \ + [PSCustomObject]@{{ \ + kb = $_.HotFixID; \ + desc = $_.Description; \ + installed = if ($_.InstalledOn) {{ $_.InstalledOn.ToString('yyyy-MM-dd') }} else {{ '' }} \ + }} \ + }} | ConvertTo-Json -Compress" + ), + ]) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow::anyhow!("PowerShell failed: {}", stderr)); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let trimmed = stdout.trim(); + if trimmed.is_empty() { + return Ok(Vec::new()); + } + + // Handle single-item case (PowerShell returns object instead of array) + let items: Vec = if trimmed.starts_with('[') { + serde_json::from_str(trimmed).unwrap_or_default() + } else { + serde_json::from_str(trimmed).map(|v: serde_json::Value| vec![v]).unwrap_or_default() + }; + + let patches: Vec = items.iter().filter_map(|item| { + let kb = item.get("kb")?.as_str()?.to_string(); + if kb.is_empty() { return None; } + let desc = item.get("desc").and_then(|v| v.as_str()).unwrap_or(""); + let installed_str = item.get("installed").and_then(|v| v.as_str()).unwrap_or(""); + + Some(PatchEntry { + title: format!("{} - {}", kb, desc), + kb_id: kb, + severity: None, // Will be enriched server-side from known CVE data + is_installed: true, + installed_at: if installed_str.is_empty() { None } else { Some(installed_str.to_string()) }, + }) + }).collect(); + + Ok(patches) +} diff --git a/crates/client/src/service.rs b/crates/client/src/service.rs index b1be0aa..917054f 100644 --- a/crates/client/src/service.rs +++ b/crates/client/src/service.rs @@ -1,9 +1,10 @@ use std::ffi::OsString; use std::time::Duration; -use tracing::{info, error}; +use tracing::{info, error, debug}; use windows_service::define_windows_service; use windows_service::service::{ - ServiceAccess, ServiceControl, ServiceErrorControl, ServiceExitCode, + ServiceAccess, ServiceAction, ServiceActionType, ServiceControl, ServiceErrorControl, + ServiceExitCode, ServiceFailureActions, ServiceFailureResetPeriod, ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType, }; use windows_service::service_control_handler::{self, ServiceControlHandlerResult}; @@ -38,6 +39,30 @@ pub fn install() -> anyhow::Result<()> { let service = manager.create_service(&service_info, ServiceAccess::CHANGE_CONFIG)?; service.set_description(SERVICE_DESCRIPTION)?; + // Configure service recovery: restart on failure with escalating delays + let failure_actions = ServiceFailureActions { + reset_period: ServiceFailureResetPeriod::After(std::time::Duration::from_secs(60)), + reboot_msg: None, + command: None, + actions: Some(vec![ + ServiceAction { + action_type: ServiceActionType::Restart, + delay: std::time::Duration::from_secs(5), + }, + ServiceAction { + action_type: ServiceActionType::Restart, + delay: std::time::Duration::from_secs(30), + }, + ServiceAction { + action_type: ServiceActionType::Restart, + delay: std::time::Duration::from_secs(60), + }, + ]), + }; + if let Err(e) = service.update_failure_actions(failure_actions) { + tracing::warn!("Failed to set service recovery actions: {}", e); + } + println!("Service '{}' installed successfully.", SERVICE_NAME); println!("Use 'sc start {}' or restart to launch.", SERVICE_NAME); Ok(()) @@ -117,7 +142,7 @@ fn run_service_inner() -> anyhow::Result<()> { // Build ClientState let device_uid = crate::load_or_create_device_uid()?; - info!("Device UID: {}", device_uid); + debug!("Device UID: {}", device_uid); let state = crate::ClientState { device_uid, diff --git a/crates/client/src/software_blocker/mod.rs b/crates/client/src/software_blocker/mod.rs index 1e7cebe..235fa13 100644 --- a/crates/client/src/software_blocker/mod.rs +++ b/crates/client/src/software_blocker/mod.rs @@ -1,11 +1,13 @@ +use std::collections::HashSet; use tokio::sync::watch; -use tracing::{info, warn}; +use tracing::{info, warn, debug}; use csm_protocol::{Frame, MessageType, SoftwareViolationReport}; use serde::Deserialize; /// System-critical processes that must never be killed regardless of server rules. /// Killing any of these would cause system instability or a BSOD. const PROTECTED_PROCESSES: &[&str] = &[ + // Windows system processes "system", "system idle process", "svchost.exe", @@ -20,8 +22,60 @@ const PROTECTED_PROCESSES: &[&str] = &[ "registry", "smss.exe", "conhost.exe", + "ntoskrnl.exe", + "dcomlaunch.exe", + "rundll32.exe", + "sihost.exe", + "taskeng.exe", + "wermgr.exe", + "WerFault.exe", + "fontdrvhost.exe", + "ctfmon.exe", + "SearchIndexer.exe", + "SearchHost.exe", + "RuntimeBroker.exe", + "SecurityHealthService.exe", + "SecurityHealthSystray.exe", + "MpCmdRun.exe", + "MsMpEng.exe", + "NisSrv.exe", + // Common browsers — should never be blocked unless explicitly configured with exact name + "chrome.exe", + "msedge.exe", + "firefox.exe", + "iexplore.exe", + "opera.exe", + "brave.exe", + "vivaldi.exe", + "thorium.exe", + // Development tools & IDEs + "code.exe", + "devenv.exe", + "idea64.exe", + "webstorm64.exe", + "pycharm64.exe", + "goland64.exe", + "clion64.exe", + "rider64.exe", + "datagrip64.exe", + "trae.exe", + "windsurf.exe", + "cursor.exe", + "zed.exe", + // Terminal & system tools + "cmd.exe", + "powershell.exe", + "pwsh.exe", + "WindowsTerminal.exe", + "conhost.exe", + // CSM itself + "csm-client.exe", ]; +/// Cooldown period (seconds) before reporting/killing the same process again. +/// Prevents spamming violations for long-running blocked processes. +const REPORT_COOLDOWN_SECS: u64 = 300; // 5 minutes + /// Software blacklist entry from server #[derive(Debug, Clone, Deserialize)] pub struct BlacklistEntry { @@ -36,6 +90,8 @@ pub struct BlacklistEntry { pub struct SoftwareBlockerConfig { pub enabled: bool, pub blacklist: Vec, + /// Server-pushed whitelist: processes matching these patterns are never blocked. + pub whitelist: Vec, } /// Start software blocker plugin. @@ -47,6 +103,11 @@ pub async fn start( ) { info!("Software blocker plugin started"); let mut config = SoftwareBlockerConfig::default(); + // Track recently acted-on processes to avoid repeated kill/report spam + let mut recent_actions: HashSet = HashSet::new(); + let mut cooldown_interval = tokio::time::interval(std::time::Duration::from_secs(REPORT_COOLDOWN_SECS)); + cooldown_interval.tick().await; + let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10)); scan_interval.tick().await; @@ -58,13 +119,23 @@ pub async fn start( } let new_config = config_rx.borrow_and_update().clone(); info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len()); + // Clear cooldown cache when config changes so new rules take effect immediately + recent_actions.clear(); config = new_config; } + _ = cooldown_interval.tick() => { + // Periodically clear the cooldown cache so we can re-check + let cleared = recent_actions.len(); + recent_actions.clear(); + if cleared > 0 { + info!("Software blocker cooldown cache cleared ({} entries)", cleared); + } + } _ = scan_interval.tick() => { if !config.enabled || config.blacklist.is_empty() { continue; } - scan_processes(&config.blacklist, &data_tx, &device_uid).await; + scan_processes(&config.blacklist, &config.whitelist, &data_tx, &device_uid, &mut recent_actions).await; } } } @@ -72,83 +143,148 @@ pub async fn start( async fn scan_processes( blacklist: &[BlacklistEntry], + whitelist: &[String], data_tx: &tokio::sync::mpsc::Sender, device_uid: &str, + recent_actions: &mut HashSet, ) { let running = get_running_processes_with_pids(); for entry in blacklist { for (process_name, pid) in &running { - if pattern_matches(&entry.name_pattern, process_name) { - // Never kill system-critical processes - if is_protected_process(process_name) { - warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid); - continue; - } - - warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action); - - // Report violation to server - // Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted" - let action_taken = match entry.action.as_str() { - "block" => "blocked_install", - "alert" => "alerted", - other => other, - }; - let violation = SoftwareViolationReport { - device_uid: device_uid.to_string(), - software_name: process_name.clone(), - action_taken: action_taken.to_string(), - timestamp: chrono::Utc::now().to_rfc3339(), - }; - if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) { - let _ = data_tx.send(frame).await; - } - - // Kill the process directly by captured PID (avoids TOCTOU race) - if entry.action == "block" { - kill_process_by_pid(*pid, process_name); - } + if !pattern_matches(&entry.name_pattern, process_name) { + continue; } + + // Never kill protected processes (system, browsers, IDEs, etc.) + if is_protected_process(process_name) { + warn!( + "Blacklisted match '{}' skipped — protected process (pid={})", + process_name, pid + ); + continue; + } + + // Check server-pushed whitelist (takes precedence over blacklist) + if is_whitelisted(process_name, whitelist) { + debug!( + "Blacklisted match '{}' skipped — whitelisted (pid={})", + process_name, pid + ); + continue; + } + + // Skip if already acted on recently (cooldown) + let action_key = format!("{}:{}", process_name.to_lowercase(), pid); + if recent_actions.contains(&action_key) { + continue; + } + + warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action); + + // Report violation to server + let action_taken = match entry.action.as_str() { + "block" => "blocked_install", + "alert" => "alerted", + other => other, + }; + let violation = SoftwareViolationReport { + device_uid: device_uid.to_string(), + software_name: process_name.clone(), + action_taken: action_taken.to_string(), + timestamp: chrono::Utc::now().to_rfc3339(), + }; + if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) { + let _ = data_tx.send(frame).await; + } + + // Kill the process if action is "block" + if entry.action == "block" { + kill_process_by_pid(*pid, process_name); + } + + // Mark as recently acted-on + recent_actions.insert(action_key); } } } +/// Match a blacklist pattern against a process name. +/// +/// **Matching rules (case-insensitive)**: +/// - No wildcard → exact filename match only (e.g. `chrome.exe` matches `chrome.exe` but NOT `new_chrome.exe`) +/// - `*` wildcard → glob-style pattern match (e.g. `*miner*` matches `bitcoin_miner.exe`) +/// - Pattern with `.exe` suffix → match against full process name +/// - Pattern without extension → match against stem (name without `.exe`) +/// +/// **IMPORTANT**: We intentionally do NOT use substring matching (`contains()`) for +/// non-wildcard patterns. Substring matching caused false positives where a pattern +/// like "game" would match "game_bar.exe" or even "svchost.exe" in edge cases. fn pattern_matches(pattern: &str, name: &str) -> bool { let pattern_lower = pattern.to_lowercase(); let name_lower = name.to_lowercase(); - // Support wildcard patterns - if pattern_lower.contains('*') { - let parts: Vec<&str> = pattern_lower.split('*').collect(); - let mut pos = 0; - for (i, part) in parts.iter().enumerate() { - if part.is_empty() { - continue; - } - if i == 0 && !parts[0].is_empty() { - // Pattern starts with literal → must match at start - if !name_lower.starts_with(part) { - return false; - } - pos = part.len(); - } else { - match name_lower[pos..].find(part) { - Some(idx) => pos += idx + part.len(), - None => return false, - } - } - } - // If pattern ends with literal (no trailing *), must match at end - if !parts.last().map_or(true, |p| p.is_empty()) { - return name_lower.ends_with(parts.last().unwrap()); - } - true + // Extract the stem (filename without extension) for extension-less patterns + let name_stem = name_lower.strip_suffix(".exe").unwrap_or(&name_lower); + + if pattern_lower.contains('*') { + // Wildcard glob matching + glob_matches(&pattern_lower, &name_lower, &name_stem) } else { - name_lower.contains(&pattern_lower) + // Exact match: compare against full name OR stem (if pattern has no extension) + let pattern_stem = pattern_lower.strip_suffix(".exe").unwrap_or(&pattern_lower); + name_lower == pattern_lower || name_stem == pattern_stem } } +/// Glob-style pattern matching with `*` wildcards. +/// Checks both the full name and the stem (without .exe). +fn glob_matches(pattern: &str, full_name: &str, stem: &str) -> bool { + // Try matching against both the full process name and the stem + glob_matches_single(pattern, full_name) || glob_matches_single(pattern, stem) +} + +fn glob_matches_single(pattern: &str, text: &str) -> bool { + let parts: Vec<&str> = pattern.split('*').collect(); + + if parts.is_empty() { + return true; + } + + let mut pos = 0; + + for (i, part) in parts.iter().enumerate() { + if part.is_empty() { + continue; + } + + if i == 0 { + // First segment: must match at the start + if !text.starts_with(part) { + return false; + } + pos = part.len(); + } else if i == parts.len() - 1 { + // Last segment: must match at the end (if pattern doesn't end with *) + return text.ends_with(part) && text[..text.len() - part.len()].len() >= pos; + } else { + // Middle segment: must appear after current position + match text[pos..].find(part) { + Some(idx) => pos += idx + part.len(), + None => return false, + } + } + } + + // If pattern ends with *, anything after last match is fine + if pattern.ends_with('*') { + return true; + } + + // Full match required + pos == text.len() +} + /// Get all running processes with their PIDs (single snapshot, no TOCTOU) fn get_running_processes_with_pids() -> Vec<(String, u32)> { #[cfg(target_os = "windows")] @@ -253,3 +389,12 @@ fn is_protected_process(name: &str) -> bool { let lower = name.to_lowercase(); PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\"))) } + +/// Check if a process name matches any server-pushed whitelist pattern. +/// Whitelist patterns use the same matching logic as blacklist (exact or glob). +fn is_whitelisted(process_name: &str, whitelist: &[String]) -> bool { + if whitelist.is_empty() { + return false; + } + whitelist.iter().any(|pattern| pattern_matches(pattern, process_name)) +} diff --git a/crates/protocol/src/lib.rs b/crates/protocol/src/lib.rs index 2927587..aa3e446 100644 --- a/crates/protocol/src/lib.rs +++ b/crates/protocol/src/lib.rs @@ -29,4 +29,6 @@ pub use message::{ PrintEventPayload, ClipboardRulesPayload, ClipboardRule, ClipboardViolationPayload, PopupBlockStatsPayload, PopupRuleStat, + PatchStatusPayload, PatchEntry, PatchScanConfigPayload, + BehaviorMetricsPayload, }; diff --git a/crates/protocol/src/message.rs b/crates/protocol/src/message.rs index bec7280..d275144 100644 --- a/crates/protocol/src/message.rs +++ b/crates/protocol/src/message.rs @@ -71,6 +71,14 @@ pub enum MessageType { // Plugin: Clipboard Control (剪贴板管控) ClipboardRules = 0x94, ClipboardViolation = 0x95, + + // Plugin: Patch Management (补丁管理) + PatchStatusReport = 0xA0, + PatchScanConfig = 0xA1, + PatchInstallCommand = 0xA2, + + // Plugin: Behavior Metrics (行为指标) + BehaviorMetricsReport = 0xB0, } impl TryFrom for MessageType { @@ -108,6 +116,10 @@ impl TryFrom for MessageType { 0x91 => Ok(Self::PrintEvent), 0x94 => Ok(Self::ClipboardRules), 0x95 => Ok(Self::ClipboardViolation), + 0xA0 => Ok(Self::PatchStatusReport), + 0xA1 => Ok(Self::PatchScanConfig), + 0xA2 => Ok(Self::PatchInstallCommand), + 0xB0 => Ok(Self::BehaviorMetricsReport), _ => Err(format!("Unknown message type: 0x{:02X}", value)), } } @@ -264,7 +276,7 @@ pub struct TaskExecutePayload { #[derive(Debug, Serialize, Deserialize)] pub enum ConfigUpdateType { UpdateIntervals { heartbeat: u64, status: u64, asset: u64 }, - TlsCertRotate, + TlsCertRotate { new_cert_hash: String, valid_until: String }, SelfDestruct, } @@ -442,6 +454,44 @@ pub struct PopupRuleStat { pub hits: u32, } +/// Plugin: Patch Status Report (Client → Server) +#[derive(Debug, Serialize, Deserialize)] +pub struct PatchStatusPayload { + pub device_uid: String, + pub patches: Vec, +} + +/// Information about a single patch/hotfix. +#[derive(Debug, Serialize, Deserialize)] +pub struct PatchEntry { + pub kb_id: String, + pub title: String, + pub severity: Option, // "Critical" | "Important" | "Moderate" | "Low" + pub is_installed: bool, + pub installed_at: Option, +} + +/// Plugin: Patch Scan Config (Server → Client) +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct PatchScanConfigPayload { + pub enabled: bool, + pub scan_interval_secs: u64, +} + +/// Plugin: Behavior Metrics Report (Client → Server) +/// Enhanced periodic metrics for anomaly detection. +#[derive(Debug, Serialize, Deserialize)] +pub struct BehaviorMetricsPayload { + pub device_uid: String, + pub clipboard_ops_count: u32, + pub clipboard_ops_night: u32, + pub print_jobs_count: u32, + pub usb_file_ops_count: u32, + pub new_processes_count: u32, + pub period_secs: u64, + pub timestamp: String, +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/server/src/alert.rs b/crates/server/src/alert.rs index 9047098..167e74b 100644 --- a/crates/server/src/alert.rs +++ b/crates/server/src/alert.rs @@ -41,6 +41,26 @@ pub async fn cleanup_task(state: AppState) { error!("Failed to cleanup alert records: {}", e); } + // Cleanup old revoked token families (keep 30 days for audit) + if let Err(e) = sqlx::query( + "DELETE FROM revoked_token_families WHERE revoked_at < datetime('now', '-30 days')" + ) + .execute(&state.db) + .await + { + error!("Failed to cleanup revoked token families: {}", e); + } + + // Cleanup old anomaly alerts that have been handled + if let Err(e) = sqlx::query( + "DELETE FROM anomaly_alerts WHERE handled = 1 AND triggered_at < datetime('now', '-90 days')" + ) + .execute(&state.db) + .await + { + error!("Failed to cleanup handled anomaly alerts: {}", e); + } + // Mark devices as offline if no heartbeat for 2 minutes if let Err(e) = sqlx::query( "UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')" diff --git a/crates/server/src/anomaly.rs b/crates/server/src/anomaly.rs new file mode 100644 index 0000000..de4cb40 --- /dev/null +++ b/crates/server/src/anomaly.rs @@ -0,0 +1,170 @@ +use sqlx::Row; +use tracing::{info, warn}; +use csm_protocol::BehaviorMetricsPayload; + +/// Check incoming behavior metrics against anomaly rules and generate alerts +pub async fn check_anomalies( + pool: &sqlx::SqlitePool, + ws_hub: &crate::ws::WsHub, + metrics: &BehaviorMetricsPayload, +) { + let mut alerts: Vec = Vec::new(); + + // Rule 1: Night-time clipboard operations (> 10 in reporting period) + if metrics.clipboard_ops_night > 10 { + alerts.push(serde_json::json!({ + "anomaly_type": "night_clipboard_spike", + "severity": "high", + "detail": format!("非工作时间剪贴板操作异常: {}次 (阈值: 10次)", metrics.clipboard_ops_night), + "metric_value": metrics.clipboard_ops_night, + })); + } + + // Rule 2: High USB file operations (> 100 per hour) + if metrics.period_secs > 0 { + let usb_per_hour = (metrics.usb_file_ops_count as f64 / metrics.period_secs as f64) * 3600.0; + if usb_per_hour > 100.0 { + alerts.push(serde_json::json!({ + "anomaly_type": "usb_file_exfiltration", + "severity": "critical", + "detail": format!("USB文件操作频率异常: {:.0}次/小时 (阈值: 100次/小时)", usb_per_hour), + "metric_value": usb_per_hour, + })); + } + } + + // Rule 3: High print volume (> 50 per reporting period) + if metrics.print_jobs_count > 50 { + alerts.push(serde_json::json!({ + "anomaly_type": "high_print_volume", + "severity": "medium", + "detail": format!("打印量异常: {}次 (阈值: 50次)", metrics.print_jobs_count), + "metric_value": metrics.print_jobs_count, + })); + } + + // Rule 4: Excessive new processes (> 20 per hour) + if metrics.period_secs > 0 { + let procs_per_hour = (metrics.new_processes_count as f64 / metrics.period_secs as f64) * 3600.0; + if procs_per_hour > 20.0 { + alerts.push(serde_json::json!({ + "anomaly_type": "process_spawn_spike", + "severity": "medium", + "detail": format!("新进程启动异常: {:.0}次/小时 (阈值: 20次/小时)", procs_per_hour), + "metric_value": procs_per_hour, + })); + } + } + + // Insert anomaly alerts + for alert in &alerts { + if let Err(e) = sqlx::query( + "INSERT INTO anomaly_alerts (device_uid, anomaly_type, severity, detail, metric_value, triggered_at) \ + VALUES (?, ?, ?, ?, ?, datetime('now'))" + ) + .bind(&metrics.device_uid) + .bind(alert.get("anomaly_type").and_then(|v| v.as_str()).unwrap_or("unknown")) + .bind(alert.get("severity").and_then(|v| v.as_str()).unwrap_or("medium")) + .bind(alert.get("detail").and_then(|v| v.as_str()).unwrap_or("")) + .bind(alert.get("metric_value").and_then(|v| v.as_f64()).unwrap_or(0.0)) + .execute(pool) + .await + { + warn!("Failed to insert anomaly alert: {}", e); + } + } + + // Broadcast anomaly alerts via WebSocket + if !alerts.is_empty() { + for alert in &alerts { + ws_hub.broadcast(serde_json::json!({ + "type": "anomaly_alert", + "device_uid": metrics.device_uid, + "anomaly_type": alert.get("anomaly_type"), + "severity": alert.get("severity"), + "detail": alert.get("detail"), + }).to_string()).await; + } + info!("Detected {} anomalies for device {}", alerts.len(), metrics.device_uid); + } +} + +/// Get anomaly alert summary for a device or all devices +pub async fn get_anomaly_summary( + pool: &sqlx::SqlitePool, + device_uid: Option<&str>, + page: u32, + page_size: u32, +) -> anyhow::Result { + let offset = page.saturating_sub(1) * page_size; + + let rows = if let Some(uid) = device_uid { + sqlx::query( + "SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \ + WHERE a.device_uid = ? ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?" + ) + .bind(uid) + .bind(page_size) + .bind(offset) + .fetch_all(pool) + .await? + } else { + sqlx::query( + "SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \ + ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?" + ) + .bind(page_size) + .bind(offset) + .fetch_all(pool) + .await? + }; + + let total: i64 = if let Some(uid) = device_uid { + sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts WHERE device_uid = ?") + .bind(uid) + .fetch_one(pool) + .await? + } else { + sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts") + .fetch_one(pool) + .await? + }; + + let alerts: Vec = rows.iter().map(|r| serde_json::json!({ + "id": r.get::("id"), + "device_uid": r.get::("device_uid"), + "hostname": r.get::("hostname"), + "anomaly_type": r.get::("anomaly_type"), + "severity": r.get::("severity"), + "detail": r.get::("detail"), + "metric_value": r.get::("metric_value"), + "handled": r.get::("handled"), + "triggered_at": r.get::("triggered_at"), + })).collect(); + + // Summary counts (scoped to same filter) + let unhandled: i64 = if let Some(uid) = device_uid { + sqlx::query_scalar( + "SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0 AND device_uid = ?" + ) + .bind(uid) + .fetch_one(pool) + .await + .unwrap_or(0) + } else { + sqlx::query_scalar( + "SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0" + ) + .fetch_one(pool) + .await + .unwrap_or(0) + }; + + Ok(serde_json::json!({ + "alerts": alerts, + "total": total, + "unhandled_count": unhandled, + "page": page, + "page_size": page_size, + })) +} diff --git a/crates/server/src/api/alerts.rs b/crates/server/src/api/alerts.rs index 7d94ea4..cf326b2 100644 --- a/crates/server/src/api/alerts.rs +++ b/crates/server/src/api/alerts.rs @@ -21,7 +21,7 @@ pub async fn list_rules( ) -> Json> { let rows = sqlx::query( "SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at - FROM alert_rules ORDER BY created_at DESC" + FROM alert_rules ORDER BY created_at DESC LIMIT 500" ) .fetch_all(&state.db) .await; @@ -116,7 +116,26 @@ pub async fn create_rule( State(state): State, Json(body): Json, ) -> (StatusCode, Json>) { + // Validate rule_type + if !matches!(body.rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") { + return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid rule_type"))); + } + + // Validate severity let severity = body.severity.unwrap_or_else(|| "medium".to_string()); + if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") { + return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid severity"))); + } + + // Validate webhook URL (SSRF prevention) + if let Some(ref url) = body.notify_webhook { + if !url.starts_with("https://") { + return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL must use HTTPS"))); + } + if url.len() > 2048 { + return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL too long"))); + } + } let result = sqlx::query( "INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook) @@ -174,6 +193,26 @@ pub async fn update_rule( let notify_email = body.notify_email.or_else(|| existing.get::, _>("notify_email")); let notify_webhook = body.notify_webhook.or_else(|| existing.get::, _>("notify_webhook")); + // Validate rule_type + if !matches!(rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") { + return Json(ApiResponse::error("Invalid rule_type")); + } + + // Validate severity + if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") { + return Json(ApiResponse::error("Invalid severity")); + } + + // Validate webhook URL (SSRF prevention) + if let Some(ref url) = notify_webhook { + if !url.starts_with("https://") { + return Json(ApiResponse::error("Webhook URL must use HTTPS")); + } + if url.len() > 2048 { + return Json(ApiResponse::error("Webhook URL too long")); + } + } + let result = sqlx::query( "UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?, notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?" diff --git a/crates/server/src/api/auth.rs b/crates/server/src/api/auth.rs index a3c0073..fdc2498 100644 --- a/crates/server/src/api/auth.rs +++ b/crates/server/src/api/auth.rs @@ -1,4 +1,5 @@ -use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response}; +use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::{Response, IntoResponse}}; +use axum::http::header::{SET_COOKIE, HeaderValue}; use serde::{Deserialize, Serialize}; use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation}; use std::sync::Arc; @@ -28,11 +29,15 @@ pub struct LoginRequest { #[derive(Debug, Serialize)] pub struct LoginResponse { - pub access_token: String, - pub refresh_token: String, pub user: UserInfo, } +#[derive(Debug, Serialize)] +pub struct MeResponse { + pub user: UserInfo, + pub expires_at: String, +} + #[derive(Debug, Serialize, sqlx::FromRow)] pub struct UserInfo { pub id: i64, @@ -40,18 +45,68 @@ pub struct UserInfo { pub role: String, } -#[derive(Debug, Deserialize)] -pub struct RefreshRequest { - pub refresh_token: String, -} - #[derive(Debug, Deserialize)] pub struct ChangePasswordRequest { pub old_password: String, pub new_password: String, } -/// In-memory rate limiter for login attempts +// --------------------------------------------------------------------------- +// Cookie helpers +// --------------------------------------------------------------------------- + +fn is_secure_cookies() -> bool { + std::env::var("CSM_DEV").is_err() +} + +fn access_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue { + let secure = if is_secure_cookies() { "; Secure" } else { "" }; + HeaderValue::from_str(&format!( + "access_token={}; HttpOnly{}; SameSite=Strict; Path=/; Max-Age={}", + token, secure, ttl_secs + )).expect("valid cookie header") +} + +fn refresh_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue { + let secure = if is_secure_cookies() { "; Secure" } else { "" }; + HeaderValue::from_str(&format!( + "refresh_token={}; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age={}", + token, secure, ttl_secs + )).expect("valid cookie header") +} + +fn clear_cookie_headers() -> Vec { + let secure = if is_secure_cookies() { "; Secure" } else { "" }; + vec![ + HeaderValue::from_str(&format!("access_token=; HttpOnly{}; SameSite=Strict; Path=/; Max-Age=0", secure)).expect("valid"), + HeaderValue::from_str(&format!("refresh_token=; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age=0", secure)).expect("valid"), + ] +} + +/// Attach Set-Cookie headers to a response. +fn with_cookies(mut response: Response, cookies: Vec) -> Response { + for cookie in cookies { + response.headers_mut().append(SET_COOKIE, cookie); + } + response +} + +/// Extract a cookie value by name from the raw Cookie header. +fn extract_cookie_value(headers: &axum::http::HeaderMap, name: &str) -> Option { + let cookie_header = headers.get("cookie")?.to_str().ok()?; + for cookie in cookie_header.split(';') { + let cookie = cookie.trim(); + if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) { + return Some(value.to_string()); + } + } + None +} + +// --------------------------------------------------------------------------- +// Rate limiter +// --------------------------------------------------------------------------- + #[derive(Clone, Default)] pub struct LoginRateLimiter { attempts: Arc>>, @@ -62,28 +117,25 @@ impl LoginRateLimiter { Self::default() } - /// Returns true if the request should be rate-limited pub async fn is_limited(&self, key: &str) -> bool { let mut attempts = self.attempts.lock().await; let now = Instant::now(); - let window = std::time::Duration::from_secs(300); // 5-minute window + let window = std::time::Duration::from_secs(300); let max_attempts = 10u32; if let Some((first_attempt, count)) = attempts.get_mut(key) { if now.duration_since(*first_attempt) > window { - // Window expired, reset *first_attempt = now; *count = 1; false } else if *count >= max_attempts { - true // Rate limited + true } else { *count += 1; false } } else { attempts.insert(key.to_string(), (now, 1)); - // Cleanup old entries periodically if attempts.len() > 1000 { let cutoff = now - window; attempts.retain(|_, (t, _)| *t > cutoff); @@ -93,46 +145,67 @@ impl LoginRateLimiter { } } +// --------------------------------------------------------------------------- +// Endpoints +// --------------------------------------------------------------------------- + pub async fn login( State(state): State, Json(req): Json, -) -> Result<(StatusCode, Json>), StatusCode> { - // Rate limit check +) -> impl IntoResponse { if state.login_limiter.is_limited(&req.username).await { - return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later.")))); + return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::::error("Too many login attempts. Try again later."))).into_response(); + } + if state.login_limiter.is_limited("ip:default").await { + return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::::error("Too many login attempts from this location. Try again later."))).into_response(); } - let user: Option = sqlx::query_as::<_, UserInfo>( - "SELECT id, username, role FROM users WHERE username = ?" + let row: Option<(UserInfo, String)> = sqlx::query_as::<_, (i64, String, String, String)>( + "SELECT id, username, role, password FROM users WHERE username = ?" ) .bind(&req.username) .fetch_optional(&state.db) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + .ok() + .flatten() + .map(|(id, username, role, password)| { + (UserInfo { id, username, role }, password) + }); - let user = match user { - Some(u) => u, - None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))), + let (user, hash) = match row { + Some(r) => r, + None => { + let _ = bcrypt::verify("timing-constant-dummy", "$2b$12$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid credentials"))).into_response(); + } }; - let hash: String = sqlx::query_scalar::<_, String>( - "SELECT password FROM users WHERE id = ?" - ) - .bind(user.id) - .fetch_one(&state.db) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - if !bcrypt::verify(&req.password, &hash).unwrap_or(false) { - return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))); + return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid credentials"))).into_response(); } let now = chrono::Utc::now().timestamp() as u64; let family = uuid::Uuid::new_v4().to_string(); - let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?; - let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?; + let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) { + Ok(t) => t, + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }; + let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) { + Ok(t) => t, + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }; + + let refresh_expires = now + state.config.auth.refresh_token_ttl_secs; + let _ = sqlx::query( + "INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))" + ) + .bind(user.id) + .bind(&family) + .bind(refresh_expires as i64) + .execute(&state.db) + .await; - // Audit log let _ = sqlx::query( "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)" ) @@ -141,73 +214,262 @@ pub async fn login( .execute(&state.db) .await; - Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse { - access_token, - refresh_token, - user, - })))) + let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response(); + with_cookies(response, vec![ + access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs), + refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs), + ]) } pub async fn refresh( State(state): State, - Json(req): Json, -) -> Result<(StatusCode, Json>), StatusCode> { - let claims = decode::( - &req.refresh_token, + headers: axum::http::HeaderMap, +) -> impl IntoResponse { + let refresh_token = match extract_cookie_value(&headers, "refresh_token") { + Some(t) => t, + None => return with_cookies( + (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Missing refresh token"))).into_response(), + clear_cookie_headers(), + ), + }; + + let claims = match decode::( + &refresh_token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), - ) - .map_err(|_| StatusCode::UNAUTHORIZED)?; + ) { + Ok(c) => c.claims, + Err(_) => return with_cookies( + (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid refresh token"))).into_response(), + clear_cookie_headers(), + ), + }; - if claims.claims.token_type != "refresh" { - return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type")))); + if claims.token_type != "refresh" { + return with_cookies( + (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token type"))).into_response(), + clear_cookie_headers(), + ); } - // Check if this refresh token family has been revoked (reuse detection) + let mut tx = match state.db.begin().await { + Ok(tx) => tx, + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }; + let revoked: bool = sqlx::query_scalar::<_, i64>( "SELECT COUNT(*) FROM revoked_token_families WHERE family = ?" ) - .bind(&claims.claims.family) - .fetch_one(&state.db) + .bind(&claims.family) + .fetch_one(&mut *tx) .await .unwrap_or(0) > 0; if revoked { - // Token reuse detected — revoke entire family and force re-login - tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family); + tx.rollback().await.ok(); + tracing::warn!("Refresh token reuse detected for user {} family {}", claims.sub, claims.family); let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?") - .bind(claims.claims.sub) + .bind(claims.sub) .execute(&state.db) .await; - return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again.")))); + return with_cookies( + (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Token reuse detected. Please log in again."))).into_response(), + clear_cookie_headers(), + ); + } + + let family_exists: bool = sqlx::query_scalar::<_, i64>( + "SELECT COUNT(*) FROM refresh_tokens WHERE family = ? AND user_id = ?" + ) + .bind(&claims.family) + .bind(claims.sub) + .fetch_one(&mut *tx) + .await + .unwrap_or(0) > 0; + + if !family_exists { + tx.rollback().await.ok(); + return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid refresh token"))).into_response(); } let user = UserInfo { - id: claims.claims.sub, - username: claims.claims.username, - role: claims.claims.role, + id: claims.sub, + username: claims.username, + role: claims.role, }; - // Rotate: new family for each refresh let new_family = uuid::Uuid::new_v4().to_string(); let now = chrono::Utc::now().timestamp() as u64; - let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?; - let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?; + let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) { + Ok(t) => t, + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }; + let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) { + Ok(t) => t, + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }; - // Revoke old family - let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))") - .bind(&claims.claims.family) - .bind(claims.claims.sub) - .execute(&state.db) - .await; + if sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))") + .bind(&claims.family) + .bind(claims.sub) + .execute(&mut *tx) + .await + .is_err() + { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } - Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse { - access_token, - refresh_token, - user, - })))) + let refresh_expires = now + state.config.auth.refresh_token_ttl_secs; + if sqlx::query( + "INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))" + ) + .bind(user.id) + .bind(&new_family) + .bind(refresh_expires as i64) + .execute(&mut *tx) + .await + .is_err() + { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + + if tx.commit().await.is_err() { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + + let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response(); + with_cookies(response, vec![ + access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs), + refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs), + ]) } +/// Get current authenticated user info from access_token cookie. +pub async fn me( + State(state): State, + headers: axum::http::HeaderMap, +) -> impl IntoResponse { + let token = match extract_cookie_value(&headers, "access_token") { + Some(t) => t, + None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Not authenticated"))).into_response(), + }; + + let claims = match decode::( + &token, + &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), + &Validation::default(), + ) { + Ok(c) => c.claims, + Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token"))).into_response(), + }; + + if claims.token_type != "access" { + return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token type"))).into_response(); + } + + let expires_at = chrono::DateTime::from_timestamp(claims.exp as i64, 0) + .map(|t| t.to_rfc3339()) + .unwrap_or_default(); + + (StatusCode::OK, Json(ApiResponse::ok(MeResponse { + user: UserInfo { + id: claims.sub, + username: claims.username, + role: claims.role, + }, + expires_at, + }))).into_response() +} + +/// Logout: clear auth cookies and revoke refresh token family. +pub async fn logout( + State(state): State, + headers: axum::http::HeaderMap, +) -> impl IntoResponse { + if let Some(token) = extract_cookie_value(&headers, "access_token") { + if let Ok(claims) = decode::( + &token, + &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), + &Validation::default(), + ) { + let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?") + .bind(claims.claims.sub) + .execute(&state.db) + .await; + let _ = sqlx::query( + "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'logout', ?)" + ) + .bind(claims.claims.sub) + .bind(format!("User {} logged out", claims.claims.username)) + .execute(&state.db) + .await; + } + } + + let response = (StatusCode::OK, Json(ApiResponse::ok(()))).into_response(); + with_cookies(response, clear_cookie_headers()) +} + +// --------------------------------------------------------------------------- +// WebSocket ticket +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub struct WsTicketResponse { + pub ticket: String, + pub expires_in: u64, +} + +/// Create a one-time ticket for WebSocket authentication. +/// Requires a valid access_token cookie (set by login). +pub async fn create_ws_ticket( + State(state): State, + headers: axum::http::HeaderMap, +) -> impl IntoResponse { + let token = match extract_cookie_value(&headers, "access_token") { + Some(t) => t, + None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Not authenticated"))).into_response(), + }; + + let claims = match decode::( + &token, + &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), + &Validation::default(), + ) { + Ok(c) => c.claims, + Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token"))).into_response(), + }; + + if claims.token_type != "access" { + return (StatusCode::UNAUTHORIZED, Json(ApiResponse::::error("Invalid token type"))).into_response(); + } + + let ticket = uuid::Uuid::new_v4().to_string(); + let claim = crate::ws::TicketClaim { + user_id: claims.sub, + username: claims.username, + role: claims.role, + created_at: std::time::Instant::now(), + }; + + { + let mut tickets = state.ws_tickets.lock().await; + tickets.insert(ticket.clone(), claim); + + // Cleanup expired tickets (>30s old) on every creation + tickets.retain(|_, c| c.created_at.elapsed().as_secs() < 30); + } + + (StatusCode::OK, Json(ApiResponse::ok(WsTicketResponse { + ticket, + expires_in: 30, + }))).into_response() +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result { let claims = Claims { sub: user.id, @@ -227,24 +489,17 @@ fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: & .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } -/// Axum middleware: require valid JWT access token +/// Axum middleware: require valid JWT access token from cookie pub async fn require_auth( State(state): State, mut request: Request, next: Next, ) -> Result { - let auth_header = request.headers() - .get("Authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")); - - let token = match auth_header { - Some(t) => t, - None => return Err(StatusCode::UNAUTHORIZED), - }; + let token = extract_cookie_value(request.headers(), "access_token") + .ok_or(StatusCode::UNAUTHORIZED)?; let claims = decode::( - token, + &token, &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), &Validation::default(), ) @@ -254,9 +509,7 @@ pub async fn require_auth( return Err(StatusCode::UNAUTHORIZED); } - // Inject claims into request extensions for handlers to use request.extensions_mut().insert(claims.claims); - Ok(next.run(request).await) } @@ -274,7 +527,6 @@ pub async fn require_admin( return Err(StatusCode::FORBIDDEN); } - // Capture audit info before running handler let method = request.method().clone(); let path = request.uri().path().to_string(); let user_id = claims.sub; @@ -282,7 +534,6 @@ pub async fn require_admin( let response = next.run(request).await; - // Record admin action to audit log (fire and forget — don't block response) let status = response.status(); if status.is_success() { let action = format!("{} {}", method, path); @@ -308,8 +559,10 @@ pub async fn change_password( if req.new_password.len() < 6 { return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("新密码至少6位")))); } + if req.new_password.len() > 72 { + return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("密码不能超过72位")))); + } - // Verify old password let hash: String = sqlx::query_scalar::<_, String>( "SELECT password FROM users WHERE id = ?" ) @@ -322,7 +575,6 @@ pub async fn change_password( return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("当前密码错误")))); } - // Update password let new_hash = bcrypt::hash(&req.new_password, 12).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; sqlx::query("UPDATE users SET password = ? WHERE id = ?") .bind(&new_hash) @@ -331,7 +583,6 @@ pub async fn change_password( .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - // Audit log let _ = sqlx::query( "INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'change_password', ?)" ) diff --git a/crates/server/src/api/conflict.rs b/crates/server/src/api/conflict.rs new file mode 100644 index 0000000..ca8133f --- /dev/null +++ b/crates/server/src/api/conflict.rs @@ -0,0 +1,243 @@ +use axum::{extract::State, Json}; +use serde::Serialize; +use sqlx::Row; +use crate::AppState; +use super::ApiResponse; + +#[derive(Debug, Serialize)] +pub struct PolicyConflict { + pub conflict_type: String, + pub severity: String, + pub description: String, + pub policies: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ConflictPolicyRef { + pub table_name: String, + pub row_id: i64, + pub name: String, + pub target_type: String, + pub target_id: Option, +} + +/// GET /api/policies/conflicts — scan all policies for conflicts +pub async fn scan_conflicts( + State(state): State, +) -> Json> { + let mut conflicts: Vec = Vec::new(); + + // 1. USB: multiple enabled policies for the same target_group + if let Ok(rows) = sqlx::query( + "SELECT target_group, COUNT(*) as cnt, GROUP_CONCAT(id) as ids, GROUP_CONCAT(name) as names, \ + GROUP_CONCAT(policy_type) as types \ + FROM usb_policies WHERE enabled = 1 AND target_group IS NOT NULL \ + GROUP BY target_group HAVING cnt > 1" + ) + .fetch_all(&state.db) + .await + { + for row in &rows { + let group: String = row.get("target_group"); + let ids: String = row.get("ids"); + let names: String = row.get("names"); + let types: String = row.get("types"); + let id_vec: Vec = ids.split(',').filter_map(|s| s.parse().ok()).collect(); + let name_vec: Vec<&str> = names.split(',').collect(); + let type_vec: Vec<&str> = types.split(',').collect(); + + conflicts.push(PolicyConflict { + conflict_type: "usb_duplicate_policy".to_string(), + severity: "high".to_string(), + description: format!("分组 '{}' 同时存在 {} 条启用的USB策略 ({})", group, id_vec.len(), type_vec.join(", ")), + policies: id_vec.iter().enumerate().map(|(i, id)| ConflictPolicyRef { + table_name: "usb_policies".to_string(), + row_id: *id, + name: name_vec.get(i).unwrap_or(&"?").to_string(), + target_type: "group".to_string(), + target_id: Some(group.clone()), + }).collect(), + }); + } + } + + // 2. USB: all_block + whitelist for same target (contradictory) + if let Ok(rows) = sqlx::query( + "SELECT a.id as aid, a.name as aname, a.target_group as agroup, \ + b.id as bid, b.name as bname \ + FROM usb_policies a JOIN usb_policies b ON a.target_group = b.target_group AND a.id < b.id \ + WHERE a.enabled = 1 AND b.enabled = 1 \ + AND ((a.policy_type = 'all_block' AND b.policy_type = 'whitelist') OR \ + (a.policy_type = 'whitelist' AND b.policy_type = 'all_block'))" + ) + .fetch_all(&state.db) + .await + { + for row in &rows { + let group: Option = row.get("agroup"); + conflicts.push(PolicyConflict { + conflict_type: "usb_block_vs_whitelist".to_string(), + severity: "critical".to_string(), + description: format!("分组 '{}' 同时存在全封堵和白名单USB策略,互斥", group.as_deref().unwrap_or("?")), + policies: vec![ + ConflictPolicyRef { + table_name: "usb_policies".to_string(), + row_id: row.get("aid"), + name: row.get("aname"), + target_type: "group".to_string(), + target_id: group.clone(), + }, + ConflictPolicyRef { + table_name: "usb_policies".to_string(), + row_id: row.get("bid"), + name: row.get("bname"), + target_type: "group".to_string(), + target_id: group, + }, + ], + }); + } + } + + // 3. Web filter: same target, same pattern, different rule_type (allow vs block) + if let Ok(rows) = sqlx::query( + "SELECT a.id as aid, a.pattern as apattern, a.rule_type as artype, \ + b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \ + FROM web_filter_rules a JOIN web_filter_rules b ON a.pattern = b.pattern AND a.id < b.id \ + WHERE a.enabled = 1 AND b.enabled = 1 \ + AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \ + AND a.rule_type != b.rule_type" + ) + .fetch_all(&state.db) + .await + { + for row in &rows { + let pattern: String = row.get("apattern"); + let artype: String = row.get("artype"); + let brtype: String = row.get("brtype"); + let ttype: String = row.get("ttype"); + let tid: Option = row.get("tid"); + conflicts.push(PolicyConflict { + conflict_type: "web_filter_allow_vs_block".to_string(), + severity: "high".to_string(), + description: format!("URL '{}' 同时被 {} 和 {},互斥", pattern, artype, brtype), + policies: vec![ + ConflictPolicyRef { + table_name: "web_filter_rules".to_string(), + row_id: row.get("aid"), + name: format!("{}: {}", artype, pattern), + target_type: ttype.clone(), + target_id: tid.clone(), + }, + ConflictPolicyRef { + table_name: "web_filter_rules".to_string(), + row_id: row.get("bid"), + name: format!("{}: {}", brtype, pattern), + target_type: ttype, + target_id: tid, + }, + ], + }); + } + } + + // 4. Clipboard: same source/target process, allow vs block + if let Ok(rows) = sqlx::query( + "SELECT a.id as aid, a.rule_type as artype, a.source_process as asrc, a.target_process as adst, \ + b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \ + FROM clipboard_rules a JOIN clipboard_rules b ON a.id < b.id \ + WHERE a.enabled = 1 AND b.enabled = 1 \ + AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \ + AND a.direction = b.direction \ + AND COALESCE(a.source_process,'') = COALESCE(b.source_process,'') \ + AND COALESCE(a.target_process,'') = COALESCE(b.target_process,'') \ + AND a.rule_type != b.rule_type" + ) + .fetch_all(&state.db) + .await + { + for row in &rows { + let asrc: Option = row.get("asrc"); + let adst: Option = row.get("adst"); + let artype: String = row.get("artype"); + let brtype: String = row.get("brtype"); + let desc = format!( + "剪贴板规则冲突: {} → {} 同时存在 {} 和 {}", + asrc.as_deref().unwrap_or("*"), + adst.as_deref().unwrap_or("*"), + artype, brtype, + ); + let ttype: String = row.get("ttype"); + let tid: Option = row.get("tid"); + conflicts.push(PolicyConflict { + conflict_type: "clipboard_allow_vs_block".to_string(), + severity: "medium".to_string(), + description: desc, + policies: vec![ + ConflictPolicyRef { + table_name: "clipboard_rules".to_string(), + row_id: row.get("aid"), + name: format!("{}: {}", artype, asrc.as_deref().unwrap_or("*")), + target_type: ttype.clone(), + target_id: tid.clone(), + }, + ConflictPolicyRef { + table_name: "clipboard_rules".to_string(), + row_id: row.get("bid"), + name: format!("{}: {}", brtype, asrc.as_deref().unwrap_or("*")), + target_type: ttype, + target_id: tid, + }, + ], + }); + } + } + + // 5. Plugin disabled but has active rules + let plugin_tables: [(&str, &str, &str, &str); 4] = [ + ("web_filter_rules", "上网行为过滤", "web_filter", "SELECT COUNT(*) FROM web_filter_rules WHERE enabled = 1"), + ("software_blacklist", "软件黑名单", "software_blocker", "SELECT COUNT(*) FROM software_blacklist WHERE enabled = 1"), + ("popup_filter_rules", "弹窗拦截", "popup_blocker", "SELECT COUNT(*) FROM popup_filter_rules WHERE enabled = 1"), + ("clipboard_rules", "剪贴板管控", "clipboard_control", "SELECT COUNT(*) FROM clipboard_rules WHERE enabled = 1"), + ]; + for (_table, label, plugin, query) in &plugin_tables { + let active_count: i64 = sqlx::query_scalar(query) + .fetch_one(&state.db) + .await + .unwrap_or(0); + + if active_count > 0 { + let disabled: bool = sqlx::query_scalar::<_, i32>( + "SELECT COUNT(*) FROM plugin_state WHERE plugin_name = ? AND enabled = 0" + ) + .bind(plugin) + .fetch_one(&state.db) + .await + .unwrap_or(0) > 0; + + if disabled { + conflicts.push(PolicyConflict { + conflict_type: "plugin_disabled_with_rules".to_string(), + severity: "low".to_string(), + description: format!("插件 '{}' 已禁用,但仍有 {} 条启用规则,规则不会生效", label, active_count), + policies: vec![ConflictPolicyRef { + table_name: "plugin_state".to_string(), + row_id: 0, + name: plugin.to_string(), + target_type: "global".to_string(), + target_id: None, + }], + }); + } + } + } + + Json(ApiResponse::ok(serde_json::json!({ + "conflicts": conflicts, + "total": conflicts.len(), + "critical_count": conflicts.iter().filter(|c| c.severity == "critical").count(), + "high_count": conflicts.iter().filter(|c| c.severity == "high").count(), + "medium_count": conflicts.iter().filter(|c| c.severity == "medium").count(), + "low_count": conflicts.iter().filter(|c| c.severity == "low").count(), + }))) +} diff --git a/crates/server/src/api/devices.rs b/crates/server/src/api/devices.rs index c8bdcf4..99796fb 100644 --- a/crates/server/src/api/devices.rs +++ b/crates/server/src/api/devices.rs @@ -4,6 +4,28 @@ use sqlx::Row; use crate::AppState; use super::{ApiResponse, Pagination}; +/// GET /api/devices/:uid/health-score +pub async fn get_health_score( + State(state): State, + Path(uid): Path, +) -> Json> { + match crate::health::get_device_score(&state.db, &uid).await { + Ok(Some(score)) => Json(ApiResponse::ok(score)), + Ok(None) => Json(ApiResponse::error("No health score available")), + Err(e) => Json(ApiResponse::internal_error("health score", e)), + } +} + +/// GET /api/dashboard/health-overview +pub async fn health_overview( + State(state): State, +) -> Json> { + match crate::health::get_health_overview(&state.db).await { + Ok(overview) => Json(ApiResponse::ok(overview)), + Err(e) => Json(ApiResponse::internal_error("health overview", e)), + } +} + #[derive(Debug, Deserialize)] pub struct DeviceListParams { pub status: Option, @@ -26,6 +48,10 @@ pub struct DeviceRow { pub last_heartbeat: Option, pub registered_at: Option, pub group_name: Option, + #[sqlx(default)] + pub health_score: Option, + #[sqlx(default)] + pub health_level: Option, } pub async fn list( @@ -41,13 +67,16 @@ pub async fn list( let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from); let devices = sqlx::query_as::<_, DeviceRow>( - "SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version, - status, last_heartbeat, registered_at, group_name - FROM devices WHERE 1=1 - AND (? IS NULL OR status = ?) - AND (? IS NULL OR group_name = ?) - AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%') - ORDER BY registered_at DESC LIMIT ? OFFSET ?" + "SELECT d.id, d.device_uid, d.hostname, d.ip_address, d.mac_address, d.os_version, d.client_version, + d.status, d.last_heartbeat, d.registered_at, d.group_name, + h.score as health_score, h.level as health_level + FROM devices d + LEFT JOIN device_health_scores h ON h.device_uid = d.device_uid + WHERE 1=1 + AND (? IS NULL OR d.status = ?) + AND (? IS NULL OR d.group_name = ?) + AND (? IS NULL OR d.hostname LIKE '%' || ? || '%' OR d.ip_address LIKE '%' || ? || '%') + ORDER BY d.registered_at DESC LIMIT ? OFFSET ?" ) .bind(&status).bind(&status) .bind(&group).bind(&group) @@ -187,16 +216,6 @@ pub async fn remove( State(state): State, Path(uid): Path, ) -> Json> { - // If client is connected, send self-destruct command - let frame = csm_protocol::Frame::new_json( - csm_protocol::MessageType::ConfigUpdate, - &serde_json::json!({"type": "SelfDestruct"}), - ).ok(); - - if let Some(frame) = frame { - state.clients.send_to(&uid, frame.encode()).await; - } - // Delete device and all associated data in a transaction let mut tx = match state.db.begin().await { Ok(tx) => tx, @@ -224,6 +243,8 @@ pub async fn remove( // Delete plugin-related data let cleanup_tables = [ "hardware_assets", + "software_assets", + "asset_changes", "usb_events", "usb_file_operations", "usage_daily", @@ -231,8 +252,20 @@ pub async fn remove( "software_violations", "web_access_log", "popup_block_stats", + "disk_encryption_status", + "disk_encryption_alerts", + "print_events", + "clipboard_violations", + "behavior_metrics", + "anomaly_alerts", + "device_health_scores", + "patch_status", ]; for table in &cleanup_tables { + // Safety: table names are hardcoded constants above, not user input. + // Parameterized ? is used for device_uid. + debug_assert!(table.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'), + "BUG: table name contains unexpected characters: {}", table); if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table)) .bind(&uid) .execute(&mut *tx) @@ -253,6 +286,17 @@ pub async fn remove( if let Err(e) = tx.commit().await { return Json(ApiResponse::internal_error("commit device deletion", e)); } + + // Send self-destruct command AFTER successful commit + let frame = csm_protocol::Frame::new_json( + csm_protocol::MessageType::ConfigUpdate, + &serde_json::json!({"type": "SelfDestruct"}), + ).ok(); + + if let Some(frame) = frame { + state.clients.send_to(&uid, frame.encode()).await; + } + state.clients.unregister(&uid).await; tracing::info!(device_uid = %uid, "Device and all associated data deleted"); Json(ApiResponse::ok(())) diff --git a/crates/server/src/api/groups.rs b/crates/server/src/api/groups.rs index 23d4d21..33640d9 100644 --- a/crates/server/src/api/groups.rs +++ b/crates/server/src/api/groups.rs @@ -50,6 +50,9 @@ pub async fn create_group( if name.is_empty() || name.len() > 50 { return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效")))); } + if name.contains('<') || name.contains('>') || name.contains('"') || name.contains('\'') { + return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符")))); + } // Check if group already exists let exists: bool = sqlx::query_scalar( @@ -78,6 +81,9 @@ pub async fn rename_group( if new_name.is_empty() || new_name.len() > 50 { return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效")))); } + if new_name.contains('<') || new_name.contains('>') || new_name.contains('"') || new_name.contains('\'') { + return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符")))); + } let result = sqlx::query( "UPDATE devices SET group_name = ? WHERE group_name = ?" diff --git a/crates/server/src/api/mod.rs b/crates/server/src/api/mod.rs index 0280b85..93ec840 100644 --- a/crates/server/src/api/mod.rs +++ b/crates/server/src/api/mod.rs @@ -1,4 +1,4 @@ -use axum::{routing::{get, post, put, delete}, Router, Json, middleware}; +use axum::{routing::{get, post, put, delete}, Router, Json, middleware, http::StatusCode, response::IntoResponse}; use serde::{Deserialize, Serialize}; use crate::AppState; @@ -9,23 +9,31 @@ pub mod usb; pub mod alerts; pub mod plugins; pub mod groups; +pub mod conflict; pub fn routes(state: AppState) -> Router { let public = Router::new() .route("/api/auth/login", post(auth::login)) .route("/api/auth/refresh", post(auth::refresh)) + .route("/api/auth/logout", post(auth::logout)) .route("/health", get(health_check)) .with_state(state.clone()); // Read-only routes (any authenticated user) let read_routes = Router::new() // Auth + .route("/api/auth/me", get(auth::me)) .route("/api/auth/change-password", put(auth::change_password)) + // WebSocket ticket (requires auth cookie) + .route("/api/ws/ticket", post(auth::create_ws_ticket)) // Devices .route("/api/devices", get(devices::list)) .route("/api/devices/:uid", get(devices::get_detail)) .route("/api/devices/:uid/status", get(devices::get_status)) .route("/api/devices/:uid/history", get(devices::get_history)) + .route("/api/devices/:uid/health-score", get(devices::get_health_score)) + // Dashboard + .route("/api/dashboard/health-overview", get(devices::health_overview)) // Assets .route("/api/assets/hardware", get(assets::list_hardware)) .route("/api/assets/software", get(assets::list_software)) @@ -40,6 +48,8 @@ pub fn routes(state: AppState) -> Router { .route("/api/alerts/records", get(alerts::list_records)) // Plugin read routes .merge(plugins::read_routes()) + // Policy conflict scan + .route("/api/policies/conflicts", get(conflict::scan_conflicts)) .layer(middleware::from_fn_with_state(state.clone(), auth::require_auth)); // Write routes (admin only) @@ -50,6 +60,8 @@ pub fn routes(state: AppState) -> Router { .route("/api/groups", post(groups::create_group)) .route("/api/groups/:name", put(groups::rename_group).delete(groups::delete_group)) .route("/api/devices/:uid/group", put(groups::move_device)) + // TLS cert rotation + .route("/api/system/tls-rotate", post(system_tls_rotate)) // USB (write) .route("/api/usb/policies", post(usb::create_policy)) .route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy)) @@ -76,6 +88,45 @@ pub fn routes(state: AppState) -> Router { .merge(ws_router) } +/// Trigger TLS certificate rotation for all online devices. +/// Admin sends the new certificate PEM and a transition deadline. +/// The server pushes a ConfigUpdate(TlsCertRotate) to all connected clients. +#[derive(Deserialize)] +struct TlsRotateRequest { + /// Path to the new certificate PEM file + cert_path: String, + /// ISO 8601 timestamp when the old cert stops being valid (transition deadline) + valid_until: String, +} + +#[derive(Serialize)] +struct TlsRotateResponse { + devices_notified: usize, +} + +async fn system_tls_rotate( + axum::extract::State(state): axum::extract::State, + Json(req): Json, +) -> impl IntoResponse { + let cert_pem = match tokio::fs::read(&req.cert_path).await { + Ok(pem) => pem, + Err(e) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiResponse::::error( + format!("Cannot read cert file {}: {}", req.cert_path, e), + )), + ).into_response(); + } + }; + + let count = crate::tcp::push_tls_cert_rotation(&state.clients, &cert_pem, &req.valid_until).await; + + (StatusCode::OK, Json(ApiResponse::ok(TlsRotateResponse { + devices_notified: count, + }))).into_response() +} + #[derive(Serialize)] struct HealthResponse { status: &'static str, diff --git a/crates/server/src/api/plugins/anomaly.rs b/crates/server/src/api/plugins/anomaly.rs new file mode 100644 index 0000000..7a1fb33 --- /dev/null +++ b/crates/server/src/api/plugins/anomaly.rs @@ -0,0 +1,48 @@ +use axum::{extract::{State, Path, Query}, Json}; +use serde::Deserialize; +use crate::AppState; +use crate::api::ApiResponse; + +#[derive(Debug, Deserialize)] +pub struct AnomalyListParams { + pub device_uid: Option, + pub page: Option, + pub page_size: Option, +} + +/// GET /api/plugins/anomaly/alerts +pub async fn list_anomaly_alerts( + State(state): State, + Query(params): Query, +) -> Json> { + let page = params.page.unwrap_or(1); + let page_size = params.page_size.unwrap_or(20).min(100); + let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()); + + match crate::anomaly::get_anomaly_summary(&state.db, device_uid, page, page_size).await { + Ok(result) => Json(ApiResponse::ok(result)), + Err(e) => Json(ApiResponse::internal_error("anomaly alerts", e)), + } +} + +/// PUT /api/plugins/anomaly/alerts/:id/handle +/// Mark an anomaly alert as handled. +pub async fn handle_anomaly_alert( + State(state): State, + Path(id): Path, + claims: axum::Extension, +) -> Json> { + let result = sqlx::query( + "UPDATE anomaly_alerts SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?" + ) + .bind(&claims.username) + .bind(id) + .execute(&state.db) + .await; + + match result { + Ok(r) if r.rows_affected() > 0 => Json(ApiResponse::ok(())), + Ok(_) => Json(ApiResponse::error("Alert not found")), + Err(e) => Json(ApiResponse::internal_error("handle anomaly alert", e)), + } +} diff --git a/crates/server/src/api/plugins/clipboard_control.rs b/crates/server/src/api/plugins/clipboard_control.rs index adc88d2..62f2312 100644 --- a/crates/server/src/api/plugins/clipboard_control.rs +++ b/crates/server/src/api/plugins/clipboard_control.rs @@ -21,7 +21,7 @@ pub struct CreateRuleRequest { pub async fn list_rules(State(state): State) -> Json> { match sqlx::query( "SELECT id, target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled, updated_at \ - FROM clipboard_rules ORDER BY updated_at DESC" + FROM clipboard_rules ORDER BY updated_at DESC LIMIT 500" ) .fetch_all(&state.db) .await @@ -127,6 +127,18 @@ pub async fn update_rule( let content_pattern = body.content_pattern.or_else(|| existing.get::, _>("content_pattern")); let enabled = body.enabled.unwrap_or_else(|| existing.get::("enabled")); + // Validate merged values + if let Some(ref rt) = rule_type { + if !matches!(rt.as_str(), "allow" | "block") { + return Json(ApiResponse::error("rule_type must be 'allow' or 'block'")); + } + } + if let Some(ref d) = direction { + if !matches!(d.as_str(), "in" | "out" | "both") { + return Json(ApiResponse::error("direction must be 'in', 'out', or 'both'")); + } + } + let result = sqlx::query( "UPDATE clipboard_rules SET rule_type = ?, direction = ?, source_process = ?, target_process = ?, content_pattern = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?" ) diff --git a/crates/server/src/api/plugins/disk_encryption.rs b/crates/server/src/api/plugins/disk_encryption.rs index bc9678d..010967a 100644 --- a/crates/server/src/api/plugins/disk_encryption.rs +++ b/crates/server/src/api/plugins/disk_encryption.rs @@ -28,7 +28,7 @@ pub async fn list_status( "SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \ s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \ d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \ - ORDER BY s.device_uid, s.drive_letter" + ORDER BY s.device_uid, s.drive_letter LIMIT 500" ) .fetch_all(&state.db) .await @@ -58,7 +58,7 @@ pub async fn list_alerts(State(state): State) -> Json Router { // Software Blocker .route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist)) .route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations)) + .route("/api/plugins/software-blocker/whitelist", get(software_blocker::list_whitelist)) // Popup Blocker .route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules)) .route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats)) @@ -36,7 +39,6 @@ pub fn read_routes() -> Router { // Disk Encryption .route("/api/plugins/disk-encryption/status", get(disk_encryption::list_status)) .route("/api/plugins/disk-encryption/alerts", get(disk_encryption::list_alerts)) - .route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert)) // Print Audit .route("/api/plugins/print-audit/events", get(print_audit::list_events)) .route("/api/plugins/print-audit/events/:id", get(print_audit::get_event)) @@ -45,6 +47,12 @@ pub fn read_routes() -> Router { .route("/api/plugins/clipboard-control/violations", get(clipboard_control::list_violations)) // Plugin Control .route("/api/plugins/control", get(plugin_control::list_plugins)) + // Patch Management + .route("/api/plugins/patch/status", get(patch::list_patch_status)) + .route("/api/plugins/patch/summary", get(patch::patch_summary)) + .route("/api/plugins/patch/device/:uid", get(patch::device_patches)) + // Anomaly Detection + .route("/api/plugins/anomaly/alerts", get(anomaly::list_anomaly_alerts)) } /// Write plugin routes (admin only — require_admin middleware applied by caller) @@ -56,6 +64,8 @@ pub fn write_routes() -> Router { // Software Blocker .route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist)) .route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist)) + .route("/api/plugins/software-blocker/whitelist", post(software_blocker::add_to_whitelist)) + .route("/api/plugins/software-blocker/whitelist/:id", put(software_blocker::update_whitelist).delete(software_blocker::remove_from_whitelist)) // Popup Blocker .route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule)) .route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule)) @@ -65,6 +75,10 @@ pub fn write_routes() -> Router { // Clipboard Control .route("/api/plugins/clipboard-control/rules", post(clipboard_control::create_rule)) .route("/api/plugins/clipboard-control/rules/:id", put(clipboard_control::update_rule).delete(clipboard_control::delete_rule)) + // Disk Encryption + .route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert)) // Plugin Control (enable/disable) .route("/api/plugins/control/:plugin_name", put(plugin_control::set_plugin_state)) + // Anomaly Detection — handle alert + .route("/api/plugins/anomaly/alerts/:id/handle", put(anomaly::handle_anomaly_alert)) } diff --git a/crates/server/src/api/plugins/patch.rs b/crates/server/src/api/plugins/patch.rs new file mode 100644 index 0000000..b0a72e0 --- /dev/null +++ b/crates/server/src/api/plugins/patch.rs @@ -0,0 +1,146 @@ +use axum::{extract::{State, Path, Query}, Json}; +use serde::Deserialize; +use sqlx::Row; +use crate::AppState; +use crate::api::ApiResponse; + +#[derive(Debug, Deserialize)] +pub struct PatchListParams { + pub device_uid: Option, + pub severity: Option, + #[allow(dead_code)] + pub installed: Option, + pub page: Option, + pub page_size: Option, +} + +/// GET /api/plugins/patch/status +pub async fn list_patch_status( + State(state): State, + Query(params): Query, +) -> Json> { + let limit = params.page_size.unwrap_or(20).min(100); + let offset = params.page.unwrap_or(1).saturating_sub(1) * limit; + + let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()); + let severity = params.severity.as_deref().filter(|s| !s.is_empty()); + + let rows = sqlx::query( + "SELECT p.*, d.hostname FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \ + WHERE 1=1 \ + AND (? IS NULL OR p.device_uid = ?) \ + AND (? IS NULL OR p.severity = ?) \ + ORDER BY p.updated_at DESC LIMIT ? OFFSET ?" + ) + .bind(device_uid).bind(device_uid) + .bind(severity).bind(severity) + .bind(limit).bind(offset) + .fetch_all(&state.db) + .await; + + match rows { + Ok(records) => { + let items: Vec = records.iter().map(|r| serde_json::json!({ + "id": r.get::("id"), + "device_uid": r.get::("device_uid"), + "hostname": r.get::("hostname"), + "kb_id": r.get::("kb_id"), + "title": r.get::("title"), + "severity": r.get::, _>("severity"), + "is_installed": r.get::("is_installed"), + "installed_at": r.get::, _>("installed_at"), + "updated_at": r.get::("updated_at"), + })).collect(); + + // Summary stats (scoped to same filters as main query) + let total_installed: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM patch_status WHERE is_installed = 1 \ + AND (? IS NULL OR device_uid = ?) \ + AND (? IS NULL OR severity = ?)" + ) + .bind(device_uid).bind(device_uid) + .bind(severity).bind(severity) + .fetch_one(&state.db).await.unwrap_or(0); + let total_missing: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM patch_status WHERE is_installed = 0 \ + AND (? IS NULL OR device_uid = ?) \ + AND (? IS NULL OR severity = ?)" + ) + .bind(device_uid).bind(device_uid) + .bind(severity).bind(severity) + .fetch_one(&state.db).await.unwrap_or(0); + + Json(ApiResponse::ok(serde_json::json!({ + "patches": items, + "summary": { + "total_installed": total_installed, + "total_missing": total_missing, + }, + "page": params.page.unwrap_or(1), + "page_size": limit, + }))) + } + Err(e) => Json(ApiResponse::internal_error("query patch status", e)), + } +} + +/// GET /api/plugins/patch/summary — per-device patch summary +pub async fn patch_summary( + State(state): State, +) -> Json> { + let rows = sqlx::query( + "SELECT p.device_uid, d.hostname, \ + COUNT(*) as total_patches, \ + SUM(CASE WHEN p.is_installed = 1 THEN 1 ELSE 0 END) as installed, \ + SUM(CASE WHEN p.is_installed = 0 THEN 1 ELSE 0 END) as missing, \ + MAX(p.updated_at) as last_scan \ + FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \ + GROUP BY p.device_uid ORDER BY missing DESC" + ) + .fetch_all(&state.db) + .await; + + match rows { + Ok(records) => { + let devices: Vec = records.iter().map(|r| serde_json::json!({ + "device_uid": r.get::("device_uid"), + "hostname": r.get::("hostname"), + "total_patches": r.get::("total_patches"), + "installed": r.get::("installed"), + "missing": r.get::("missing"), + "last_scan": r.get::("last_scan"), + })).collect(); + Json(ApiResponse::ok(serde_json::json!({ "devices": devices }))) + } + Err(e) => Json(ApiResponse::internal_error("patch summary", e)), + } +} + +/// GET /api/plugins/patch/device/:uid — patches for a single device +pub async fn device_patches( + State(state): State, + Path(uid): Path, +) -> Json> { + let rows = sqlx::query( + "SELECT kb_id, title, severity, is_installed, installed_at, updated_at \ + FROM patch_status WHERE device_uid = ? ORDER BY updated_at DESC" + ) + .bind(&uid) + .fetch_all(&state.db) + .await; + + match rows { + Ok(records) => { + let patches: Vec = records.iter().map(|r| serde_json::json!({ + "kb_id": r.get::("kb_id"), + "title": r.get::("title"), + "severity": r.get::, _>("severity"), + "is_installed": r.get::("is_installed"), + "installed_at": r.get::, _>("installed_at"), + "updated_at": r.get::("updated_at"), + })).collect(); + Json(ApiResponse::ok(serde_json::json!({ "patches": patches }))) + } + Err(e) => Json(ApiResponse::internal_error("device patches", e)), + } +} diff --git a/crates/server/src/api/plugins/popup_blocker.rs b/crates/server/src/api/plugins/popup_blocker.rs index d62bf9c..3523f00 100644 --- a/crates/server/src/api/plugins/popup_blocker.rs +++ b/crates/server/src/api/plugins/popup_blocker.rs @@ -17,7 +17,7 @@ pub struct CreateRuleRequest { } pub async fn list_rules(State(state): State) -> Json> { - match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC") + match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC LIMIT 500") .fetch_all(&state.db).await { Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({ "id": r.get::("id"), "rule_type": r.get::("rule_type"), @@ -47,6 +47,16 @@ pub async fn create_rule(State(state): State, Json(req): Json 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_title too long (max 255)"))); } + } + if let Some(ref c) = req.window_class { + if c.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_class too long (max 255)"))); } + } + if let Some(ref p) = req.process_name { + if p.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("process_name too long (max 255)"))); } + } match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)") .bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id) @@ -81,6 +91,14 @@ pub async fn update_rule(State(state): State, Path(id): Path, Jso let process_name = body.process_name.or_else(|| existing.get::, _>("process_name")); let enabled = body.enabled.unwrap_or_else(|| existing.get::("enabled")); + // Ensure at least one filter is non-empty after update + let has_filter = window_title.as_ref().map_or(false, |s| !s.is_empty()) + || window_class.as_ref().map_or(false, |s| !s.is_empty()) + || process_name.as_ref().map_or(false, |s| !s.is_empty()); + if !has_filter { + return Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")); + } + let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?") .bind(&window_title) .bind(&window_class) diff --git a/crates/server/src/api/plugins/software_blocker.rs b/crates/server/src/api/plugins/software_blocker.rs index 23a36b6..d16385b 100644 --- a/crates/server/src/api/plugins/software_blocker.rs +++ b/crates/server/src/api/plugins/software_blocker.rs @@ -1,7 +1,7 @@ use axum::{extract::{State, Path, Query, Json}, http::StatusCode}; use serde::Deserialize; use sqlx::Row; -use csm_protocol::MessageType; +use csm_protocol::{Frame, MessageType}; use crate::AppState; use crate::api::ApiResponse; use crate::tcp::push_to_targets; @@ -16,7 +16,7 @@ pub struct CreateBlacklistRequest { } pub async fn list_blacklist(State(state): State) -> Json> { - match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC") + match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC LIMIT 500") .fetch_all(&state.db).await { Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({ "id": r.get::("id"), "name_pattern": r.get::("name_pattern"), @@ -53,8 +53,8 @@ pub async fn add_to_blacklist(State(state): State, Json(req): Json { let new_id = r.last_insert_rowid(); - let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await; - push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await; + let payload = fetch_software_payload_for_push(&state.db, &target_type, req.target_id.as_deref()).await; + push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, req.target_id.as_deref()).await; (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id})))) } Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))), @@ -80,6 +80,14 @@ pub async fn update_blacklist(State(state): State, Path(id): Path let action = body.action.unwrap_or_else(|| existing.get::("action")); let enabled = body.enabled.unwrap_or_else(|| existing.get::("enabled")); + // Input validation (same as create) + if name_pattern.trim().is_empty() || name_pattern.len() > 255 { + return Json(ApiResponse::error("name_pattern must be 1-255 chars")); + } + if !matches!(action.as_str(), "block" | "alert") { + return Json(ApiResponse::error("action must be 'block' or 'alert'")); + } + let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?") .bind(&name_pattern) .bind(&action) @@ -92,8 +100,8 @@ pub async fn update_blacklist(State(state): State, Path(id): Path Ok(r) if r.rows_affected() > 0 => { let target_type_val: String = existing.get("target_type"); let target_id_val: Option = existing.get("target_id"); - let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await; - push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await; + let payload = fetch_software_payload_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await; + push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type_val, target_id_val.as_deref()).await; Json(ApiResponse::ok(())) } Ok(_) => Json(ApiResponse::error("Not found")), @@ -110,8 +118,8 @@ pub async fn remove_from_blacklist(State(state): State, Path(id): Path }; match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await { Ok(r) if r.rows_affected() > 0 => { - let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await; - push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await; + let payload = fetch_software_payload_for_push(&state.db, &target_type, target_id.as_deref()).await; + push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, target_id.as_deref()).await; Json(ApiResponse::ok(())) } _ => Json(ApiResponse::error("Not found")), @@ -134,6 +142,29 @@ pub async fn list_violations(State(state): State, Query(f): Query, +) -> serde_json::Value { + let blacklist = fetch_blacklist_for_push(db, target_type, target_id).await; + + // Whitelist is always global — fetch all enabled entries + let whitelist: Vec = sqlx::query_scalar( + "SELECT name_pattern FROM software_whitelist WHERE enabled = 1" + ) + .fetch_all(db) + .await + .unwrap_or_default(); + + serde_json::json!({ + "blacklist": blacklist, + "whitelist": whitelist, + }) +} + async fn fetch_blacklist_for_push( db: &sqlx::SqlitePool, target_type: &str, @@ -156,3 +187,112 @@ async fn fetch_blacklist_for_push( })).collect()) .unwrap_or_default() } + +// ─── Whitelist management ─── + +/// GET /api/plugins/software-blocker/whitelist +pub async fn list_whitelist(State(state): State) -> Json> { + match sqlx::query("SELECT id, name_pattern, reason, is_builtin, enabled, created_at FROM software_whitelist ORDER BY is_builtin DESC, created_at ASC LIMIT 500") + .fetch_all(&state.db).await { + Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"whitelist": rows.iter().map(|r| serde_json::json!({ + "id": r.get::("id"), + "name_pattern": r.get::("name_pattern"), + "reason": r.get::,_>("reason"), + "is_builtin": r.get::("is_builtin"), + "enabled": r.get::("enabled"), + "created_at": r.get::("created_at") + })).collect::>()}))), + Err(e) => Json(ApiResponse::internal_error("query software whitelist", e)), + } +} + +#[derive(Debug, Deserialize)] +pub struct CreateWhitelistRequest { + pub name_pattern: String, + pub reason: Option, +} + +/// POST /api/plugins/software-blocker/whitelist +pub async fn add_to_whitelist(State(state): State, Json(req): Json) -> (StatusCode, Json>) { + if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 { + return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars"))); + } + + match sqlx::query("INSERT INTO software_whitelist (name_pattern, reason) VALUES (?, ?)") + .bind(&req.name_pattern).bind(&req.reason) + .execute(&state.db).await { + Ok(r) => { + let new_id = r.last_insert_rowid(); + // Push updated whitelist to all online clients + push_whitelist_to_all(&state).await; + (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id})))) + } + Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add whitelist entry", e))), + } +} + +/// PUT /api/plugins/software-blocker/whitelist/:id +#[derive(Debug, Deserialize)] +pub struct UpdateWhitelistRequest { + pub name_pattern: Option, + pub enabled: Option, +} + +pub async fn update_whitelist(State(state): State, Path(id): Path, Json(body): Json) -> Json> { + let existing = sqlx::query("SELECT name_pattern, enabled FROM software_whitelist WHERE id = ?") + .bind(id).fetch_optional(&state.db).await; + + let existing = match existing { + Ok(Some(row)) => row, + Ok(None) => return Json(ApiResponse::error("Not found")), + Err(e) => return Json(ApiResponse::internal_error("query whitelist", e)), + }; + + let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::("name_pattern")); + let enabled = body.enabled.unwrap_or_else(|| existing.get::("enabled")); + + // Input validation — validate merged value + if name_pattern.trim().is_empty() || name_pattern.len() > 255 { + return Json(ApiResponse::error("name_pattern must be 1-255 chars")); + } + + match sqlx::query("UPDATE software_whitelist SET name_pattern = ?, enabled = ? WHERE id = ?") + .bind(&name_pattern).bind(enabled).bind(id) + .execute(&state.db).await { + Ok(r) if r.rows_affected() > 0 => { + push_whitelist_to_all(&state).await; + Json(ApiResponse::ok(())) + } + Ok(_) => Json(ApiResponse::error("Not found")), + Err(e) => Json(ApiResponse::internal_error("update whitelist", e)), + } +} + +/// DELETE /api/plugins/software-blocker/whitelist/:id +pub async fn remove_from_whitelist(State(state): State, Path(id): Path) -> Json> { + match sqlx::query("DELETE FROM software_whitelist WHERE id = ? AND is_builtin = 0") + .bind(id).execute(&state.db).await { + Ok(r) if r.rows_affected() > 0 => { + push_whitelist_to_all(&state).await; + Json(ApiResponse::ok(())) + } + Ok(_) => Json(ApiResponse::error("Not found or is built-in entry")), + Err(e) => Json(ApiResponse::internal_error("remove whitelist entry", e)), + } +} + +/// Push updated whitelist to all online clients by resending the full software control config. +async fn push_whitelist_to_all(state: &AppState) { + // Fetch payload once, then broadcast to all online clients + let payload = fetch_software_payload_for_push(&state.db, "global", None).await; + let frame = match Frame::new_json(MessageType::SoftwareBlacklist, &payload) { + Ok(f) => f.encode(), + Err(_) => return, + }; + + let online = state.clients.list_online().await; + for uid in &online { + state.clients.send_to(uid, frame.clone()).await; + } + tracing::info!("Pushed updated whitelist to {} online clients", online.len()); +} \ No newline at end of file diff --git a/crates/server/src/api/plugins/web_filter.rs b/crates/server/src/api/plugins/web_filter.rs index a9fc614..dbbcf5f 100644 --- a/crates/server/src/api/plugins/web_filter.rs +++ b/crates/server/src/api/plugins/web_filter.rs @@ -16,7 +16,7 @@ pub struct CreateRuleRequest { } pub async fn list_rules(State(state): State) -> Json> { - match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC") + match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC LIMIT 500") .fetch_all(&state.db).await { Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({ "id": r.get::("id"), "rule_type": r.get::("rule_type"), @@ -75,6 +75,14 @@ pub async fn update_rule(State(state): State, Path(id): Path, Jso let pattern = body.pattern.unwrap_or_else(|| existing.get::("pattern")); let enabled = body.enabled.unwrap_or_else(|| existing.get::("enabled")); + // Validate merged values + if !matches!(rule_type.as_str(), "blacklist" | "whitelist" | "category") { + return Json(ApiResponse::error("rule_type must be 'blacklist', 'whitelist', or 'category'")); + } + if pattern.trim().is_empty() || pattern.len() > 255 { + return Json(ApiResponse::error("pattern must be 1-255 chars")); + } + let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?") .bind(&rule_type) .bind(&pattern) @@ -114,17 +122,42 @@ pub async fn delete_rule(State(state): State, Path(id): Path) -> } #[derive(Debug, Deserialize)] -pub struct LogFilters { pub device_uid: Option, pub action: Option } +pub struct LogFilters { + pub device_uid: Option, + pub action: Option, + pub page: Option, + pub page_size: Option, +} pub async fn list_access_log(State(state): State, Query(f): Query) -> Json> { - match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200") - .bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action) - .fetch_all(&state.db).await { - Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({ - "id": r.get::("id"), "device_uid": r.get::("device_uid"), - "url": r.get::("url"), "action": r.get::("action"), - "timestamp": r.get::("timestamp") - })).collect::>() }))), + let limit = f.page_size.unwrap_or(20).min(100); + let offset = f.page.unwrap_or(1).saturating_sub(1) * limit; + let device_uid = f.device_uid.as_deref().filter(|s| !s.is_empty()); + let action = f.action.as_deref().filter(|s| !s.is_empty()); + + let rows = sqlx::query( + "SELECT id, device_uid, url, action, timestamp FROM web_access_log \ + WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) \ + ORDER BY timestamp DESC LIMIT ? OFFSET ?" + ) + .bind(device_uid).bind(device_uid) + .bind(action).bind(action) + .bind(limit).bind(offset) + .fetch_all(&state.db) + .await; + + match rows { + Ok(records) => Json(ApiResponse::ok(serde_json::json!({ + "log": records.iter().map(|r| serde_json::json!({ + "id": r.get::("id"), + "device_uid": r.get::("device_uid"), + "url": r.get::("url"), + "action": r.get::("action"), + "timestamp": r.get::("timestamp") + })).collect::>(), + "page": f.page.unwrap_or(1), + "page_size": limit, + }))), Err(e) => Json(ApiResponse::internal_error("query web access log", e)), } } diff --git a/crates/server/src/api/usb.rs b/crates/server/src/api/usb.rs index 9c0ae9f..21ecb82 100644 --- a/crates/server/src/api/usb.rs +++ b/crates/server/src/api/usb.rs @@ -66,7 +66,7 @@ pub async fn list_policies( ) -> Json> { let rows = sqlx::query( "SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at - FROM usb_policies ORDER BY created_at DESC" + FROM usb_policies ORDER BY created_at DESC LIMIT 500" ) .fetch_all(&state.db) .await; @@ -106,6 +106,11 @@ pub async fn create_policy( ) -> (StatusCode, Json>) { let enabled = body.enabled.unwrap_or(1); + // Input validation + if body.name.trim().is_empty() || body.name.len() > 100 { + return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name must be 1-100 chars"))); + } + let result = sqlx::query( "INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)" ) diff --git a/crates/server/src/health.rs b/crates/server/src/health.rs new file mode 100644 index 0000000..044f0f8 --- /dev/null +++ b/crates/server/src/health.rs @@ -0,0 +1,315 @@ +use crate::AppState; +use sqlx::Row; +use tracing::{info, error}; + +/// Background task: recompute device health scores every 5 minutes +pub async fn health_score_task(state: AppState) { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300)); + // First computation runs immediately, then every 5 minutes + + loop { + interval.tick().await; + if let Err(e) = recompute_all_scores(&state).await { + error!("Health score computation failed: {}", e); + } + } +} + +async fn recompute_all_scores(state: &AppState) -> anyhow::Result<()> { + // Get all device UIDs + let devices: Vec = sqlx::query_scalar( + "SELECT device_uid FROM devices" + ) + .fetch_all(&state.db) + .await?; + + let mut computed = 0u32; + let mut errors = 0u32; + + for uid in &devices { + match compute_and_store_score(&state.db, uid).await { + Ok(score) => { + computed += 1; + tracing::debug!("Health score for {}: {} ({})", uid, score.score, score.level); + } + Err(e) => { + errors += 1; + error!("Failed to compute health score for {}: {}", uid, e); + } + } + } + + if computed > 0 { + info!("Health scores computed: {} devices, {} errors", computed, errors); + } + + Ok(()) +} + +struct HealthScoreResult { + score: i32, + status_score: i32, + encryption_score: i32, + load_score: i32, + alert_score: i32, + compliance_score: i32, + patch_score: i32, + level: String, + details: String, +} + +async fn compute_and_store_score( + pool: &sqlx::SqlitePool, + device_uid: &str, +) -> anyhow::Result { + let mut details = Vec::new(); + + // 1. Online status (15 points) + let status_score: i32 = sqlx::query_scalar( + "SELECT CASE WHEN status = 'online' THEN 15 ELSE 0 END FROM devices WHERE device_uid = ?" + ) + .bind(device_uid) + .fetch_one(pool) + .await + .unwrap_or(0); + + if status_score < 15 { + details.push("设备离线".to_string()); + } + + // 2. Disk encryption (20 points) + let encryption_score: i32 = sqlx::query_scalar( + "SELECT CASE \ + WHEN COUNT(*) = 0 THEN 10 \ + WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) = COUNT(*) THEN 20 \ + WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) > 0 THEN 10 \ + ELSE 0 END \ + FROM disk_encryption_status WHERE device_uid = ?" + ) + .bind(device_uid) + .fetch_one(pool) + .await + .unwrap_or(0); + + if encryption_score < 20 { + let unencrypted: Vec = sqlx::query_scalar( + "SELECT drive_letter FROM disk_encryption_status WHERE device_uid = ? AND protection_status != 'On'" + ) + .bind(device_uid) + .fetch_all(pool) + .await + .unwrap_or_default(); + if unencrypted.is_empty() && encryption_score < 20 { + details.push("未检测到加密状态".to_string()); + } else if !unencrypted.is_empty() { + details.push(format!("未加密驱动器: {}", unencrypted.join(", "))); + } + } + + // 3. System load (20 points): CPU(7) + Memory(7) + Disk(6) + let load_row = sqlx::query( + "SELECT cpu_usage, memory_usage, disk_usage FROM device_status WHERE device_uid = ?" + ) + .bind(device_uid) + .fetch_optional(pool) + .await?; + + let load_score = if let Some(row) = load_row { + let cpu = row.get::("cpu_usage"); + let mem = row.get::("memory_usage"); + let disk = row.get::("disk_usage"); + + let cpu_pts = if cpu < 70.0 { 7 } else if cpu < 90.0 { 4 } else { 0 }; + let mem_pts = if mem < 80.0 { 7 } else if mem < 95.0 { 4 } else { 0 }; + let disk_pts = if disk < 80.0 { 6 } else if disk < 95.0 { 3 } else { 0 }; + + let total = cpu_pts + mem_pts + disk_pts; + if cpu >= 90.0 { details.push(format!("CPU过高 ({:.0}%)", cpu)); } + if mem >= 95.0 { details.push(format!("内存过高 ({:.0}%)", mem)); } + if disk >= 95.0 { details.push(format!("磁盘空间不足 ({:.0}%)", disk)); } + total + } else { + details.push("无状态数据".to_string()); + 0 + }; + + // 4. Alert clearance (15 points) + let unhandled_alerts: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM alert_records WHERE device_uid = ? AND handled = 0" + ) + .bind(device_uid) + .fetch_one(pool) + .await + .unwrap_or(0); + + let alert_score: i32 = if unhandled_alerts == 0 { 15 } else { 0 }; + if unhandled_alerts > 0 { + details.push(format!("{}条未处理告警", unhandled_alerts)); + } + + // 5. Compliance (10 points): no recent software violations + let recent_violations: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM software_violations WHERE device_uid = ? AND timestamp > datetime('now', '-7 days')" + ) + .bind(device_uid) + .fetch_one(pool) + .await + .unwrap_or(0); + + let compliance_score: i32 = if recent_violations == 0 { 10 } else { + details.push(format!("近期{}次软件违规", recent_violations)); + (10 - (recent_violations as i32).min(10)).max(0) + }; + + // 6. Patch status (20 points): reserved for future patch management + // For now, give full score if device is online + let patch_score: i32 = if status_score > 0 { 20 } else { 10 }; + + let score = status_score + encryption_score + load_score + alert_score + compliance_score + patch_score; + let level = if score >= 80 { + "healthy" + } else if score >= 50 { + "warning" + } else if score > 0 { + "critical" + } else { + "unknown" + }; + + let details_json = if details.is_empty() { + "[]".to_string() + } else { + serde_json::to_string(&details).unwrap_or_else(|_| "[]".to_string()) + }; + + let result = HealthScoreResult { + score, + status_score, + encryption_score, + load_score, + alert_score, + compliance_score, + patch_score, + level: level.to_string(), + details: details_json, + }; + + // Upsert the score + sqlx::query( + "INSERT INTO device_health_scores \ + (device_uid, score, status_score, encryption_score, load_score, alert_score, compliance_score, patch_score, level, details, computed_at) \ + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) \ + ON CONFLICT(device_uid) DO UPDATE SET \ + score = excluded.score, status_score = excluded.status_score, \ + encryption_score = excluded.encryption_score, load_score = excluded.load_score, \ + alert_score = excluded.alert_score, compliance_score = excluded.compliance_score, \ + patch_score = excluded.patch_score, level = excluded.level, \ + details = excluded.details, computed_at = datetime('now')" + ) + .bind(device_uid) + .bind(result.score) + .bind(result.status_score) + .bind(result.encryption_score) + .bind(result.load_score) + .bind(result.alert_score) + .bind(result.compliance_score) + .bind(result.patch_score) + .bind(&result.level) + .bind(&result.details) + .execute(pool) + .await?; + + Ok(result) +} + +/// Compute a single device's health score on demand +pub async fn get_device_score( + pool: &sqlx::SqlitePool, + device_uid: &str, +) -> anyhow::Result> { + // Try to get cached score + let row = sqlx::query( + "SELECT score, status_score, encryption_score, load_score, alert_score, compliance_score, \ + patch_score, level, details, computed_at \ + FROM device_health_scores WHERE device_uid = ?" + ) + .bind(device_uid) + .fetch_optional(pool) + .await?; + + match row { + Some(r) => Ok(Some(serde_json::json!({ + "device_uid": device_uid, + "score": r.get::("score"), + "breakdown": { + "status": r.get::("status_score"), + "encryption": r.get::("encryption_score"), + "load": r.get::("load_score"), + "alerts": r.get::("alert_score"), + "compliance": r.get::("compliance_score"), + "patches": r.get::("patch_score"), + }, + "level": r.get::("level"), + "details": serde_json::from_str::( + &r.get::("details") + ).unwrap_or(serde_json::json!([])), + "computed_at": r.get::("computed_at"), + }))), + None => Ok(None), + } +} + +/// Get health overview for all devices (dashboard aggregation) +pub async fn get_health_overview(pool: &sqlx::SqlitePool) -> anyhow::Result { + let rows = sqlx::query( + "SELECT h.device_uid, h.score, h.level, d.hostname, d.status, d.group_name \ + FROM device_health_scores h \ + JOIN devices d ON d.device_uid = h.device_uid \ + ORDER BY h.score ASC" + ) + .fetch_all(pool) + .await?; + + let mut healthy = 0u32; + let mut warning = 0u32; + let mut critical = 0u32; + let mut unknown = 0u32; + let mut total_score = 0i64; + + let mut devices: Vec = Vec::with_capacity(rows.len()); + + for r in &rows { + let level: String = r.get("level"); + match level.as_str() { + "healthy" => healthy += 1, + "warning" => warning += 1, + "critical" => critical += 1, + _ => unknown += 1, + } + total_score += r.get::("score") as i64; + + devices.push(serde_json::json!({ + "device_uid": r.get::("device_uid"), + "hostname": r.get::("hostname"), + "status": r.get::("status"), + "group_name": r.get::("group_name"), + "score": r.get::("score"), + "level": level, + })); + } + + let total = devices.len().max(1); + let avg_score = total_score as f64 / total as f64; + + Ok(serde_json::json!({ + "summary": { + "total": total, + "healthy": healthy, + "warning": warning, + "critical": critical, + "unknown": unknown, + "avg_score": (avg_score * 10.0).round() / 10.0, + }, + "devices": devices, + })) +} diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index fa98af1..3ac2f10 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -7,6 +7,7 @@ use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode}; use std::path::Path; use std::str::FromStr; use std::sync::Arc; +use std::collections::HashMap; use tokio::net::TcpListener; use axum::http::Method as HttpMethod; use tower_http::cors::CorsLayer; @@ -23,6 +24,8 @@ mod db; mod tcp; mod ws; mod alert; +mod health; +mod anomaly; use config::AppConfig; @@ -38,6 +41,7 @@ pub struct AppState { pub clients: Arc, pub ws_hub: Arc, pub login_limiter: Arc, + pub ws_tickets: Arc>>, } #[tokio::main] @@ -58,15 +62,19 @@ async fn main() -> Result<()> { // Security checks if config.registration_token.is_empty() { - warn!("SECURITY: registration_token is empty — any device can register!"); + anyhow::bail!("FATAL: registration_token is empty. Set it in config.toml or via CSM_REGISTRATION_TOKEN env var. Device registration is disabled for security."); } - if config.auth.jwt_secret.len() < 32 { - warn!("SECURITY: jwt_secret is too short ({} chars) — consider using a 32+ byte key from CSM_JWT_SECRET env var", config.auth.jwt_secret.len()); + if config.auth.jwt_secret.is_empty() || config.auth.jwt_secret.len() < 32 { + anyhow::bail!("FATAL: jwt_secret is missing or too short. Set CSM_JWT_SECRET env var with a 32+ byte random key."); + } + if config.server.tls.is_none() { + warn!("SECURITY: No TLS configured — all TCP communication is plaintext. Configure [server.tls] for production."); + if std::env::var("CSM_DEV").is_err() { + warn!("Set CSM_DEV=1 to suppress this warning in development environments."); + } } let config = Arc::new(config); - - // Initialize database let db = init_database(&config.database.path).await?; run_migrations(&db).await?; info!("Database initialized at {}", config.database.path); @@ -84,6 +92,7 @@ async fn main() -> Result<()> { clients: clients.clone(), ws_hub: ws_hub.clone(), login_limiter: Arc::new(api::auth::LoginRateLimiter::new()), + ws_tickets: Arc::new(tokio::sync::Mutex::new(HashMap::new())), }; // Start background tasks @@ -92,6 +101,12 @@ async fn main() -> Result<()> { alert::cleanup_task(cleanup_state).await; }); + // Health score computation task + let health_state = state.clone(); + tokio::spawn(async move { + health::health_score_task(health_state).await; + }); + // Start TCP listener for client connections let tcp_state = state.clone(); let tcp_addr = config.server.tcp_addr.clone(); @@ -131,7 +146,11 @@ async fn main() -> Result<()> { )) .layer(SetResponseHeaderLayer::if_not_present( axum::http::header::HeaderName::from_static("content-security-policy"), - axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"), + axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' wss:; font-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"), + )) + .layer(SetResponseHeaderLayer::if_not_present( + axum::http::header::HeaderName::from_static("permissions-policy"), + axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=(), payment=()"), )) .with_state(state); @@ -197,6 +216,9 @@ async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> { include_str!("../../../migrations/014_clipboard_control.sql"), include_str!("../../../migrations/015_plugin_control.sql"), include_str!("../../../migrations/016_encryption_alerts_unique.sql"), + include_str!("../../../migrations/017_device_health_scores.sql"), + include_str!("../../../migrations/018_patch_management.sql"), + include_str!("../../../migrations/019_software_whitelist.sql"), ]; // Create migrations tracking table @@ -257,11 +279,27 @@ async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> { .await?; warn!("Created default admin user (username: admin)"); - // Print password directly to stderr — bypasses tracing JSON formatter - eprintln!("============================================================"); - eprintln!(" Generated admin password: {}", random_password); - eprintln!(" *** Save this password now — it will NOT be shown again! ***"); - eprintln!("============================================================"); + // Write password to restricted file instead of stderr (avoid log capture) + let pw_path = std::path::Path::new("data/initial-password.txt"); + if let Some(parent) = pw_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + match std::fs::write(pw_path, &random_password) { + Ok(_) => { + warn!("Initial admin password saved to data/initial-password.txt (delete after first login)"); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = std::fs::set_permissions(pw_path, std::fs::Permissions::from_mode(0o600)); + } + #[cfg(not(unix))] + { + // Windows: restrict ACL would require windows-rs; at minimum hide the file + let _ = std::process::Command::new("attrib").args(["+H", &pw_path.to_string_lossy()]).output(); + } + } + Err(e) => warn!("Failed to save initial password to file: {}. Password was: {}", e, random_password), + } } Ok(()) @@ -278,13 +316,14 @@ fn build_cors_layer(origins: &[String]) -> CorsLayer { .collect(); if allowed_origins.is_empty() { - // No CORS — production safe by default + // No CORS — production safe by default (same-origin cookies work without CORS) CorsLayer::new() } else { CorsLayer::new() .allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins)) .allow_methods([HttpMethod::GET, HttpMethod::POST, HttpMethod::PUT, HttpMethod::DELETE]) .allow_headers([axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE]) + .allow_credentials(true) .max_age(std::time::Duration::from_secs(3600)) } } diff --git a/crates/server/src/tcp.rs b/crates/server/src/tcp.rs index 869c718..b63049a 100644 --- a/crates/server/src/tcp.rs +++ b/crates/server/src/tcp.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU32, Ordering}; use std::time::Instant; use tokio::sync::RwLock; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::TcpListener; use tracing::{info, warn, debug}; use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -167,7 +167,7 @@ pub async fn push_all_plugin_configs( } } - // Software blacklist + // Software blacklist + whitelist if let Ok(rows) = sqlx::query( "SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))" ) @@ -180,8 +180,20 @@ pub async fn push_all_plugin_configs( "name_pattern": r.get::("name_pattern"), "action": r.get::("action"), })).collect(); - if !entries.is_empty() { - if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) { + + // Fetch whitelist (global, always pushed to all devices) + let whitelist: Vec = sqlx::query_scalar( + "SELECT name_pattern FROM software_whitelist WHERE enabled = 1" + ) + .fetch_all(db) + .await + .unwrap_or_default(); + + if !entries.is_empty() || !whitelist.is_empty() { + if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({ + "blacklist": entries, + "whitelist": whitelist, + })) { clients.send_to(device_uid, frame.encode()).await; } } @@ -261,17 +273,53 @@ pub async fn push_all_plugin_configs( } } - // Disk encryption config — push default reporting interval (no dedicated config table) + // Disk encryption config — read from patch_policies if available, else default { - let config = csm_protocol::DiskEncryptionConfigPayload { - enabled: true, - report_interval_secs: 3600, + let config = if let Ok(Some(row)) = sqlx::query( + "SELECT auto_approve, enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1" + ) + .fetch_optional(db) + .await + { + // If patch_policies exist, infer disk encryption should be enabled + csm_protocol::DiskEncryptionConfigPayload { + enabled: row.get::("enabled") != 0, + report_interval_secs: 3600, + } + } else { + csm_protocol::DiskEncryptionConfigPayload { + enabled: true, + report_interval_secs: 3600, + } }; if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionConfig, &config) { clients.send_to(device_uid, frame.encode()).await; } } + // Patch scan config — read from patch_policies if available, else default + { + let config = if let Ok(Some(row)) = sqlx::query( + "SELECT enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1" + ) + .fetch_optional(db) + .await + { + csm_protocol::PatchScanConfigPayload { + enabled: row.get::("enabled") != 0, + scan_interval_secs: 43200, + } + } else { + csm_protocol::PatchScanConfigPayload { + enabled: true, + scan_interval_secs: 43200, + } + }; + if let Ok(frame) = Frame::new_json(MessageType::PatchScanConfig, &config) { + clients.send_to(device_uid, frame.encode()).await; + } + } + // Push plugin enable/disable state — disable any plugins that admin has turned off if let Ok(rows) = sqlx::query( "SELECT plugin_name FROM plugin_state WHERE enabled = 0" @@ -297,10 +345,17 @@ pub async fn push_all_plugin_configs( /// Maximum accumulated read buffer size per connection (8 MB) const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024; -/// Registry of all connected client sessions +/// Registry of all connected client sessions, including cached device secrets. #[derive(Clone, Default)] pub struct ClientRegistry { - sessions: Arc>>>>>, + sessions: Arc>>, +} + +/// Per-device session data kept in memory for fast access. +struct ClientSession { + tx: Arc>>, + /// Cached device_secret for HMAC verification — avoids a DB query per heartbeat. + secret: Option, } impl ClientRegistry { @@ -308,8 +363,8 @@ impl ClientRegistry { Self::default() } - pub async fn register(&self, device_uid: String, tx: Arc>>) { - self.sessions.write().await.insert(device_uid, tx); + pub async fn register(&self, device_uid: String, secret: Option, tx: Arc>>) { + self.sessions.write().await.insert(device_uid, ClientSession { tx, secret }); } pub async fn unregister(&self, device_uid: &str) { @@ -317,13 +372,25 @@ impl ClientRegistry { } pub async fn send_to(&self, device_uid: &str, data: Vec) -> bool { - if let Some(tx) = self.sessions.read().await.get(device_uid) { - tx.send(data).await.is_ok() + if let Some(session) = self.sessions.read().await.get(device_uid) { + session.tx.send(data).await.is_ok() } else { false } } + /// Get cached device secret for HMAC verification (avoids DB query per heartbeat). + pub async fn get_secret(&self, device_uid: &str) -> Option { + self.sessions.read().await.get(device_uid).and_then(|s| s.secret.clone()) + } + + /// Backfill cached device secret after a cache miss (e.g. server restart). + pub async fn set_secret(&self, device_uid: &str, secret: String) { + if let Some(session) = self.sessions.write().await.get_mut(device_uid) { + session.secret = Some(secret); + } + } + pub async fn count(&self) -> usize { self.sessions.read().await.len() } @@ -366,7 +433,7 @@ pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<( Some(acceptor) => { match acceptor.accept(stream).await { Ok(tls_stream) => { - if let Err(e) = handle_client_tls(tls_stream, state).await { + if let Err(e) = handle_client(tls_stream, state).await { warn!("Client {} TLS error: {}", peer_addr, e); } } @@ -451,6 +518,17 @@ fn verify_device_uid(device_uid: &Option, msg_type: &str, claimed_uid: & } } +/// Constant-time string comparison to prevent timing attacks on secrets. +fn constant_time_eq(a: &str, b: &str) -> bool { + use std::iter; + if a.len() != b.len() { + // Still do a comparison to avoid leaking length via timing + let _ = a.as_bytes().iter().zip(iter::repeat(0u8)).map(|(x, y)| x ^ y); + return false; + } + a.as_bytes().iter().zip(b.as_bytes()).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0 +} + /// Process a single decoded frame. Shared by both plaintext and TLS handlers. /// `hmac_fail_count` tracks consecutive HMAC failures; caller checks it for disconnect threshold. async fn process_frame( @@ -467,10 +545,10 @@ async fn process_frame( info!("Device registration attempt: {} ({})", req.hostname, req.device_uid); - // Validate registration token against configured token + // Validate registration token against configured token (constant-time comparison) let expected_token = &state.config.registration_token; if !expected_token.is_empty() { - if req.registration_token.is_empty() || req.registration_token != *expected_token { + if req.registration_token.is_empty() || !constant_time_eq(&req.registration_token, expected_token) { warn!("Registration rejected for {}: invalid token", req.device_uid); let err_frame = Frame::new_json(MessageType::RegisterResponse, &serde_json::json!({"error": "invalid_registration_token"}))?; @@ -514,7 +592,7 @@ async fn process_frame( *device_uid = Some(req.device_uid.clone()); // If this device was already connected on a different session, evict the old one // The new register() call will replace it in the hashmap - state.clients.register(req.device_uid.clone(), tx.clone()).await; + state.clients.register(req.device_uid.clone(), Some(device_secret.clone()), tx.clone()).await; // Send registration response let config = csm_protocol::ClientConfig::default(); @@ -539,17 +617,25 @@ async fn process_frame( return Ok(()); } - // Verify HMAC — reject if secret exists but HMAC is missing or wrong - let secret: Option = sqlx::query_scalar( - "SELECT device_secret FROM devices WHERE device_uid = ?" - ) - .bind(&heartbeat.device_uid) - .fetch_optional(&state.db) - .await - .map_err(|e| { - warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e); - anyhow::anyhow!("DB error during HMAC verification") - })?; + // Verify HMAC — use cached secret from ClientRegistry, fall back to DB on cache miss (e.g. after restart) + let mut secret = state.clients.get_secret(&heartbeat.device_uid).await; + if secret.is_none() { + // Cache miss (server restarted) — query DB and backfill cache + let db_secret: Option = sqlx::query_scalar( + "SELECT device_secret FROM devices WHERE device_uid = ?" + ) + .bind(&heartbeat.device_uid) + .fetch_optional(&state.db) + .await + .map_err(|e| { + warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e); + anyhow::anyhow!("DB error during HMAC verification") + })?; + if let Some(ref s) = db_secret { + state.clients.set_secret(&heartbeat.device_uid, s.clone()).await; + } + secret = db_secret; + } if let Some(ref secret) = secret { if !secret.is_empty() { @@ -650,6 +736,40 @@ async fn process_frame( crate::db::DeviceRepo::upsert_software(&state.db, &sw).await?; } + MessageType::AssetChange => { + let change: csm_protocol::AssetChange = frame.decode_payload() + .map_err(|e| anyhow::anyhow!("Invalid asset change: {}", e))?; + + if !verify_device_uid(device_uid, "AssetChange", &change.device_uid) { + return Ok(()); + } + + let change_type_str = match change.change_type { + csm_protocol::AssetChangeType::Hardware => "hardware", + csm_protocol::AssetChangeType::SoftwareAdded => "software_added", + csm_protocol::AssetChangeType::SoftwareRemoved => "software_removed", + }; + + sqlx::query( + "INSERT INTO asset_changes (device_uid, change_type, change_detail, detected_at) \ + VALUES (?, ?, ?, datetime('now'))" + ) + .bind(&change.device_uid) + .bind(change_type_str) + .bind(serde_json::to_string(&change.change_detail).map_err(|e| anyhow::anyhow!("Failed to serialize asset change detail: {}", e))?) + .execute(&state.db) + .await + .map_err(|e| anyhow::anyhow!("DB error inserting asset change: {}", e))?; + + debug!("Asset change: {} {:?} for device {}", change_type_str, change.change_detail, change.device_uid); + + state.ws_hub.broadcast(serde_json::json!({ + "type": "asset_change", + "device_uid": change.device_uid, + "change_type": change_type_str, + }).to_string()).await; + } + MessageType::UsageReport => { let report: csm_protocol::UsageDailyReport = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?; @@ -910,23 +1030,86 @@ async fn process_frame( return Ok(()); } - for rule_stat in &stats.rule_stats { - sqlx::query( - "INSERT INTO popup_block_stats (device_uid, rule_id, blocked_count, period_secs, reported_at) \ - VALUES (?, ?, ?, ?, datetime('now'))" - ) - .bind(&stats.device_uid) - .bind(rule_stat.rule_id) - .bind(rule_stat.hits as i32) - .bind(stats.period_secs as i32) - .execute(&state.db) - .await - .ok(); - } + // Upsert aggregate stats per device per day + let today = chrono::Utc::now().format("%Y-%m-%d").to_string(); + sqlx::query( + "INSERT INTO popup_block_stats (device_uid, blocked_count, date) \ + VALUES (?, ?, ?) \ + ON CONFLICT(device_uid, date) DO UPDATE SET \ + blocked_count = blocked_count + excluded.blocked_count" + ) + .bind(&stats.device_uid) + .bind(stats.blocked_count as i32) + .bind(&today) + .execute(&state.db) + .await + .ok(); debug!("Popup block stats: {} blocked {} windows in {}s", stats.device_uid, stats.blocked_count, stats.period_secs); } + MessageType::PatchStatusReport => { + let payload: csm_protocol::PatchStatusPayload = frame.decode_payload() + .map_err(|e| anyhow::anyhow!("Invalid patch status report: {}", e))?; + + if !verify_device_uid(device_uid, "PatchStatusReport", &payload.device_uid) { + return Ok(()); + } + + for patch in &payload.patches { + sqlx::query( + "INSERT INTO patch_status (device_uid, kb_id, title, severity, is_installed, installed_at, updated_at) \ + VALUES (?, ?, ?, ?, ?, ?, datetime('now')) \ + ON CONFLICT(device_uid, kb_id) DO UPDATE SET \ + title = excluded.title, severity = COALESCE(excluded.severity, patch_status.severity), \ + is_installed = excluded.is_installed, installed_at = excluded.installed_at, \ + updated_at = datetime('now')" + ) + .bind(&payload.device_uid) + .bind(&patch.kb_id) + .bind(&patch.title) + .bind(&patch.severity) + .bind(patch.is_installed as i32) + .bind(&patch.installed_at) + .execute(&state.db) + .await + .map_err(|e| anyhow::anyhow!("DB error inserting patch status: {}", e))?; + } + + info!("Patch status reported: {} ({} patches)", payload.device_uid, payload.patches.len()); + } + + MessageType::BehaviorMetricsReport => { + let metrics: csm_protocol::BehaviorMetricsPayload = frame.decode_payload() + .map_err(|e| anyhow::anyhow!("Invalid behavior metrics: {}", e))?; + + if !verify_device_uid(device_uid, "BehaviorMetricsReport", &metrics.device_uid) { + return Ok(()); + } + + sqlx::query( + "INSERT INTO behavior_metrics (device_uid, clipboard_ops_count, clipboard_ops_night, print_jobs_count, usb_file_ops_count, new_processes_count, period_secs, reported_at) \ + VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))" + ) + .bind(&metrics.device_uid) + .bind(metrics.clipboard_ops_count as i32) + .bind(metrics.clipboard_ops_night as i32) + .bind(metrics.print_jobs_count as i32) + .bind(metrics.usb_file_ops_count as i32) + .bind(metrics.new_processes_count as i32) + .bind(metrics.period_secs as i32) + .execute(&state.db) + .await + .map_err(|e| anyhow::anyhow!("DB error inserting behavior metrics: {}", e))?; + + // Run anomaly detection inline + crate::anomaly::check_anomalies(&state.db, &state.ws_hub, &metrics).await; + + debug!("Behavior metrics saved: {} (clipboard={}, print={}, usb_file={}, procs={})", + metrics.device_uid, metrics.clipboard_ops_count, metrics.print_jobs_count, + metrics.usb_file_ops_count, metrics.new_processes_count); + } + _ => { debug!("Unhandled message type: {:?}", frame.msg_type); } @@ -935,13 +1118,14 @@ async fn process_frame( Ok(()) } -/// Handle a single client TCP connection -async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> { +/// Handle a single client TCP connection (plaintext or TLS) +async fn handle_client(stream: S, state: AppState) -> anyhow::Result<()> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ use tokio::io::{AsyncReadExt, AsyncWriteExt}; - let _ = stream.set_nodelay(true); - - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = tokio::io::split(stream); let (tx, mut rx) = tokio::sync::mpsc::channel::>(256); let tx = Arc::new(tx); @@ -1018,81 +1202,50 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> Ok(()) } -/// Handle a TLS-wrapped client connection -async fn handle_client_tls( - stream: tokio_rustls::server::TlsStream, - state: AppState, -) -> anyhow::Result<()> { - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - let (mut reader, mut writer) = tokio::io::split(stream); - let (tx, mut rx) = tokio::sync::mpsc::channel::>(256); - let tx = Arc::new(tx); - - let mut buffer = vec![0u8; 65536]; - let mut read_buf = Vec::with_capacity(65536); - let mut device_uid: Option = None; - let mut rate_limiter = RateLimiter::new(); - let hmac_fail_count = Arc::new(AtomicU32::new(0)); - - let write_task = tokio::spawn(async move { - while let Some(data) = rx.recv().await { - if writer.write_all(&data).await.is_err() { - break; - } +/// Push a TLS certificate rotation notice to all online devices. +/// Computes the fingerprint of the new certificate and sends ConfigUpdate(TlsCertRotate). +pub async fn push_tls_cert_rotation(clients: &ClientRegistry, new_cert_pem: &[u8], valid_until: &str) -> usize { + // Compute SHA-256 fingerprint of the new certificate + let certs: Vec<_> = match rustls_pemfile::certs(&mut &new_cert_pem[..]).collect::, _>>() { + Ok(c) => c, + Err(e) => { + warn!("Failed to parse new certificate for rotation: {:?}", e); + return 0; } - }); + }; - // Reader loop with idle timeout - 'reader: loop { - let read_result = tokio::time::timeout( - std::time::Duration::from_secs(IDLE_TIMEOUT_SECS), - reader.read(&mut buffer), - ).await; + let end_entity = match certs.first() { + Some(c) => c, + None => { + warn!("No certificates found in PEM for rotation"); + return 0; + } + }; - let n = match read_result { - Ok(Ok(0)) => break, - Ok(Ok(n)) => n, - Ok(Err(e)) => return Err(e.into()), - Err(_) => { - warn!("Idle timeout for TLS device {:?}, disconnecting", device_uid); - break; - } + let fingerprint = { + use sha2::{Sha256, Digest}; + let mut hasher = Sha256::new(); + hasher.update(end_entity.as_ref()); + hex::encode(hasher.finalize()) + }; + + info!("Pushing TLS cert rotation: new fingerprint={}... valid_until={}", &fingerprint[..16], valid_until); + + let config_update = csm_protocol::ConfigUpdateType::TlsCertRotate { + new_cert_hash: fingerprint, + valid_until: valid_until.to_string(), + }; + + let online = clients.list_online().await; + let mut pushed = 0usize; + for uid in &online { + let frame = match Frame::new_json(MessageType::ConfigUpdate, &config_update) { + Ok(f) => f, + Err(_) => continue, }; - read_buf.extend_from_slice(&buffer[..n]); - - if read_buf.len() > MAX_READ_BUF_SIZE { - warn!("TLS connection exceeded max buffer size, dropping"); - break; - } - - while let Some(frame) = Frame::decode(&read_buf)? { - let frame_size = frame.encoded_size(); - read_buf.drain(..frame_size); - - if frame.version != PROTOCOL_VERSION { - warn!("Unsupported protocol version: 0x{:02X}", frame.version); - continue; - } - - if !rate_limiter.check() { - warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid); - break 'reader; - } - - if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await { - warn!("Frame processing error: {}", e); - } - - // Disconnect if too many consecutive HMAC failures - if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES { - warn!("Too many HMAC failures for TLS device {:?}, disconnecting", device_uid); - break 'reader; - } + if clients.send_to(uid, frame.encode()).await { + pushed += 1; } } - - cleanup_on_disconnect(&state, &device_uid).await; - write_task.abort(); - Ok(()) + pushed } diff --git a/crates/server/src/ws.rs b/crates/server/src/ws.rs index 6c93aa1..233e040 100644 --- a/crates/server/src/ws.rs +++ b/crates/server/src/ws.rs @@ -1,12 +1,10 @@ use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message}; use axum::response::IntoResponse; use axum::extract::Query; -use jsonwebtoken::{decode, Validation, DecodingKey}; use serde::Deserialize; use tokio::sync::broadcast; use std::sync::Arc; use tracing::{debug, warn}; -use crate::api::auth::Claims; use crate::AppState; /// WebSocket hub for broadcasting real-time events to admin browsers @@ -32,65 +30,73 @@ impl WsHub { } } -#[derive(Debug, Deserialize)] -pub struct WsAuthParams { - pub token: Option, +/// Claim stored when a WS ticket is created. Consumed on WS connection. +#[derive(Debug, Clone)] +pub struct TicketClaim { + pub user_id: i64, + pub username: String, + pub role: String, + pub created_at: std::time::Instant, } -/// HTTP upgrade handler for WebSocket connections -/// Validates JWT token from query parameter before upgrading +#[derive(Debug, Deserialize)] +pub struct WsTicketParams { + pub ticket: Option, +} + +/// HTTP upgrade handler for WebSocket connections. +/// Validates a one-time ticket (obtained via POST /api/ws/ticket) before upgrading. pub async fn ws_handler( ws: WebSocketUpgrade, - Query(params): Query, + Query(params): Query, axum::extract::State(state): axum::extract::State, ) -> impl IntoResponse { - let token = match params.token { + let ticket = match params.ticket { Some(t) => t, None => { - warn!("WebSocket connection rejected: no token provided"); - return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response(); + warn!("WebSocket connection rejected: no ticket provided"); + return (axum::http::StatusCode::UNAUTHORIZED, "Missing ticket").into_response(); } }; - let claims = match decode::( - &token, - &DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()), - &Validation::default(), - ) { - Ok(c) => c.claims, - Err(e) => { - warn!("WebSocket connection rejected: invalid token - {}", e); - return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response(); + // Consume (remove) the ticket from the store — single use + let claim = { + let mut tickets = state.ws_tickets.lock().await; + match tickets.remove(&ticket) { + Some(claim) => claim, + None => { + warn!("WebSocket connection rejected: invalid or expired ticket"); + return (axum::http::StatusCode::UNAUTHORIZED, "Invalid or expired ticket").into_response(); + } } }; - if claims.token_type != "access" { - warn!("WebSocket connection rejected: not an access token"); - return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response(); + // Check ticket age (30 second TTL) + if claim.created_at.elapsed().as_secs() > 30 { + warn!("WebSocket connection rejected: ticket expired"); + return (axum::http::StatusCode::UNAUTHORIZED, "Ticket expired").into_response(); } let hub = state.ws_hub.clone(); - ws.on_upgrade(move |socket| handle_socket(socket, claims, hub)) + ws.on_upgrade(move |socket| handle_socket(socket, claim, hub)) } -async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc) { - debug!("WebSocket client connected: user={}", claims.username); +async fn handle_socket(mut socket: WebSocket, claim: TicketClaim, hub: Arc) { + debug!("WebSocket client connected: user={}", claim.username); let welcome = serde_json::json!({ "type": "connected", "message": "CSM real-time feed active", - "user": claims.username + "user": claim.username }); if socket.send(Message::Text(welcome.to_string())).await.is_err() { return; } - // Subscribe to broadcast hub for real-time events let mut rx = hub.subscribe(); loop { tokio::select! { - // Forward broadcast messages to WebSocket client msg = rx.recv() => { match msg { Ok(text) => { @@ -104,7 +110,6 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc) { Err(broadcast::error::RecvError::Closed) => break, } } - // Handle incoming WebSocket messages (ping/close) msg = socket.recv() => { match msg { Some(Ok(Message::Ping(data))) => { @@ -121,5 +126,5 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc) { } } - debug!("WebSocket client disconnected: user={}", claims.username); + debug!("WebSocket client disconnected: user={}", claim.username); } diff --git a/migrations/017_device_health_scores.sql b/migrations/017_device_health_scores.sql new file mode 100644 index 0000000..af9fa77 --- /dev/null +++ b/migrations/017_device_health_scores.sql @@ -0,0 +1,20 @@ +-- 017_device_health_scores.sql: Device health scoring system + +CREATE TABLE IF NOT EXISTS device_health_scores ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE, + score INTEGER NOT NULL DEFAULT 0 CHECK(score >= 0 AND score <= 100), + status_score INTEGER NOT NULL DEFAULT 0, + encryption_score INTEGER NOT NULL DEFAULT 0, + load_score INTEGER NOT NULL DEFAULT 0, + alert_score INTEGER NOT NULL DEFAULT 0, + compliance_score INTEGER NOT NULL DEFAULT 0, + patch_score INTEGER NOT NULL DEFAULT 0, + level TEXT NOT NULL DEFAULT 'unknown' CHECK(level IN ('healthy', 'warning', 'critical', 'unknown')), + details TEXT, + computed_at TEXT NOT NULL DEFAULT (datetime('now')), + UNIQUE(device_uid) +); + +CREATE INDEX IF NOT EXISTS idx_health_scores_level ON device_health_scores(level); +CREATE INDEX IF NOT EXISTS idx_health_scores_computed ON device_health_scores(computed_at); diff --git a/migrations/018_patch_management.sql b/migrations/018_patch_management.sql new file mode 100644 index 0000000..08c6c6f --- /dev/null +++ b/migrations/018_patch_management.sql @@ -0,0 +1,59 @@ +-- 018_patch_management.sql: Patch management system + +CREATE TABLE IF NOT EXISTS patch_status ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE, + kb_id TEXT NOT NULL, + title TEXT NOT NULL, + severity TEXT, + is_installed INTEGER NOT NULL DEFAULT 0, + discovered_at TEXT NOT NULL DEFAULT (datetime('now')), + installed_at TEXT, + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + UNIQUE(device_uid, kb_id) +); + +CREATE TABLE IF NOT EXISTS patch_policies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'device', 'group')), + target_id TEXT, + auto_approve INTEGER NOT NULL DEFAULT 0, + severity_filter TEXT NOT NULL DEFAULT 'important', + enabled INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- Behavior metrics for anomaly detection +CREATE TABLE IF NOT EXISTS behavior_metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE, + clipboard_ops_count INTEGER NOT NULL DEFAULT 0, + clipboard_ops_night INTEGER NOT NULL DEFAULT 0, + print_jobs_count INTEGER NOT NULL DEFAULT 0, + usb_file_ops_count INTEGER NOT NULL DEFAULT 0, + new_processes_count INTEGER NOT NULL DEFAULT 0, + period_secs INTEGER NOT NULL DEFAULT 3600, + reported_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- Anomaly alerts generated by the detection engine +CREATE TABLE IF NOT EXISTS anomaly_alerts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE, + anomaly_type TEXT NOT NULL, + severity TEXT NOT NULL DEFAULT 'medium' CHECK(severity IN ('low', 'medium', 'high', 'critical')), + detail TEXT NOT NULL, + metric_value REAL, + baseline_value REAL, + handled INTEGER NOT NULL DEFAULT 0, + handled_by TEXT, + handled_at TEXT, + triggered_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_patch_status_device ON patch_status(device_uid); +CREATE INDEX IF NOT EXISTS idx_patch_status_severity ON patch_status(severity, is_installed); +CREATE INDEX IF NOT EXISTS idx_behavior_metrics_device_time ON behavior_metrics(device_uid, reported_at); +CREATE INDEX IF NOT EXISTS idx_anomaly_alerts_device ON anomaly_alerts(device_uid); +CREATE INDEX IF NOT EXISTS idx_anomaly_alerts_unhandled ON anomaly_alerts(handled) WHERE handled = 0; diff --git a/migrations/019_software_whitelist.sql b/migrations/019_software_whitelist.sql new file mode 100644 index 0000000..0ed0d88 --- /dev/null +++ b/migrations/019_software_whitelist.sql @@ -0,0 +1,54 @@ +-- Software whitelist: processes that should NEVER be blocked even if matched by blacklist rules. +-- This provides a safety net to prevent false positives from killing legitimate applications. +CREATE TABLE IF NOT EXISTS software_whitelist ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name_pattern TEXT NOT NULL, + reason TEXT, + is_builtin INTEGER NOT NULL DEFAULT 0, -- 1 = system default, 0 = admin-added + enabled INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +-- Built-in whitelist entries for common safe applications +INSERT INTO software_whitelist (name_pattern, reason, is_builtin) VALUES + -- Browsers + ('chrome.exe', 'Google Chrome browser', 1), + ('msedge.exe', 'Microsoft Edge browser', 1), + ('firefox.exe', 'Mozilla Firefox browser', 1), + ('iexplore.exe', 'Internet Explorer', 1), + ('opera.exe', 'Opera browser', 1), + ('brave.exe', 'Brave browser', 1), + ('vivaldi.exe', 'Vivaldi browser', 1), + -- Development tools & IDEs + ('code.exe', 'Visual Studio Code', 1), + ('devenv.exe', 'Visual Studio', 1), + ('idea64.exe', 'IntelliJ IDEA', 1), + ('webstorm64.exe', 'WebStorm', 1), + ('pycharm64.exe', 'PyCharm', 1), + ('goland64.exe', 'GoLand', 1), + ('clion64.exe', 'CLion', 1), + ('rider64.exe', 'Rider', 1), + ('trae.exe', 'Trae IDE', 1), + ('windsurf.exe', 'Windsurf IDE', 1), + ('cursor.exe', 'Cursor IDE', 1), + -- Office & productivity + ('winword.exe', 'Microsoft Word', 1), + ('excel.exe', 'Microsoft Excel', 1), + ('powerpnt.exe', 'Microsoft PowerPoint', 1), + ('outlook.exe', 'Microsoft Outlook', 1), + ('onenote.exe', 'Microsoft OneNote', 1), + ('teams.exe', 'Microsoft Teams', 1), + ('wps.exe', 'WPS Office', 1), + -- Terminal & system tools + ('cmd.exe', 'Command Prompt', 1), + ('powershell.exe', 'PowerShell', 1), + ('pwsh.exe', 'PowerShell Core', 1), + ('WindowsTerminal.exe', 'Windows Terminal', 1), + -- Communication + ('wechat.exe', 'WeChat', 1), + ('dingtalk.exe', 'DingTalk', 1), + ('feishu.exe', 'Feishu/Lark', 1), + ('qq.exe', 'QQ', 1), + ('tim.exe', 'Tencent TIM', 1), + -- CSM + ('csm-client.exe', 'CSM Client itself', 1); diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index f2cad51..a8ac45d 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -1,5 +1,7 @@ /** - * Shared API client with authentication and error handling + * Shared API client with cookie-based authentication. + * Tokens are managed via HttpOnly cookies set by the server — + * the frontend never reads or stores JWT tokens. */ const API_BASE = import.meta.env.VITE_API_BASE || '' @@ -21,43 +23,37 @@ export class ApiError extends Error { } } -function getToken(): string | null { - const token = localStorage.getItem('token') - if (!token || token.trim() === '') return null - return token -} - -function clearAuth() { - localStorage.removeItem('token') - localStorage.removeItem('refresh_token') - window.location.href = '/login' -} - let refreshPromise: Promise | null = null +/** Cached user info from /api/auth/me */ +let cachedUser: { id: number; username: string; role: string } | null = null + +export function getCachedUser() { + return cachedUser +} + +export function clearCachedUser() { + cachedUser = null +} + async function tryRefresh(): Promise { - // Coalesce concurrent refresh attempts if (refreshPromise) return refreshPromise refreshPromise = (async () => { - const refreshToken = localStorage.getItem('refresh_token') - if (!refreshToken || refreshToken.trim() === '') return false - try { const response = await fetch(`${API_BASE}/api/auth/refresh`, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ refresh_token: refreshToken }), + credentials: 'same-origin', }) if (!response.ok) return false const result = await response.json() - if (!result.success || !result.data?.access_token) return false + if (!result.success) return false - localStorage.setItem('token', result.data.access_token) - if (result.data.refresh_token) { - localStorage.setItem('refresh_token', result.data.refresh_token) + // Update cached user from refresh response + if (result.data?.user) { + cachedUser = result.data.user } return true } catch { @@ -74,13 +70,8 @@ async function request( path: string, options: RequestInit = {}, ): Promise { - const token = getToken() const headers = new Headers(options.headers || {}) - if (token) { - headers.set('Authorization', `Bearer ${token}`) - } - if (options.body && typeof options.body === 'string') { headers.set('Content-Type', 'application/json') } @@ -88,18 +79,21 @@ async function request( const response = await fetch(`${API_BASE}${path}`, { ...options, headers, + credentials: 'same-origin', }) // Handle 401 - try refresh before giving up if (response.status === 401) { const refreshed = await tryRefresh() if (refreshed) { - // Retry the original request with new token - const newToken = getToken() - headers.set('Authorization', `Bearer ${newToken}`) - const retryResponse = await fetch(`${API_BASE}${path}`, { ...options, headers }) + const retryResponse = await fetch(`${API_BASE}${path}`, { + ...options, + headers, + credentials: 'same-origin', + }) if (retryResponse.status === 401) { - clearAuth() + clearCachedUser() + window.location.href = '/login' throw new ApiError(401, 'UNAUTHORIZED', 'Session expired') } const retryContentType = retryResponse.headers.get('content-type') @@ -112,7 +106,8 @@ async function request( } return retryResult.data as T } - clearAuth() + clearCachedUser() + window.location.href = '/login' throw new ApiError(401, 'UNAUTHORIZED', 'Session expired') } @@ -159,11 +154,12 @@ export const api = { return request(path, { method: 'DELETE' }) }, - /** Login doesn't use the auth header */ - async login(username: string, password: string): Promise<{ access_token: string; refresh_token: string; user: { id: number; username: string; role: string } }> { + /** Login — server sets HttpOnly cookies, we only get user info back */ + async login(username: string, password: string): Promise<{ user: { id: number; username: string; role: string } }> { const response = await fetch(`${API_BASE}/api/auth/login`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, + credentials: 'same-origin', body: JSON.stringify({ username, password }), }) @@ -172,12 +168,27 @@ export const api = { throw new ApiError(response.status, 'LOGIN_FAILED', result.error || 'Login failed') } - localStorage.setItem('token', result.data.access_token) - localStorage.setItem('refresh_token', result.data.refresh_token) + cachedUser = result.data.user return result.data }, - logout() { - clearAuth() + /** Logout — server clears cookies */ + async logout(): Promise { + try { + await fetch(`${API_BASE}/api/auth/logout`, { + method: 'POST', + credentials: 'same-origin', + }) + } catch { + // Ignore errors during logout + } + clearCachedUser() + }, + + /** Check current auth status via /api/auth/me */ + async me(): Promise<{ user: { id: number; username: string; role: string }; expires_at: string }> { + const result = await request<{ user: { id: number; username: string; role: string }; expires_at: string }>('/api/auth/me') + cachedUser = (result as { user: { id: number; username: string; role: string } }).user + return result }, } diff --git a/web/src/router/index.ts b/web/src/router/index.ts index e88d76c..a6025d1 100644 --- a/web/src/router/index.ts +++ b/web/src/router/index.ts @@ -27,40 +27,43 @@ const router = createRouter({ { path: 'plugins/print-audit', name: 'PrintAudit', component: () => import('../views/plugins/PrintAudit.vue') }, { path: 'plugins/clipboard-control', name: 'ClipboardControl', component: () => import('../views/plugins/ClipboardControl.vue') }, { path: 'plugins/plugin-control', name: 'PluginControl', component: () => import('../views/plugins/PluginControl.vue') }, + { path: 'plugins/patch', name: 'PatchManagement', component: () => import('../views/plugins/PatchManagement.vue') }, + { path: 'plugins/anomaly', name: 'AnomalyDetection', component: () => import('../views/plugins/AnomalyDetection.vue') }, ], }, ], }) -/** Check if a JWT token is structurally valid and not expired */ -function isTokenValid(token: string): boolean { - if (!token || token.trim() === '') return false - try { - const parts = token.split('.') - if (parts.length !== 3) return false - const payload = JSON.parse(atob(parts[1])) - if (!payload.exp) return false - // Reject if token expires within 30 seconds - return payload.exp * 1000 > Date.now() + 30_000 - } catch { - return false - } -} +/** Track whether we've already validated auth this session */ +let authChecked = false -router.beforeEach((to, _from, next) => { +router.beforeEach(async (to, _from, next) => { if (to.path === '/login') { next() return } - const token = localStorage.getItem('token') - if (!token || !isTokenValid(token)) { - localStorage.removeItem('token') - localStorage.removeItem('refresh_token') - next('/login') - } else { + // If we've already verified auth this session, allow navigation + // (cookies are sent automatically, no need to check on every route change) + if (authChecked) { next() + return + } + + // Check auth status via /api/auth/me (reads access_token cookie) + try { + const { me } = await import('../lib/api') + await me() + authChecked = true + next() + } catch { + next('/login') } }) +/** Reset auth check flag (called after logout) */ +export function resetAuthCheck() { + authChecked = false +} + export default router diff --git a/web/src/views/Dashboard.vue b/web/src/views/Dashboard.vue index dc8d3d9..4162b8e 100644 --- a/web/src/views/Dashboard.vue +++ b/web/src/views/Dashboard.vue @@ -41,6 +41,32 @@
USB事件(24h)
+
+
+ {{ healthAvg }} +
+
+
{{ healthAvg }}
+
健康评分
+
+
+ + + +
+
+ {{ healthSummary.healthy }} 健康 +
+
+ {{ healthSummary.warning }} 告警 +
+
+ {{ healthSummary.critical }} 严重 +
+
+ {{ healthSummary.unknown }} 未知 +
+
策略冲突: {{ conflictCount }} 项
@@ -138,7 +164,7 @@ diff --git a/web/src/views/Settings.vue b/web/src/views/Settings.vue index e571c6c..1bc34db 100644 --- a/web/src/views/Settings.vue +++ b/web/src/views/Settings.vue @@ -85,7 +85,7 @@ + + diff --git a/web/src/views/plugins/PatchManagement.vue b/web/src/views/plugins/PatchManagement.vue new file mode 100644 index 0000000..982d836 --- /dev/null +++ b/web/src/views/plugins/PatchManagement.vue @@ -0,0 +1,92 @@ + + + + + diff --git a/wiki/SECURITY-AUDIT.md b/wiki/SECURITY-AUDIT.md new file mode 100644 index 0000000..8fc5abd --- /dev/null +++ b/wiki/SECURITY-AUDIT.md @@ -0,0 +1,255 @@ +# CSM 安全审计报告 + +> **审计日期**: 2026-04-11 | **审计范围**: 全系统 (Server + Client + Protocol + Frontend) +> **方法论**: OWASP Top 10 (2021), CWE Top 25, 手动源码审查 + 攻击者视角分析 + +--- + +## 执行摘要 + +| 严重级别 | 数量 | 说明 | +|----------|------|------| +| **CRITICAL** | 4 | 可直接被远程利用,导致系统完全沦陷 | +| **HIGH** | 12 | 需特定条件但影响重大 | +| **MEDIUM** | 12 | 有限影响或需较高权限 | +| **LOW** | 8 | 纵深防御建议 | + +**最关键发现**: JWT Secret 硬编码在版本控制中、注册 Token 为空允许任意设备注册、凭据文件无 ACL 保护、默认无 TLS 传输加密。这四个问题组合意味着攻击者可以在数分钟内完全接管系统。 + +--- + +## CRITICAL (4) + +### AUD-001: JWT Secret 硬编码在 config.toml 中 + +- **文件**: `config.toml:12` → `crates/server/src/config.rs` +- **CWE**: CWE-798 (硬编码凭证) +- **OWASP**: A07:2021 - 安全配置错误 + +**漏洞代码**: +```toml +jwt_secret = "39ffc129-dd62-4eb4-bbc0-8bf4b8e2ccc7" +``` + +**攻击场景**: 任何能访问仓库的人可提取 secret,为任意用户(含 admin)伪造 JWT,获得系统完全管理控制权,可推送恶意配置到所有医院终端、禁用安全控制、篡改审计记录。 + +**修复**: +1. 从 `config.toml` 中移除硬编码 secret +2. 通过 `CSM_JWT_SECRET` 环境变量独占加载 +3. secret 为空时拒绝启动 +4. **立即轮换**已泄露的 secret `39ffc129-dd62-4eb4-bbc0-8bf4b8e2ccc7` +5. 将 `config.toml` 加入 `.gitignore` + +--- + +### AUD-002: 空 Registration Token — 任意设备注册 + +- **文件**: `config.toml:1`, `crates/server/src/tcp.rs:549-558` +- **CWE**: CWE-306 (关键功能缺失认证) +- **OWASP**: A07:2021 + +**漏洞代码**: +```toml +registration_token = "" +``` +```rust +if !expected_token.is_empty() { // 空字符串直接跳过验证 +``` + +**攻击场景**: TCP 端口 9999 可达的任何攻击者可注册恶意设备,注入伪造审计数据掩盖安全事件,获取所有插件配置(含安全策略)。 + +**修复**: 设置强 registration token,空 token 时拒绝启动。 + +--- + +### AUD-003: Windows 凭据文件无 ACL 保护 + +- **文件**: `crates/client/src/main.rs:280-283` +- **CWE**: CWE-732 (关键资源权限不当) +- **OWASP**: A01:2021 + +**漏洞代码**: +```rust +#[cfg(not(unix))] +fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> { + std::fs::write(path, content) // 无任何 ACL 设置 +} +``` + +**攻击场景**: `device_secret.txt`(HMAC 密钥)和 `device_uid.txt`(设备身份)以默认权限写入,设备上任何用户进程可读取。攻击者可提取密钥伪造心跳,在不同机器上模拟设备。 + +**修复**: 使用 `icacls` 或 `SetSecurityInfo` 设置仅 SYSTEM 可访问的 ACL。 + +--- + +### AUD-004: 默认无 TLS — 明文传输所有敏感数据 + +- **文件**: `config.toml` (无 `[server.tls]`), `crates/server/src/tcp.rs:411-414` +- **CWE**: CWE-319 (明文传输敏感信息) +- **OWASP**: A02:2021 + +**攻击场景**: TCP 端口 9999 以明文运行。`device_secret`、所有插件配置、Web 过滤规则、USB 策略、软件黑名单均以明文 JSON 传输。网络嗅探者可提取设备认证密钥并逆向工程安全策略。 + +**修复**: 生产环境强制 TLS,无 TLS 时拒绝启动(`CSM_DEV=1` 除外)。 + +--- + +## HIGH (12) + +### AUD-005: Refresh Token 未存储 — 撤销机制不完整 + +- **文件**: `crates/server/src/api/auth.rs:131-133` +- **CWE**: CWE-613 +- `refresh_tokens` 表存在但从未写入。登录不存储 token record,刷新仅检查 family 是否撤销,不验证 token 是否曾被实际颁发。无法强制注销所有 session。 + +### AUD-006: Refresh Token Family Rotation 存在 TOCTOU 竞争条件 + +- **文件**: `crates/server/src/api/auth.rs:167-183` +- **CWE**: CWE-367 +- 检查 family 撤销状态与执行撤销之间无事务保护。并发使用同一 stolen token 的两个请求均可通过检查,攻击者获得全新的未撤销 token family。 + +### AUD-007: 客户端不验证服务器身份 + +- **文件**: `crates/client/src/network/mod.rs:149-162` +- **CWE**: CWE-295 +- 客户端盲目连接任何响应的服务器。ARP/DNS 欺骗攻击者可推送恶意配置(禁用所有安全插件、注入有害规则)。HMAC 仅保护心跳,不保护配置推送。 + +### AUD-008: PowerShell 命令注入面 + +- **文件**: `crates/client/src/asset/mod.rs:82-83`, `crates/client/src/clipboard_control/mod.rs:143-159` +- **CWE**: CWE-78 +- `powershell_lines()` 通过 `format!()` 拼接命令参数。若服务器推送含引号/转义字符的恶意规则,可能导致 PowerShell 命令注入。 + +### AUD-009: 服务停止/卸载未受保护 + +- **文件**: `crates/client/src/service.rs:17-69` +- **CWE**: CWE-284 +- `csm-client.exe --uninstall` 无认证保护。终端管理员权限用户可完全移除安全代理,绕过所有监控。无服务恢复策略、无看门狗进程、无反调试保护。 + +### AUD-010: JWT Token 存储在 localStorage + +- **文件**: `web/src/lib/api.ts:25-28, 175-176` +- **CWE**: CWE-922 +- Access token 和 refresh token 均存储在 `localStorage`,可被同源任意 JS 访问。XSS 漏洞可直接窃取 7 天有效期的 refresh token。 + +### AUD-011: CSP 允许 unsafe-inline + unsafe-eval + +- **文件**: `crates/server/src/main.rs:142-143` +- **CWE**: CWE-693 +- `script-src 'self' 'unsafe-inline' 'unsafe-eval'` 使 CSP 对 XSS 几乎无效。结合 localStorage token 存储,单个 XSS 即可导致管理员会话完全沦陷。 + +### AUD-012: WebSocket JWT 在 URL 查询参数中 + +- **文件**: `crates/server/src/ws.rs:36-73` +- **CWE**: CWE-312 +- JWT 通过 `/ws?token=eyJ...` 传输。Token 出现在浏览器历史、服务器访问日志、代理日志中。且 WebSocket handler 不检查用户角色,非管理员可接收所有广播事件。 + +### AUD-013: 告警规则 Webhook SSRF + +- **文件**: `crates/server/src/api/alerts.rs:115-131` +- **CWE**: CWE-918 +- `notify_webhook` 字段无 URL 验证。可设置为 `http://169.254.169.254/latest/meta-data/` (AWS 元数据) 或 `file:///etc/passwd`,将服务器变成 SSRF 代理。 + +### AUD-014: 仅基于用户名的速率限制可绕过 + +- **文件**: `crates/server/src/api/auth.rs:101` +- **CWE**: CWE-307 +- 速率限制仅以用户名为 key。攻击者可用 `Admin`、`ADMIN` 等变体绕过。无 IP 限制,可分布式暴力破解。 + +### AUD-015: 磁盘加密确认在只读路由层 — 权限提升 + +- **文件**: `crates/server/src/api/plugins/mod.rs:42` +- **CWE**: CWE-862 +- PUT `acknowledge_alert` 在 `read_routes()` 中(仅需认证,不需 admin)。任何认证用户可确认(忽略)加密告警,掩盖合规违规。 + +### AUD-016: 初始管理员密码输出到 stderr + +- **文件**: `crates/server/src/main.rs:270-275` +- **CWE**: CWE-532 +- 初始密码通过 `eprintln!` 输出。容器化部署中 stderr 被日志聚合系统捕获,有日志访问权限者可获取管理员密码。 + +--- + +## MEDIUM (12) + +| # | 发现 | 文件 | CWE | +|---|------|------|-----| +| AUD-017 | 多个 Update handler 跳过输入验证 (软件黑名单/Web过滤器/剪贴板) | `software_blocker.rs:79`, `web_filter.rs:62`, `clipboard_control.rs:107` | CWE-20 | +| AUD-018 | USB 策略 rules 字段接受任意 JSON 无验证 | `usb.rs:94-135` | CWE-20 | +| AUD-019 | 密码无最大长度限制 (bcrypt 72 字节截断) | `auth.rs:303` | CWE-20 | +| AUD-020 | 多个字段缺少长度验证 (弹出窗口/剪贴板/USB策略名) | 多处 | CWE-20 | +| AUD-021 | 多个列表端点无分页 (黑名单/白名单/规则/策略) | 多处 | CWE-770 | +| AUD-022 | 磁盘加密状态列表无分页可全库转储 | `disk_encryption.rs:12-55` | CWE-770 | +| AUD-023 | JWT 角色仅信任 claim 不查库 (降级延迟) | `auth.rs:273` | CWE-863 | +| AUD-024 | 缺少 HSTS 头 | `main.rs:123-143` | CWE-319 | +| AUD-025 | CORS 配置需严格限制 | `main.rs:284-301` | CWE-942 | +| AUD-026 | 日志中泄露设备 UID 和服务器地址 | `main.rs:62`, `network/mod.rs:34` | CWE-532 | +| AUD-027 | 注册 Token 回退空字符串 | `main.rs:72` | CWE-254 | +| AUD-028 | conflict.rs 中 format! SQL 模式 (当前安全但脆弱) | `conflict.rs:205` | CWE-89 | + +--- + +## LOW (8) + +| # | 发现 | 文件 | +|---|------|------| +| AUD-029 | 组名未过滤 HTML 特殊字符 | `groups.rs:72-96` | +| AUD-030 | 弹出窗口阻止器更新可创建无过滤器的规则 | `popup_blocker.rs:67-104` | +| AUD-031 | 设备删除非原子 (自毁帧在事务前发送) | `devices.rs:215-306` | +| AUD-032 | 受保护进程列表硬编码且可修补绕过 | `software_blocker/mod.rs:9-73` | +| AUD-033 | hosts 文件修改可能与 EDR 冲突 | `web_filter/mod.rs:59-93` | +| AUD-034 | 软件拦截器 TOCTOU 竞争条件 (已缓解) | `software_blocker/mod.rs:329-386` | +| AUD-035 | 前端路由守卫不验证 JWT 签名 | `router/index.ts:38-49` | +| AUD-036 | WebSocket 不验证入站消息 (当前丢弃) | `ws.rs:108-119` | + +--- + +## 修复优先级 + +### P0 — 立即 (24h) + +| 修复项 | 对应发现 | 工作量 | +|--------|---------|--------| +| 轮换 JWT Secret,移至环境变量 | AUD-001 | 30min | +| 设置非空 registration_token | AUD-002 | 15min | +| 凭据文件添加 Windows ACL | AUD-003 | 1h | +| 生产环境强制 TLS | AUD-004 | 2h | + +### P1 — 短期 (1 周) + +| 修复项 | 对应发现 | 工作量 | +|--------|---------|--------| +| Refresh token 存储到 DB + 事务保护 | AUD-005, 006 | 4h | +| Update handler 添加输入验证 | AUD-017 | 4h | +| Webhook URL 验证防 SSRF | AUD-013 | 1h | +| 磁盘加密确认移至 admin 路由 | AUD-015 | 15min | +| 初始密码写入文件替代 stderr | AUD-016 | 30min | +| 添加 IP 速率限制 | AUD-014 | 2h | + +### P2 — 中期 (1 月) + +| 修复项 | 对应发现 | 工作量 | +|--------|---------|--------| +| Token 迁移至 HttpOnly Cookie | AUD-010 | 8h | +| CSP 强化 (nonce-based) | AUD-011 | 4h | +| WebSocket ticket 认证 | AUD-012 | 4h | +| 服务器身份验证 (证书固定) | AUD-007 | 8h | +| 服务保护 (恢复策略/看门狗) | AUD-009 | 4h | +| PowerShell 注入面消除 | AUD-008 | 6h | +| HSTS + Permissions-Policy 头 | AUD-024, 036 | 1h | +| 分页补充 | AUD-021, 022 | 4h | + +--- + +## 安全亮点 (做得好的地方) + +1. **SQL 注入防御**: 全库一致使用 `sqlx::bind()` 参数化 +2. **错误处理**: `ApiResponse::internal_error()` 不泄露内部错误详情 +3. **密码哈希**: bcrypt cost=12,符合行业标准 +4. **Token Family 轮换**: 检测 token 重放并撤销整个 family +5. **常量时间比较**: 注册 token 验证已使用 `constant_time_eq()` +6. **帧速率限制**: 100 帧/5秒/连接 +7. **审计日志**: 所有管理员写入操作记录到 `admin_audit_log` +8. **HMAC 心跳**: 设备认证使用 HMAC-SHA256 +9. **进程保护列表**: 防止误杀系统关键进程 +10. **输入验证**: Create handler 普遍有字段验证 diff --git a/wiki/client.md b/wiki/client.md new file mode 100644 index 0000000..0e1ad94 --- /dev/null +++ b/wiki/client.md @@ -0,0 +1,88 @@ +# Client(客户端代理) + +## 设计思想 + +`csm-client` 是部署在医院终端设备上的 Windows 代理程序,设计为: +1. **无人值守运行** — 支持控制台模式(开发调试)和 Windows 服务模式(生产部署) +2. **自动重连** — 指数退避策略(1s → 60s),断线后 drain stale frames +3. **插件化采集** — 每个插件独立 task,通过 `watch` channel 接收配置,通过 `mpsc` channel 上报数据 +4. **单入口 data channel** — 所有插件共享一个 `mpsc::channel::(1024)`,network 模块统一发送 + +关键设计决策: +- **watch + mpsc 双通道** — `watch` 用于服务器推送配置到插件(多消费者最新值),`mpsc` 用于插件上报数据到网络层(多生产者有序队列) +- **device_uid 持久化** — UUID 首次生成后写入 `device_uid.txt`,与可执行文件同目录 +- **device_secret 持久化** — 注册成功后写入 `device_secret.txt`,重启后自动认证 + +## 代码逻辑 + +### 启动流程 + +``` +main() → load device_uid → load device_secret → create ClientState + → create data channel (mpsc 1024) + → create watch channels for each plugin config + → spawn core tasks (monitor, asset, usb) + → spawn plugin tasks (11 plugins) + → reconnect loop: connect_and_run() with exponential backoff +``` + +### 网络层 (`network/mod.rs`) + +- `connect_and_run()` — TCP 连接、注册/认证、双工读写循环 +- `handle_server_message()` — 根据 MessageType 分发服务器下发的帧到对应 watch channel +- `PluginChannels` — 持有所有插件的 `watch::Sender`,用于接收服务器推送的配置 +- 注册流程:发送 Register → 收到 RegisterResponse(含 device_secret)→ 持久化 secret +- 认证流程:已有 device_secret 时,心跳帧携带 HMAC-SHA256 签名 + +### 插件统一模板 + +每个插件遵循相同模式: +```rust +pub async fn start( + mut config_rx: watch::Receiver, + data_tx: mpsc::Sender, + device_uid: String, +) { + loop { + tokio::select! { + result = config_rx.changed() => { /* 更新 config */ } + _ = interval.tick() => { + if !config.enabled { continue; } + // 采集数据 → Frame::new_json() → data_tx.send() + } + } + } +} +``` + +### 双模式运行 + +- **控制台模式**: 直接 `cargo run -p csm-client`,Ctrl+C 优雅退出 +- **服务模式**: `--install` 注册 Windows 服务、`--service` 以服务方式运行、`--uninstall` 卸载 + +## 关联模块 + +- [[protocol]] — 使用 Frame 构造上报帧,解析服务器下发帧 +- [[server]] — TCP 连接的对端,接收帧并处理 +- [[plugins]] — 每个插件的具体实现逻辑 + +## 关键文件 + +| 文件 | 职责 | +|------|------| +| `crates/client/src/main.rs` | 启动入口、插件 channel 创建、task spawn、重连循环 | +| `crates/client/src/network/mod.rs` | TCP 连接、注册认证、双工读写、服务器消息分发 | +| `crates/client/src/service.rs` | Windows 服务安装/卸载/运行(`#[cfg(target_os = "windows")]`) | +| `crates/client/src/monitor/mod.rs` | 核心设备状态采集(CPU/内存/进程) | +| `crates/client/src/asset/mod.rs` | 硬件/软件资产采集 | +| `crates/client/src/usb/mod.rs` | USB 设备插拔监控 | +| `crates/client/src/web_filter/mod.rs` | 上网拦截插件 | +| `crates/client/src/usage_timer/mod.rs` | 使用时长记录插件 | +| `crates/client/src/software_blocker/mod.rs` | 软件禁止安装插件 | +| `crates/client/src/popup_blocker/mod.rs` | 弹窗拦截插件 | +| `crates/client/src/usb_audit/mod.rs` | U盘文件操作审计插件 | +| `crates/client/src/watermark/mod.rs` | 屏幕水印插件 | +| `crates/client/src/disk_encryption/mod.rs` | 磁盘加密检测插件 | +| `crates/client/src/print_audit/mod.rs` | 打印审计插件 | +| `crates/client/src/clipboard_control/mod.rs` | 剪贴板管控插件 | +| `crates/client/src/patch/mod.rs` | 补丁管理插件 | diff --git a/wiki/database.md b/wiki/database.md new file mode 100644 index 0000000..023f7ef --- /dev/null +++ b/wiki/database.md @@ -0,0 +1,71 @@ +# Database(数据库层) + +## 设计思想 + +SQLite 单文件数据库,WAL 模式支持并发读写。设计原则: +1. **只追加迁移** — 永不修改已有 migration 文件 +2. **参数绑定** — 所有 SQL 使用 `.bind()`,绝不拼接 +3. **upsert 模式** — `ON CONFLICT ... DO UPDATE` 处理重复上报,必须更新 `updated_at` +4. **嵌入式迁移** — SQL 文件通过 `include_str!` 编译进二进制,运行时按序执行 +5. **外键启用** — `foreign_keys(true)` 强制引用完整性 + +## 代码逻辑 + +### 初始化 + +``` +main() → init_database() → SQLite WAL + Normal sync + 5s busy timeout + FK on + → run_migrations() → CREATE _migrations 表 → 按序执行 001-018 + → ensure_default_admin() → 首次启动生成随机 admin 密码 +``` + +### 连接池配置 + +- 最大 8 连接 +- cache_size = -64000 (64MB) +- wal_autocheckpoint = 1000 + +### 迁移历史 + +| # | 文件 | 内容 | +|---|------|------| +| 001 | init.sql | users, devices 表 | +| 002 | assets.sql | hardware_assets, software_assets, asset_changes 表 | +| 003 | usb.sql | usb_events, usb_policies, usb_device_patterns 表 | +| 004 | alerts.sql | alert_rules, alert_records 表 | +| 005 | web_filter.sql | web_filter_rules, web_access_logs 表 | +| 006 | usage_timer.sql | usage_daily, app_usage 表 | +| 007 | software_blocker.sql | software_blacklist, software_violations 表 | +| 008 | popup_blocker.sql | popup_blocker_rules, popup_block_stats 表 | +| 009 | usb_file_audit.sql | usb_file_operations 表 | +| 010 | watermark.sql | watermark_configs 表 | +| 011 | token_security.sql | token_families 表(JWT token family 轮换) | +| 012 | disk_encryption.sql | disk_encryption_status, disk_encryption_alerts 表 | +| 013 | print_audit.sql | print_events 表 | +| 014 | clipboard_control.sql | clipboard_rules, clipboard_violations 表 | +| 015 | plugin_control.sql | plugin_states 表 | +| 016 | encryption_alerts_unique.sql | 唯一约束修复 | +| 017 | device_health_scores.sql | device_health_scores 表 | +| 018 | patch_management.sql | patch_status 表 | + +### 数据操作层 (`db.rs`) + +`DeviceRepo` 提供: +- 设备注册/查询/删除/分组 +- 资产增删改查 +- USB 事件记录和策略管理 +- 告警规则和记录操作 +- 所有插件数据的 CRUD + +## 关联模块 + +- [[server]] — 通过 db.rs 访问数据库 +- [[plugins]] — 每个插件有对应的数据库表 + +## 关键文件 + +| 文件 | 职责 | +|------|------| +| `crates/server/src/db.rs` | DeviceRepo 数据库操作方法集合 | +| `crates/server/src/main.rs` | 数据库初始化、迁移执行 | +| `migrations/001_init.sql` ~ `018_*.sql` | 数据库迁移脚本 | diff --git a/wiki/index.md b/wiki/index.md new file mode 100644 index 0000000..ba8f74b --- /dev/null +++ b/wiki/index.md @@ -0,0 +1,46 @@ +# CSM 知识库 + +## 项目画像 + +CSM (Client Security Manager) — 医院终端安全管控平台,C/S + Web 三层架构。管理 11 个安全插件,覆盖上网拦截、U盘管控、打印审计、剪贴板管控、补丁管理等场景。 + +**关键数字**: 3 个 Rust crate + Vue 前端 | 18 个数据库迁移 | 13 个客户端插件 | ~30 个 API 端点 | 自定义 TCP 二进制协议 + +## 模块导航树 + +``` +CSM +├── [[protocol]] — 二进制协议层(Frame 编解码、MessageType、payload 定义) +├── [[server]] — 服务端(HTTP API + TCP 接入 + WebSocket + SQLite) +├── [[client]] — 客户端代理(Windows 服务、插件采集、自动重连) +├── [[web-frontend]] — Web 管理面板(Vue 3 SPA) +├── [[plugins]] — 插件体系(端到端设计、新增插件清单) +└── [[database]] — 数据库层(SQLite、迁移、操作方法) +``` + +## 核心架构决策 + +### 为什么用自定义 TCP 二进制协议而不是 HTTP? +内网环境低延迟需求,二进制帧比 HTTP 更省带宽和延迟。帧头仅 10 字节(MAGIC+VERSION+TYPE+LENGTH),payload 用 JSON 保持可调试性。 + +### 为什么插件配置用 watch channel 而不是 HTTP 轮询? +Server 主动推送配置变更到 Client,避免轮询延迟。`tokio::watch` 保证每个插件总是读到最新配置值,配置下发 → 全链路秒级生效。 + +### 为什么嵌入前端而不是独立部署? +`include_dir!` 编译时打包 `web/dist/`,部署只需一个 server 二进制文件。SPA fallback 让前端路由(如 `/devices`)直接返回 `index.html`。 + +### 为什么 SQLite 而不是 PostgreSQL? +医院内网单机部署场景,零外部依赖。WAL 模式 + 64MB 缓存足以支撑数百台终端的并发写入。 + +### 为什么三级作用域推送(global/group/device)? +医院按科室分组管理设备。全局策略作为基线,科室策略覆盖特定需求,单设备策略处理例外情况。`push_to_targets()` 自动解析作用域并过滤在线设备。 + +## 技术栈速查 + +| 层 | 技术 | +|----|------| +| 服务端 | Rust + Axum + SQLx + SQLite + JWT + Rustls | +| 客户端 | Rust + Tokio + sysinfo + windows-rs | +| 协议 | 自定义 TCP 二进制(MAGIC + VERSION + TYPE + LENGTH + JSON payload) | +| 前端 | Vue 3 + TypeScript + Vite + Element Plus + Pinia + ECharts | +| 构建 | Cargo workspace + npm | diff --git a/wiki/plugins.md b/wiki/plugins.md new file mode 100644 index 0000000..943cb4f --- /dev/null +++ b/wiki/plugins.md @@ -0,0 +1,78 @@ +# Plugin System(插件体系) + +## 设计思想 + +CSM 的核心扩展机制,采用**端到端插件化**设计: +- Client 端每个插件独立 tokio task,负责数据采集/策略执行 +- Server 端每个插件有独立的 API handler 模块和数据库表 +- Protocol 层每个插件有专属 MessageType 范围和 payload struct +- Frontend 端每个插件有独立页面组件 + +三级配置推送:`global` → `group` → `device`,优先级递增。 + +## 代码逻辑 + +### 插件全链路(以 Web Filter 为例) + +``` +1. API: POST /api/plugins/web-filter/rules → server/api/plugins/web_filter.rs +2. Server 存储 → db.rs → INSERT INTO web_filter_rules +3. 推送 → push_to_targets(db, clients, WebFilterRuleUpdate, payload, scope, target_id) +4. TCP → Client network/mod.rs handle_server_message() → web_filter_tx.send() +5. Client → web_filter/mod.rs config_rx.changed() → 更新本地规则 → 采集上报 +6. Client → Frame::new_json(WebAccessLog, entry) → data_tx → network → TCP → server +7. Server → tcp.rs process_frame(WebAccessLog) → db.rs → INSERT INTO web_access_logs +8. Frontend → GET /api/plugins/web-filter/log → 展示 +``` + +### 现有插件一览 + +| 插件 | 消息类型范围 | 方向 | 功能 | +|------|-------------|------|------| +| Web Filter | 0x2x | S→C 规则, C→S 日志 | URL 黑白名单、访问日志 | +| Usage Timer | 0x3x | C→S 报告 | 每日使用时长、应用使用统计 | +| Software Blocker | 0x4x | S→C 黑名单, C→S 违规 | 禁止安装软件、违规上报 | +| Popup Blocker | 0x5x | S→C 规则, C→S 统计 | 弹窗拦截规则、拦截统计 | +| USB File Audit | 0x6x | C→S 记录 | U盘文件操作审计 | +| Watermark | 0x70 | S→C 配置 | 屏幕水印显示配置 | +| USB Policy | 0x71 | S→C 策略 | U盘管控(全阻/白名单/黑名单) | +| Plugin Control | 0x80-0x81 | S→C 命令 | 远程启停插件 | +| Disk Encryption | 0x90, 0x93 | C→S 状态, S→C 配置 | 磁盘加密状态检测 | +| Print Audit | 0x91 | C→S 事件 | 打印操作审计 | +| Clipboard Control | 0x94-0x95 | S→C 规则, C→S 违规 | 剪贴板操作管控(仅上报元数据) | +| Patch Management | 0xA0-0xA2 | 双向 | 系统补丁扫描与安装 | +| Behavior Metrics | 0xB0 | C→S 指标 | 行为指标采集(异常检测输入) | + +### 新增插件必改文件清单 + +| # | 文件 | 改动 | +|---|------|------| +| 1 | `crates/protocol/src/message.rs` | 添加 MessageType 枚举值 + payload struct | +| 2 | `crates/protocol/src/lib.rs` | re-export 新类型 | +| 3 | `crates/client/src//mod.rs` | 创建插件实现 | +| 4 | `crates/client/src/main.rs` | `mod `, watch channel, PluginChannels 字段, spawn | +| 5 | `crates/client/src/network/mod.rs` | PluginChannels 字段, handle_server_message 分支 | +| 6 | `crates/server/src/api/plugins/.rs` | 创建 API handler | +| 7 | `crates/server/src/api/plugins/mod.rs` | mod 声明 + 路由注册 | +| 8 | `crates/server/src/tcp.rs` | process_frame 新分支 + push_all_plugin_configs | +| 9 | `crates/server/src/db.rs` | 新增 DB 操作方法 | +| 10 | `migrations/NNN_.sql` | 新迁移文件 | +| 11 | `crates/server/src/main.rs` | include_str! 新迁移 | + +## 关联模块 + +- [[protocol]] — 定义插件的 MessageType 和 payload +- [[client]] — 插件采集端 +- [[server]] — 插件 API 和数据处理 +- [[web-frontend]] — 插件管理页面 +- [[database]] — 每个插件的数据库表 + +## 关键文件 + +| 文件 | 职责 | +|------|------| +| `crates/client/src//mod.rs` | 客户端插件实现(每个插件一个目录) | +| `crates/server/src/api/plugins/.rs` | 服务端插件 API(每个插件一个文件) | +| `crates/server/src/tcp.rs` | 帧分发 + push_to_targets + push_all_plugin_configs | +| `crates/client/src/main.rs` | 插件 watch channel 创建 + task spawn | +| `crates/client/src/network/mod.rs` | PluginChannels 定义 + 服务器消息分发 | diff --git a/wiki/protocol.md b/wiki/protocol.md new file mode 100644 index 0000000..4e8c275 --- /dev/null +++ b/wiki/protocol.md @@ -0,0 +1,58 @@ +# Protocol(二进制协议层) + +## 设计思想 + +`csm-protocol` 是 Server 和 Client 共享的协议定义 crate。核心设计决策: + +1. **零拷贝编解码** — `Frame::encode()` / `Frame::decode()` 直接操作字节切片,无中间分配 +2. **类型安全** — `MessageType` 枚举确保所有消息类型在编译期可见,`TryFrom` 处理未知类型 +3. **JSON payload** — 网络传输用 JSON(`serde`),兼顾可调试性和跨语言兼容性 +4. **payload 上限 4MB** — `MAX_PAYLOAD_SIZE` 防止恶意帧耗尽内存 + +二进制帧格式:`MAGIC(4B "CSM\0") + VERSION(1B) + TYPE(1B) + LENGTH(4B big-endian) + PAYLOAD(变长 JSON)` + +## 代码逻辑 + +### 帧生命周期 + +``` +发送方: T → Frame::new_json(mt, &data) → Frame::encode() → Vec → TCP stream +接收方: TCP bytes → Frame::decode(&buf) → Option → Frame::decode_payload::() +``` + +### MessageType 分块规划 + +| 范围 | 插件 | 方向 | +|------|------|------| +| 0x01-0x0F | Core(心跳/注册/状态/资产) | 双向 | +| 0x10-0x1F | Core Server→Client(策略/配置/任务) | S→C | +| 0x20-0x2F | Web Filter | C→S 日志, S→C 规则 | +| 0x30-0x3F | Usage Timer | C→S 报告 | +| 0x40-0x4F | Software Blocker | C→S 违规, S→C 黑名单 | +| 0x50-0x5F | Popup Blocker | C→S 统计, S→C 规则 | +| 0x60-0x6F | USB File Audit | C→S 操作记录 | +| 0x70-0x7F | Watermark + USB Policy | S→C 配置 | +| 0x80-0x8F | Plugin Control | S→C 启停命令 | +| 0x90-0x9F | Disk Encryption / Print / Clipboard | 混合 | +| 0xA0-0xAF | Patch Management | C→S 状态, S→C 配置 | +| 0xB0-0xBF | Behavior Metrics | C→S 指标 | + +### 关键类型 + +- `Frame` — 帧结构(version + msg_type + payload bytes) +- `FrameError` — 解码错误枚举(InvalidMagic / UnknownMessageType / PayloadTooLarge / Io) +- 每个 MessageType 对应一个 payload struct(如 `WebAccessLogEntry`, `HeartbeatPayload`) + +## 关联模块 + +- [[server]] — TCP 接入层调用 `Frame::decode()` 解析客户端帧,调用 `push_to_targets()` 推送配置帧 +- [[client]] — 通过 `Frame::new_json()` 构造上报帧,通过 `Frame::decode()` 解析服务器下发的帧 +- [[plugins]] — 每个插件定义自己的 payload struct 在此 crate 中 + +## 关键文件 + +| 文件 | 职责 | +|------|------| +| `crates/protocol/src/message.rs` | MessageType 枚举、Frame 编解码、所有 payload struct | +| `crates/protocol/src/device.rs` | DeviceStatus、ProcessInfo、HardwareAsset、UsbEvent 等设备相关类型 | +| `crates/protocol/src/lib.rs` | Re-export 所有公开类型 | diff --git a/wiki/server.md b/wiki/server.md new file mode 100644 index 0000000..9c81e0c --- /dev/null +++ b/wiki/server.md @@ -0,0 +1,84 @@ +# Server(服务端) + +## 设计思想 + +`csm-server` 是整个系统的核心枢纽,同时承载三个协议: +1. **TCP 二进制协议** (端口 9999) — 接入 Client 代理 +2. **HTTP REST API** (端口 9998) — 服务 Web 面板 +3. **WebSocket** (`/ws`) — 实时推送设备状态变更到前端 + +关键设计决策: +- **SQLite + WAL** — 单机部署零依赖,WAL 模式支持并发读写 +- **include_dir 嵌入前端** — 编译时将 `web/dist/` 打包进二进制,部署只需一个文件 +- **三层权限** — public(登录/健康检查)→ authenticated(只读)→ admin(写操作) +- **ClientRegistry** — `Arc>` 管理在线客户端的 TCP 写端,支持 `push_to_targets()` 三级作用域推送 + +## 代码逻辑 + +### 启动流程 + +``` +main() → load config → init SQLite → run migrations → ensure admin + → spawn TCP listener (9999) + → spawn alert cleanup task + → spawn health score task + → build HTTP router (9998) with CORS/security headers/SPA fallback + → axum::serve() +``` + +### TCP 接入层 (`tcp.rs`) + +- `start_tcp_server()` — 监听 TCP,每连接 spawn 一个 task +- `process_frame()` — 根据 MessageType 分发到对应 handler(需先 verify_device_uid) +- `ClientRegistry` — 线程安全的在线设备注册表,支持 `list_online()`、`send_frame()` +- `push_to_targets(db, clients, msg_type, payload, target_type, target_id)` — 三级作用域推送(global/group/device) +- 帧速率限制:100 帧/5秒/连接 +- HMAC 验证:心跳帧必须携带 HMAC-SHA256 签名,连续 3 次失败断开 +- 空闲超时:180 秒无数据断开 +- 最大并发连接:500 + +### HTTP API (`api/`) + +路由分三层: +- **public**: `/api/auth/login`, `/api/auth/refresh`, `/health` +- **authenticated** (require_auth 中间件): GET 类设备/资产/告警/插件查询 +- **admin** (require_admin + require_auth): 设备删除、策略增删改、插件配置写入 + +统一响应格式 `ApiResponse`:`{ success, data, error }`,分页默认 page=1, page_size=20, 上限 100。 + +### WebSocket (`ws.rs`) + +- `WsHub` 广播设备上线/离线/状态变更事件给所有连接的前端客户端 +- JWT 认证通过 query parameter `?token=xxx` + +### 后台任务 + +- `alert::cleanup_task()` — 定期清理过期告警 +- `health::health_score_task()` — 定期计算设备健康评分 + +## 关联模块 + +- [[protocol]] — 使用 Frame 编解码和 MessageType 分发 +- [[client]] — TCP 连接的对端 +- [[web-frontend]] — HTTP API 和 WebSocket 的消费者 +- [[plugins]] — API 层的 plugins/ 子模块处理所有插件相关路由 +- [[database]] — 数据库操作集中在 db.rs + +## 关键文件 + +| 文件 | 职责 | +|------|------| +| `crates/server/src/main.rs` | 启动入口、数据库初始化、迁移、路由组装、SPA fallback | +| `crates/server/src/tcp.rs` | TCP 监听、帧处理、ClientRegistry、push_to_targets | +| `crates/server/src/ws.rs` | WebSocket hub 广播 | +| `crates/server/src/api/mod.rs` | 路由定义、ApiResponse 信封、Pagination | +| `crates/server/src/api/auth.rs` | JWT 登录/刷新/改密、限流、require_auth/require_admin 中间件 | +| `crates/server/src/api/devices.rs` | 设备列表/详情/状态/历史/健康评分 API | +| `crates/server/src/api/plugins/mod.rs` | 插件路由注册(read_routes + write_routes) | +| `crates/server/src/api/plugins/*.rs` | 各插件 API handler(每个插件一个文件) | +| `crates/server/src/db.rs` | DeviceRepo 数据库操作方法集合 | +| `crates/server/src/config.rs` | AppConfig TOML 配置加载 | +| `crates/server/src/health.rs` | 设备健康评分计算 | +| `crates/server/src/anomaly.rs` | 异常检测逻辑 | +| `crates/server/src/alert.rs` | 告警处理与清理 | +| `crates/server/src/audit.rs` | 审计日志 | diff --git a/wiki/web-frontend.md b/wiki/web-frontend.md new file mode 100644 index 0000000..23aef64 --- /dev/null +++ b/wiki/web-frontend.md @@ -0,0 +1,70 @@ +# Web Frontend(管理面板) + +## 设计思想 + +Vue 3 + TypeScript + Vite + Element Plus + Pinia + ECharts 的单页应用。关键决策: + +1. **SPA 嵌入部署** — 构建产物 `web/dist/` 通过 `include_dir!` 编译进 server 二进制,部署零额外依赖 +2. **JWT 本地存储** — token 存 `localStorage`,路由守卫检查过期,30 秒内即将过期视为无效 +3. **按路由懒加载** — 所有页面组件使用 `() => import(...)` 动态导入 + +## 代码逻辑 + +### 路由结构 + +``` +/login → Login.vue(公开) +/ → Layout.vue(认证后) + /dashboard → Dashboard.vue(仪表盘/健康概览) + /devices → Devices.vue(设备列表) + /devices/:uid → DeviceDetail.vue(设备详情) + /usb → UsbPolicy.vue(U盘策略管理) + /alerts → Alerts.vue(告警管理) + /settings → Settings.vue(系统设置) + /plugins/web-filter → WebFilter.vue + /plugins/usage-timer → UsageTimer.vue + /plugins/software-blocker → SoftwareBlocker.vue + /plugins/popup-blocker → PopupBlocker.vue + /plugins/usb-file-audit → UsbFileAudit.vue + /plugins/watermark → Watermark.vue + /plugins/disk-encryption → DiskEncryption.vue + /plugins/print-audit → PrintAudit.vue + /plugins/clipboard-control → ClipboardControl.vue + /plugins/plugin-control → PluginControl.vue +``` + +### 认证流程 + +1. Login.vue → `POST /api/auth/login` → 获取 access_token + refresh_token +2. token 存入 localStorage +3. 路由守卫 `beforeEach` 检查 JWT 过期(解析 payload.exp) +4. API 调用携带 `Authorization: Bearer ` header +5. token 过期 → 自动跳转 /login + +### API 通信 + +`web/src/lib/api.ts` — 封装所有 API 调用,统一处理认证和错误。 + +### 状态管理 + +`web/src/stores/devices.ts` — Pinia store 管理设备列表状态。 + +## 关联模块 + +- [[server]] — 消费其 HTTP REST API 和 WebSocket 推送 +- [[plugins]] — 每个插件页面对应 server 端的插件 API + +## 关键文件 + +| 文件 | 职责 | +|------|------| +| `web/src/main.ts` | 应用入口、Vue 实例创建 | +| `web/src/App.vue` | 根组件 | +| `web/src/router/index.ts` | 路由定义、JWT 路由守卫 | +| `web/src/lib/api.ts` | API 通信封装 | +| `web/src/stores/devices.ts` | Pinia 设备状态管理 | +| `web/src/views/Layout.vue` | 主布局(侧边栏+内容区) | +| `web/src/views/Dashboard.vue` | 仪表盘页 | +| `web/src/views/Devices.vue` | 设备列表页 | +| `web/src/views/DeviceDetail.vue` | 设备详情页 | +| `web/src/views/plugins/*.vue` | 各插件管理页面 |