feat: 新增补丁管理和异常检测插件及相关功能

feat(protocol): 添加补丁管理和行为指标协议类型
feat(client): 实现补丁管理插件采集功能
feat(server): 添加补丁管理和异常检测API
feat(database): 新增补丁状态和异常检测相关表
feat(web): 添加补丁管理和异常检测前端页面
fix(security): 增强输入验证和防注入保护
refactor(auth): 重构认证检查逻辑
perf(service): 优化Windows服务恢复策略
style: 统一健康评分显示样式
docs: 更新知识库文档
This commit is contained in:
iven
2026-04-11 15:59:53 +08:00
parent b5333d8c93
commit 60ee38a3c2
49 changed files with 3988 additions and 461 deletions

View File

@@ -1,5 +1,7 @@
# CSM — 企业终端安全管理系统
> **知识库**: @wiki/index.md — 编译后的模块化知识,新会话加载即了解全貌。
## 项目概览
CSM (Client Security Manager) 是一个医院设备终端安全管控平台,采用 C/S + Web 管理面板三层架构。

View File

@@ -121,6 +121,14 @@ fn collect_system_details() -> (Option<String>, Option<String>, Option<String>)
#[cfg(target_os = "windows")]
fn powershell_lines(command: &str) -> Vec<String> {
use std::process::Command;
// Reject commands containing suspicious patterns that could indicate injection
let lower = command.to_lowercase();
if lower.contains("invoke-expression") || lower.contains("iex ") || lower.contains("& ") {
tracing::warn!("Rejected suspicious PowerShell command pattern");
return Vec::new();
}
let output = match Command::new("powershell")
.args(["-NoProfile", "-NonInteractive", "-Command",
&format!("[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}", command)])

View File

@@ -1,7 +1,7 @@
use anyhow::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tracing::{info, error, warn};
use tracing::{info, error, warn, debug};
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
mod monitor;
@@ -17,6 +17,7 @@ mod web_filter;
mod disk_encryption;
mod clipboard_control;
mod print_audit;
mod patch;
#[cfg(target_os = "windows")]
mod service;
@@ -58,17 +59,22 @@ fn main() -> Result<()> {
info!("CSM Client starting (console mode)...");
let device_uid = load_or_create_device_uid()?;
info!("Device UID: {}", device_uid);
debug!("Device UID: {}", device_uid);
let server_addr = std::env::var("CSM_SERVER")
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
let registration_token = std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default();
if registration_token.is_empty() {
tracing::warn!("CSM_REGISTRATION_TOKEN not set — device registration may fail");
}
let state = ClientState {
device_uid,
server_addr,
config: ClientConfig::default(),
device_secret: load_device_secret(),
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
registration_token,
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
};
@@ -97,6 +103,7 @@ pub async fn run(state: ClientState) -> Result<()> {
let (disk_encryption_tx, disk_encryption_rx) = tokio::sync::watch::channel(disk_encryption::DiskEncryptionConfig::default());
let (print_audit_tx, print_audit_rx) = tokio::sync::watch::channel(print_audit::PrintAuditConfig::default());
let (clipboard_control_tx, clipboard_control_rx) = tokio::sync::watch::channel(clipboard_control::ClipboardControlConfig::default());
let (patch_tx, patch_rx) = tokio::sync::watch::channel(patch::PluginConfig::default());
// Build plugin channels struct
let plugins = network::PluginChannels {
@@ -110,6 +117,7 @@ pub async fn run(state: ClientState) -> Result<()> {
disk_encryption_tx,
print_audit_tx,
clipboard_control_tx,
patch_tx,
};
// Spawn core monitoring tasks
@@ -182,6 +190,12 @@ pub async fn run(state: ClientState) -> Result<()> {
clipboard_control::start(clipboard_control_rx, cc_data_tx, cc_uid).await;
});
let patch_data_tx = data_tx.clone();
let patch_uid = state.device_uid.clone();
tokio::spawn(async move {
patch::start(patch_rx, patch_data_tx, patch_uid).await;
});
// Connect to server with reconnect
let mut backoff = Duration::from_secs(1);
let max_backoff = Duration::from_secs(60);
@@ -270,5 +284,16 @@ fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Resu
#[cfg(not(unix))]
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
std::fs::write(path, content)
std::fs::write(path, content)?;
// Restrict file ACL to SYSTEM and Administrators only on Windows
let path_str = path.to_string_lossy();
// Remove inherited permissions and grant only SYSTEM full control
let _ = std::process::Command::new("icacls")
.args([&*path_str, "/inheritance:r", "/grant:r", "SYSTEM:(F)", "/grant:r", "Administrators:(F)"])
.output();
// Also hide the file
let _ = std::process::Command::new("attrib")
.args(["+H", &*path_str])
.output();
Ok(())
}

View File

@@ -1,11 +1,13 @@
use anyhow::Result;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload, PatchScanConfigPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use sha2::{Sha256, Digest};
use crate::ClientState;
@@ -21,6 +23,7 @@ pub struct PluginChannels {
pub disk_encryption_tx: tokio::sync::watch::Sender<crate::disk_encryption::DiskEncryptionConfig>,
pub print_audit_tx: tokio::sync::watch::Sender<crate::print_audit::PrintAuditConfig>,
pub clipboard_control_tx: tokio::sync::watch::Sender<crate::clipboard_control::ClipboardControlConfig>,
pub patch_tx: tokio::sync::watch::Sender<crate::patch::PluginConfig>,
}
/// Connect to server and run the main communication loop
@@ -30,7 +33,7 @@ pub async fn connect_and_run(
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
debug!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
@@ -40,9 +43,7 @@ pub async fn connect_and_run(
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
/// Wrap a TCP stream with TLS and certificate pinning.
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
@@ -62,19 +63,38 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
// Check if skip-verify is allowed (only in CSM_DEV mode)
let skip_verify = std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true")
&& std::env::var("CSM_DEV").is_ok();
let config = if skip_verify {
warn!("TLS certificate verification DISABLED — CSM_DEV mode only!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth()
} else {
// Build standard verifier with pinning wrapper
let inner = rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| anyhow::anyhow!("Failed to build TLS verifier: {:?}", e))?;
let pin_file = pin_file_path();
let pinned_hashes = load_pinned_hashes(&pin_file);
let verifier = PinnedCertVerifier {
inner,
pin_file,
pinned_hashes: Arc::new(Mutex::new(pinned_hashes)),
};
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
@@ -86,6 +106,131 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::
Ok(tls_stream)
}
/// Default pin file path: %PROGRAMDATA%\CSM\server_cert_pin (Windows)
fn pin_file_path() -> PathBuf {
if let Ok(custom) = std::env::var("CSM_TLS_PIN_FILE") {
PathBuf::from(custom)
} else if cfg!(target_os = "windows") {
std::env::var("PROGRAMDATA")
.map(|p| PathBuf::from(p).join("CSM").join("server_cert_pin"))
.unwrap_or_else(|_| PathBuf::from("server_cert_pin"))
} else {
PathBuf::from("/var/lib/csm/server_cert_pin")
}
}
/// Load pinned certificate hashes from file.
/// Format: one hex-encoded SHA-256 hash per line.
fn load_pinned_hashes(path: &PathBuf) -> Vec<String> {
match std::fs::read_to_string(path) {
Ok(content) => content.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect(),
Err(_) => Vec::new(), // First connection — no pin file yet
}
}
/// Save a pinned hash to the pin file.
fn save_pinned_hash(path: &PathBuf, hash: &str) {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let _ = std::fs::write(path, format!("{}\n", hash));
}
/// Compute SHA-256 fingerprint of a DER-encoded certificate.
fn cert_fingerprint(cert: &rustls_pki_types::CertificateDer) -> String {
let mut hasher = Sha256::new();
hasher.update(cert.as_ref());
hex::encode(hasher.finalize())
}
/// Certificate verifier with pinning support.
/// On first connection (no stored pin), records the certificate fingerprint.
/// On subsequent connections, verifies the fingerprint matches.
#[derive(Debug)]
struct PinnedCertVerifier {
inner: Arc<rustls::client::WebPkiServerVerifier>,
pin_file: PathBuf,
pinned_hashes: Arc<Mutex<Vec<String>>>,
}
impl rustls::client::danger::ServerCertVerifier for PinnedCertVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls_pki_types::CertificateDer,
intermediates: &[rustls_pki_types::CertificateDer],
server_name: &rustls_pki_types::ServerName,
ocsp_response: &[u8],
now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
// 1. Standard PKIX verification
self.inner.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?;
// 2. Compute certificate fingerprint
let fingerprint = cert_fingerprint(end_entity);
// 3. Check against pinned hashes
let mut pinned = self.pinned_hashes.lock().unwrap();
if pinned.is_empty() {
// First connection — record the certificate fingerprint
info!("Recording server certificate pin: {}...", &fingerprint[..16]);
save_pinned_hash(&self.pin_file, &fingerprint);
pinned.push(fingerprint);
} else if !pinned.contains(&fingerprint) {
warn!("Certificate pin mismatch! Expected one of {:?}, got {}", pinned, fingerprint);
return Err(rustls::Error::General(
"Server certificate does not match pinned fingerprint. Possible MITM attack.".into(),
));
}
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
/// Update pinned certificate hash (called when receiving TlsCertRotate).
pub fn update_cert_pin(new_hash: &str) {
let pin_file = pin_file_path();
let mut pinned = load_pinned_hashes(&pin_file);
if !pinned.contains(&new_hash.to_string()) {
pinned.push(new_hash.to_string());
// Keep only the last 2 hashes (current + rotating)
while pinned.len() > 2 {
pinned.remove(0);
}
// Write all hashes to file
if let Some(parent) = pin_file.parent() {
let _ = std::fs::create_dir_all(parent);
}
let content = pinned.iter().map(|h| h.as_str()).collect::<Vec<_>>().join("\n");
let _ = std::fs::write(&pin_file, format!("{}\n", content));
info!("Updated certificate pin file with new hash: {}...", &new_hash[..16]);
}
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
@@ -242,7 +387,20 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
let update: csm_protocol::ConfigUpdateType = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid config update: {}", e))?;
match update {
csm_protocol::ConfigUpdateType::UpdateIntervals { heartbeat, status, asset } => {
info!("Config update: intervals heartbeat={}s status={}s asset={}s", heartbeat, status, asset);
}
csm_protocol::ConfigUpdateType::TlsCertRotate { new_cert_hash, valid_until } => {
info!("Certificate rotation: new hash={}... valid_until={}", &new_cert_hash[..16.min(new_cert_hash.len())], valid_until);
update_cert_pin(&new_cert_hash);
}
csm_protocol::ConfigUpdateType::SelfDestruct => {
warn!("Self-destruct command received (not implemented)");
}
}
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
@@ -276,7 +434,14 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
let whitelist: Vec<String> = payload.get("whitelist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig {
enabled: true,
blacklist,
whitelist,
};
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
@@ -322,6 +487,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
};
plugins.clipboard_control_tx.send(config)?;
}
MessageType::PatchScanConfig => {
let config: PatchScanConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid patch scan config: {}", e))?;
info!("Received patch scan config: enabled={}, interval={}s", config.enabled, config.scan_interval_secs);
let plugin_config = crate::patch::PluginConfig {
enabled: config.enabled,
scan_interval_secs: config.scan_interval_secs,
};
plugins.patch_tx.send(plugin_config)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
@@ -351,7 +526,7 @@ fn handle_plugin_control(
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![], whitelist: vec![] })?;
}
}
"popup_blocker" => {
@@ -384,6 +559,11 @@ fn handle_plugin_control(
plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?;
}
}
"patch" => {
if !enabled {
plugins.patch_tx.send(crate::patch::PluginConfig { enabled: false, scan_interval_secs: 43200 })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}

View File

@@ -0,0 +1,116 @@
use tokio::sync::watch;
use csm_protocol::{Frame, MessageType, PatchStatusPayload, PatchEntry};
use tracing::{debug, warn};
#[derive(Debug, Clone, Default)]
pub struct PluginConfig {
pub enabled: bool,
pub scan_interval_secs: u64,
}
pub async fn start(
mut config_rx: watch::Receiver<PluginConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
let mut config = config_rx.borrow_and_update().clone();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(
if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 }
));
interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() { break; }
config = config_rx.borrow_and_update().clone();
let new_secs = if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 };
interval = tokio::time::interval(std::time::Duration::from_secs(new_secs));
interval.tick().await;
debug!("Patch config updated: enabled={}, interval={}s", config.enabled, new_secs);
}
_ = interval.tick() => {
if !config.enabled { continue; }
match collect_patches().await {
Ok(patches) => {
if patches.is_empty() {
debug!("No patches collected for device {}", device_uid);
continue;
}
debug!("Collected {} patches for device {}", patches.len(), device_uid);
let payload = PatchStatusPayload {
device_uid: device_uid.clone(),
patches,
};
if let Ok(frame) = Frame::new_json(MessageType::PatchStatusReport, &payload) {
if data_tx.send(frame).await.is_err() {
warn!("Failed to send patch status: channel closed");
break;
}
}
}
Err(e) => {
warn!("Patch collection failed: {}", e);
}
}
}
}
}
}
async fn collect_patches() -> anyhow::Result<Vec<PatchEntry>> {
// SECURITY: PowerShell command uses only hardcoded strings with no user/remote input.
// The format!() only inserts PowerShell syntax, not external data.
let output = tokio::process::Command::new("powershell")
.args([
"-NoProfile", "-NonInteractive", "-Command",
&format!(
"[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; \
Get-HotFix | Select-Object -First 200 | \
ForEach-Object {{ \
[PSCustomObject]@{{ \
kb = $_.HotFixID; \
desc = $_.Description; \
installed = if ($_.InstalledOn) {{ $_.InstalledOn.ToString('yyyy-MM-dd') }} else {{ '' }} \
}} \
}} | ConvertTo-Json -Compress"
),
])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow::anyhow!("PowerShell failed: {}", stderr));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let trimmed = stdout.trim();
if trimmed.is_empty() {
return Ok(Vec::new());
}
// Handle single-item case (PowerShell returns object instead of array)
let items: Vec<serde_json::Value> = if trimmed.starts_with('[') {
serde_json::from_str(trimmed).unwrap_or_default()
} else {
serde_json::from_str(trimmed).map(|v: serde_json::Value| vec![v]).unwrap_or_default()
};
let patches: Vec<PatchEntry> = items.iter().filter_map(|item| {
let kb = item.get("kb")?.as_str()?.to_string();
if kb.is_empty() { return None; }
let desc = item.get("desc").and_then(|v| v.as_str()).unwrap_or("");
let installed_str = item.get("installed").and_then(|v| v.as_str()).unwrap_or("");
Some(PatchEntry {
title: format!("{} - {}", kb, desc),
kb_id: kb,
severity: None, // Will be enriched server-side from known CVE data
is_installed: true,
installed_at: if installed_str.is_empty() { None } else { Some(installed_str.to_string()) },
})
}).collect();
Ok(patches)
}

View File

@@ -1,9 +1,10 @@
use std::ffi::OsString;
use std::time::Duration;
use tracing::{info, error};
use tracing::{info, error, debug};
use windows_service::define_windows_service;
use windows_service::service::{
ServiceAccess, ServiceControl, ServiceErrorControl, ServiceExitCode,
ServiceAccess, ServiceAction, ServiceActionType, ServiceControl, ServiceErrorControl,
ServiceExitCode, ServiceFailureActions, ServiceFailureResetPeriod,
ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType,
};
use windows_service::service_control_handler::{self, ServiceControlHandlerResult};
@@ -38,6 +39,30 @@ pub fn install() -> anyhow::Result<()> {
let service = manager.create_service(&service_info, ServiceAccess::CHANGE_CONFIG)?;
service.set_description(SERVICE_DESCRIPTION)?;
// Configure service recovery: restart on failure with escalating delays
let failure_actions = ServiceFailureActions {
reset_period: ServiceFailureResetPeriod::After(std::time::Duration::from_secs(60)),
reboot_msg: None,
command: None,
actions: Some(vec![
ServiceAction {
action_type: ServiceActionType::Restart,
delay: std::time::Duration::from_secs(5),
},
ServiceAction {
action_type: ServiceActionType::Restart,
delay: std::time::Duration::from_secs(30),
},
ServiceAction {
action_type: ServiceActionType::Restart,
delay: std::time::Duration::from_secs(60),
},
]),
};
if let Err(e) = service.update_failure_actions(failure_actions) {
tracing::warn!("Failed to set service recovery actions: {}", e);
}
println!("Service '{}' installed successfully.", SERVICE_NAME);
println!("Use 'sc start {}' or restart to launch.", SERVICE_NAME);
Ok(())
@@ -117,7 +142,7 @@ fn run_service_inner() -> anyhow::Result<()> {
// Build ClientState
let device_uid = crate::load_or_create_device_uid()?;
info!("Device UID: {}", device_uid);
debug!("Device UID: {}", device_uid);
let state = crate::ClientState {
device_uid,

View File

@@ -1,11 +1,13 @@
use std::collections::HashSet;
use tokio::sync::watch;
use tracing::{info, warn};
use tracing::{info, warn, debug};
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
use serde::Deserialize;
/// System-critical processes that must never be killed regardless of server rules.
/// Killing any of these would cause system instability or a BSOD.
const PROTECTED_PROCESSES: &[&str] = &[
// Windows system processes
"system",
"system idle process",
"svchost.exe",
@@ -20,8 +22,60 @@ const PROTECTED_PROCESSES: &[&str] = &[
"registry",
"smss.exe",
"conhost.exe",
"ntoskrnl.exe",
"dcomlaunch.exe",
"rundll32.exe",
"sihost.exe",
"taskeng.exe",
"wermgr.exe",
"WerFault.exe",
"fontdrvhost.exe",
"ctfmon.exe",
"SearchIndexer.exe",
"SearchHost.exe",
"RuntimeBroker.exe",
"SecurityHealthService.exe",
"SecurityHealthSystray.exe",
"MpCmdRun.exe",
"MsMpEng.exe",
"NisSrv.exe",
// Common browsers — should never be blocked unless explicitly configured with exact name
"chrome.exe",
"msedge.exe",
"firefox.exe",
"iexplore.exe",
"opera.exe",
"brave.exe",
"vivaldi.exe",
"thorium.exe",
// Development tools & IDEs
"code.exe",
"devenv.exe",
"idea64.exe",
"webstorm64.exe",
"pycharm64.exe",
"goland64.exe",
"clion64.exe",
"rider64.exe",
"datagrip64.exe",
"trae.exe",
"windsurf.exe",
"cursor.exe",
"zed.exe",
// Terminal & system tools
"cmd.exe",
"powershell.exe",
"pwsh.exe",
"WindowsTerminal.exe",
"conhost.exe",
// CSM itself
"csm-client.exe",
];
/// Cooldown period (seconds) before reporting/killing the same process again.
/// Prevents spamming violations for long-running blocked processes.
const REPORT_COOLDOWN_SECS: u64 = 300; // 5 minutes
/// Software blacklist entry from server
#[derive(Debug, Clone, Deserialize)]
pub struct BlacklistEntry {
@@ -36,6 +90,8 @@ pub struct BlacklistEntry {
pub struct SoftwareBlockerConfig {
pub enabled: bool,
pub blacklist: Vec<BlacklistEntry>,
/// Server-pushed whitelist: processes matching these patterns are never blocked.
pub whitelist: Vec<String>,
}
/// Start software blocker plugin.
@@ -47,6 +103,11 @@ pub async fn start(
) {
info!("Software blocker plugin started");
let mut config = SoftwareBlockerConfig::default();
// Track recently acted-on processes to avoid repeated kill/report spam
let mut recent_actions: HashSet<String> = HashSet::new();
let mut cooldown_interval = tokio::time::interval(std::time::Duration::from_secs(REPORT_COOLDOWN_SECS));
cooldown_interval.tick().await;
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
scan_interval.tick().await;
@@ -58,13 +119,23 @@ pub async fn start(
}
let new_config = config_rx.borrow_and_update().clone();
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
// Clear cooldown cache when config changes so new rules take effect immediately
recent_actions.clear();
config = new_config;
}
_ = cooldown_interval.tick() => {
// Periodically clear the cooldown cache so we can re-check
let cleared = recent_actions.len();
recent_actions.clear();
if cleared > 0 {
info!("Software blocker cooldown cache cleared ({} entries)", cleared);
}
}
_ = scan_interval.tick() => {
if !config.enabled || config.blacklist.is_empty() {
continue;
}
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
scan_processes(&config.blacklist, &config.whitelist, &data_tx, &device_uid, &mut recent_actions).await;
}
}
}
@@ -72,83 +143,148 @@ pub async fn start(
async fn scan_processes(
blacklist: &[BlacklistEntry],
whitelist: &[String],
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
recent_actions: &mut HashSet<String>,
) {
let running = get_running_processes_with_pids();
for entry in blacklist {
for (process_name, pid) in &running {
if pattern_matches(&entry.name_pattern, process_name) {
// Never kill system-critical processes
if is_protected_process(process_name) {
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
continue;
}
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
// Report violation to server
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
let action_taken = match entry.action.as_str() {
"block" => "blocked_install",
"alert" => "alerted",
other => other,
};
let violation = SoftwareViolationReport {
device_uid: device_uid.to_string(),
software_name: process_name.clone(),
action_taken: action_taken.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
let _ = data_tx.send(frame).await;
}
// Kill the process directly by captured PID (avoids TOCTOU race)
if entry.action == "block" {
kill_process_by_pid(*pid, process_name);
}
if !pattern_matches(&entry.name_pattern, process_name) {
continue;
}
// Never kill protected processes (system, browsers, IDEs, etc.)
if is_protected_process(process_name) {
warn!(
"Blacklisted match '{}' skipped — protected process (pid={})",
process_name, pid
);
continue;
}
// Check server-pushed whitelist (takes precedence over blacklist)
if is_whitelisted(process_name, whitelist) {
debug!(
"Blacklisted match '{}' skipped — whitelisted (pid={})",
process_name, pid
);
continue;
}
// Skip if already acted on recently (cooldown)
let action_key = format!("{}:{}", process_name.to_lowercase(), pid);
if recent_actions.contains(&action_key) {
continue;
}
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
// Report violation to server
let action_taken = match entry.action.as_str() {
"block" => "blocked_install",
"alert" => "alerted",
other => other,
};
let violation = SoftwareViolationReport {
device_uid: device_uid.to_string(),
software_name: process_name.clone(),
action_taken: action_taken.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
let _ = data_tx.send(frame).await;
}
// Kill the process if action is "block"
if entry.action == "block" {
kill_process_by_pid(*pid, process_name);
}
// Mark as recently acted-on
recent_actions.insert(action_key);
}
}
}
/// Match a blacklist pattern against a process name.
///
/// **Matching rules (case-insensitive)**:
/// - No wildcard → exact filename match only (e.g. `chrome.exe` matches `chrome.exe` but NOT `new_chrome.exe`)
/// - `*` wildcard → glob-style pattern match (e.g. `*miner*` matches `bitcoin_miner.exe`)
/// - Pattern with `.exe` suffix → match against full process name
/// - Pattern without extension → match against stem (name without `.exe`)
///
/// **IMPORTANT**: We intentionally do NOT use substring matching (`contains()`) for
/// non-wildcard patterns. Substring matching caused false positives where a pattern
/// like "game" would match "game_bar.exe" or even "svchost.exe" in edge cases.
fn pattern_matches(pattern: &str, name: &str) -> bool {
let pattern_lower = pattern.to_lowercase();
let name_lower = name.to_lowercase();
// Support wildcard patterns
if pattern_lower.contains('*') {
let parts: Vec<&str> = pattern_lower.split('*').collect();
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 && !parts[0].is_empty() {
// Pattern starts with literal → must match at start
if !name_lower.starts_with(part) {
return false;
}
pos = part.len();
} else {
match name_lower[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with literal (no trailing *), must match at end
if !parts.last().map_or(true, |p| p.is_empty()) {
return name_lower.ends_with(parts.last().unwrap());
}
true
// Extract the stem (filename without extension) for extension-less patterns
let name_stem = name_lower.strip_suffix(".exe").unwrap_or(&name_lower);
if pattern_lower.contains('*') {
// Wildcard glob matching
glob_matches(&pattern_lower, &name_lower, &name_stem)
} else {
name_lower.contains(&pattern_lower)
// Exact match: compare against full name OR stem (if pattern has no extension)
let pattern_stem = pattern_lower.strip_suffix(".exe").unwrap_or(&pattern_lower);
name_lower == pattern_lower || name_stem == pattern_stem
}
}
/// Glob-style pattern matching with `*` wildcards.
/// Checks both the full name and the stem (without .exe).
fn glob_matches(pattern: &str, full_name: &str, stem: &str) -> bool {
// Try matching against both the full process name and the stem
glob_matches_single(pattern, full_name) || glob_matches_single(pattern, stem)
}
fn glob_matches_single(pattern: &str, text: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.is_empty() {
return true;
}
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 {
// First segment: must match at the start
if !text.starts_with(part) {
return false;
}
pos = part.len();
} else if i == parts.len() - 1 {
// Last segment: must match at the end (if pattern doesn't end with *)
return text.ends_with(part) && text[..text.len() - part.len()].len() >= pos;
} else {
// Middle segment: must appear after current position
match text[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with *, anything after last match is fine
if pattern.ends_with('*') {
return true;
}
// Full match required
pos == text.len()
}
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
#[cfg(target_os = "windows")]
@@ -253,3 +389,12 @@ fn is_protected_process(name: &str) -> bool {
let lower = name.to_lowercase();
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
}
/// Check if a process name matches any server-pushed whitelist pattern.
/// Whitelist patterns use the same matching logic as blacklist (exact or glob).
fn is_whitelisted(process_name: &str, whitelist: &[String]) -> bool {
if whitelist.is_empty() {
return false;
}
whitelist.iter().any(|pattern| pattern_matches(pattern, process_name))
}

View File

@@ -29,4 +29,6 @@ pub use message::{
PrintEventPayload,
ClipboardRulesPayload, ClipboardRule, ClipboardViolationPayload,
PopupBlockStatsPayload, PopupRuleStat,
PatchStatusPayload, PatchEntry, PatchScanConfigPayload,
BehaviorMetricsPayload,
};

View File

@@ -71,6 +71,14 @@ pub enum MessageType {
// Plugin: Clipboard Control (剪贴板管控)
ClipboardRules = 0x94,
ClipboardViolation = 0x95,
// Plugin: Patch Management (补丁管理)
PatchStatusReport = 0xA0,
PatchScanConfig = 0xA1,
PatchInstallCommand = 0xA2,
// Plugin: Behavior Metrics (行为指标)
BehaviorMetricsReport = 0xB0,
}
impl TryFrom<u8> for MessageType {
@@ -108,6 +116,10 @@ impl TryFrom<u8> for MessageType {
0x91 => Ok(Self::PrintEvent),
0x94 => Ok(Self::ClipboardRules),
0x95 => Ok(Self::ClipboardViolation),
0xA0 => Ok(Self::PatchStatusReport),
0xA1 => Ok(Self::PatchScanConfig),
0xA2 => Ok(Self::PatchInstallCommand),
0xB0 => Ok(Self::BehaviorMetricsReport),
_ => Err(format!("Unknown message type: 0x{:02X}", value)),
}
}
@@ -264,7 +276,7 @@ pub struct TaskExecutePayload {
#[derive(Debug, Serialize, Deserialize)]
pub enum ConfigUpdateType {
UpdateIntervals { heartbeat: u64, status: u64, asset: u64 },
TlsCertRotate,
TlsCertRotate { new_cert_hash: String, valid_until: String },
SelfDestruct,
}
@@ -442,6 +454,44 @@ pub struct PopupRuleStat {
pub hits: u32,
}
/// Plugin: Patch Status Report (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct PatchStatusPayload {
pub device_uid: String,
pub patches: Vec<PatchEntry>,
}
/// Information about a single patch/hotfix.
#[derive(Debug, Serialize, Deserialize)]
pub struct PatchEntry {
pub kb_id: String,
pub title: String,
pub severity: Option<String>, // "Critical" | "Important" | "Moderate" | "Low"
pub is_installed: bool,
pub installed_at: Option<String>,
}
/// Plugin: Patch Scan Config (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PatchScanConfigPayload {
pub enabled: bool,
pub scan_interval_secs: u64,
}
/// Plugin: Behavior Metrics Report (Client → Server)
/// Enhanced periodic metrics for anomaly detection.
#[derive(Debug, Serialize, Deserialize)]
pub struct BehaviorMetricsPayload {
pub device_uid: String,
pub clipboard_ops_count: u32,
pub clipboard_ops_night: u32,
pub print_jobs_count: u32,
pub usb_file_ops_count: u32,
pub new_processes_count: u32,
pub period_secs: u64,
pub timestamp: String,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -41,6 +41,26 @@ pub async fn cleanup_task(state: AppState) {
error!("Failed to cleanup alert records: {}", e);
}
// Cleanup old revoked token families (keep 30 days for audit)
if let Err(e) = sqlx::query(
"DELETE FROM revoked_token_families WHERE revoked_at < datetime('now', '-30 days')"
)
.execute(&state.db)
.await
{
error!("Failed to cleanup revoked token families: {}", e);
}
// Cleanup old anomaly alerts that have been handled
if let Err(e) = sqlx::query(
"DELETE FROM anomaly_alerts WHERE handled = 1 AND triggered_at < datetime('now', '-90 days')"
)
.execute(&state.db)
.await
{
error!("Failed to cleanup handled anomaly alerts: {}", e);
}
// Mark devices as offline if no heartbeat for 2 minutes
if let Err(e) = sqlx::query(
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"

View File

@@ -0,0 +1,170 @@
use sqlx::Row;
use tracing::{info, warn};
use csm_protocol::BehaviorMetricsPayload;
/// Check incoming behavior metrics against anomaly rules and generate alerts
pub async fn check_anomalies(
pool: &sqlx::SqlitePool,
ws_hub: &crate::ws::WsHub,
metrics: &BehaviorMetricsPayload,
) {
let mut alerts: Vec<serde_json::Value> = Vec::new();
// Rule 1: Night-time clipboard operations (> 10 in reporting period)
if metrics.clipboard_ops_night > 10 {
alerts.push(serde_json::json!({
"anomaly_type": "night_clipboard_spike",
"severity": "high",
"detail": format!("非工作时间剪贴板操作异常: {}次 (阈值: 10次)", metrics.clipboard_ops_night),
"metric_value": metrics.clipboard_ops_night,
}));
}
// Rule 2: High USB file operations (> 100 per hour)
if metrics.period_secs > 0 {
let usb_per_hour = (metrics.usb_file_ops_count as f64 / metrics.period_secs as f64) * 3600.0;
if usb_per_hour > 100.0 {
alerts.push(serde_json::json!({
"anomaly_type": "usb_file_exfiltration",
"severity": "critical",
"detail": format!("USB文件操作频率异常: {:.0}次/小时 (阈值: 100次/小时)", usb_per_hour),
"metric_value": usb_per_hour,
}));
}
}
// Rule 3: High print volume (> 50 per reporting period)
if metrics.print_jobs_count > 50 {
alerts.push(serde_json::json!({
"anomaly_type": "high_print_volume",
"severity": "medium",
"detail": format!("打印量异常: {}次 (阈值: 50次)", metrics.print_jobs_count),
"metric_value": metrics.print_jobs_count,
}));
}
// Rule 4: Excessive new processes (> 20 per hour)
if metrics.period_secs > 0 {
let procs_per_hour = (metrics.new_processes_count as f64 / metrics.period_secs as f64) * 3600.0;
if procs_per_hour > 20.0 {
alerts.push(serde_json::json!({
"anomaly_type": "process_spawn_spike",
"severity": "medium",
"detail": format!("新进程启动异常: {:.0}次/小时 (阈值: 20次/小时)", procs_per_hour),
"metric_value": procs_per_hour,
}));
}
}
// Insert anomaly alerts
for alert in &alerts {
if let Err(e) = sqlx::query(
"INSERT INTO anomaly_alerts (device_uid, anomaly_type, severity, detail, metric_value, triggered_at) \
VALUES (?, ?, ?, ?, ?, datetime('now'))"
)
.bind(&metrics.device_uid)
.bind(alert.get("anomaly_type").and_then(|v| v.as_str()).unwrap_or("unknown"))
.bind(alert.get("severity").and_then(|v| v.as_str()).unwrap_or("medium"))
.bind(alert.get("detail").and_then(|v| v.as_str()).unwrap_or(""))
.bind(alert.get("metric_value").and_then(|v| v.as_f64()).unwrap_or(0.0))
.execute(pool)
.await
{
warn!("Failed to insert anomaly alert: {}", e);
}
}
// Broadcast anomaly alerts via WebSocket
if !alerts.is_empty() {
for alert in &alerts {
ws_hub.broadcast(serde_json::json!({
"type": "anomaly_alert",
"device_uid": metrics.device_uid,
"anomaly_type": alert.get("anomaly_type"),
"severity": alert.get("severity"),
"detail": alert.get("detail"),
}).to_string()).await;
}
info!("Detected {} anomalies for device {}", alerts.len(), metrics.device_uid);
}
}
/// Get anomaly alert summary for a device or all devices
pub async fn get_anomaly_summary(
pool: &sqlx::SqlitePool,
device_uid: Option<&str>,
page: u32,
page_size: u32,
) -> anyhow::Result<serde_json::Value> {
let offset = page.saturating_sub(1) * page_size;
let rows = if let Some(uid) = device_uid {
sqlx::query(
"SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \
WHERE a.device_uid = ? ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(uid)
.bind(page_size)
.bind(offset)
.fetch_all(pool)
.await?
} else {
sqlx::query(
"SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \
ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(page_size)
.bind(offset)
.fetch_all(pool)
.await?
};
let total: i64 = if let Some(uid) = device_uid {
sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts WHERE device_uid = ?")
.bind(uid)
.fetch_one(pool)
.await?
} else {
sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts")
.fetch_one(pool)
.await?
};
let alerts: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"anomaly_type": r.get::<String, _>("anomaly_type"),
"severity": r.get::<String, _>("severity"),
"detail": r.get::<String, _>("detail"),
"metric_value": r.get::<f64, _>("metric_value"),
"handled": r.get::<i32, _>("handled"),
"triggered_at": r.get::<String, _>("triggered_at"),
})).collect();
// Summary counts (scoped to same filter)
let unhandled: i64 = if let Some(uid) = device_uid {
sqlx::query_scalar(
"SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0 AND device_uid = ?"
)
.bind(uid)
.fetch_one(pool)
.await
.unwrap_or(0)
} else {
sqlx::query_scalar(
"SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0"
)
.fetch_one(pool)
.await
.unwrap_or(0)
};
Ok(serde_json::json!({
"alerts": alerts,
"total": total,
"unhandled_count": unhandled,
"page": page,
"page_size": page_size,
}))
}

View File

@@ -21,7 +21,7 @@ pub async fn list_rules(
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
FROM alert_rules ORDER BY created_at DESC"
FROM alert_rules ORDER BY created_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await;
@@ -116,7 +116,26 @@ pub async fn create_rule(
State(state): State<AppState>,
Json(body): Json<CreateRuleRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
// Validate rule_type
if !matches!(body.rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid rule_type")));
}
// Validate severity
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid severity")));
}
// Validate webhook URL (SSRF prevention)
if let Some(ref url) = body.notify_webhook {
if !url.starts_with("https://") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL must use HTTPS")));
}
if url.len() > 2048 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL too long")));
}
}
let result = sqlx::query(
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
@@ -174,6 +193,26 @@ pub async fn update_rule(
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
// Validate rule_type
if !matches!(rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") {
return Json(ApiResponse::error("Invalid rule_type"));
}
// Validate severity
if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") {
return Json(ApiResponse::error("Invalid severity"));
}
// Validate webhook URL (SSRF prevention)
if let Some(ref url) = notify_webhook {
if !url.starts_with("https://") {
return Json(ApiResponse::error("Webhook URL must use HTTPS"));
}
if url.len() > 2048 {
return Json(ApiResponse::error("Webhook URL too long"));
}
}
let result = sqlx::query(
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"

View File

@@ -1,4 +1,5 @@
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::{Response, IntoResponse}};
use axum::http::header::{SET_COOKIE, HeaderValue};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;
@@ -28,11 +29,15 @@ pub struct LoginRequest {
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: String,
pub user: UserInfo,
}
#[derive(Debug, Serialize)]
pub struct MeResponse {
pub user: UserInfo,
pub expires_at: String,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct UserInfo {
pub id: i64,
@@ -40,18 +45,68 @@ pub struct UserInfo {
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Deserialize)]
pub struct ChangePasswordRequest {
pub old_password: String,
pub new_password: String,
}
/// In-memory rate limiter for login attempts
// ---------------------------------------------------------------------------
// Cookie helpers
// ---------------------------------------------------------------------------
fn is_secure_cookies() -> bool {
std::env::var("CSM_DEV").is_err()
}
fn access_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"access_token={}; HttpOnly{}; SameSite=Strict; Path=/; Max-Age={}",
token, secure, ttl_secs
)).expect("valid cookie header")
}
fn refresh_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
HeaderValue::from_str(&format!(
"refresh_token={}; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age={}",
token, secure, ttl_secs
)).expect("valid cookie header")
}
fn clear_cookie_headers() -> Vec<HeaderValue> {
let secure = if is_secure_cookies() { "; Secure" } else { "" };
vec![
HeaderValue::from_str(&format!("access_token=; HttpOnly{}; SameSite=Strict; Path=/; Max-Age=0", secure)).expect("valid"),
HeaderValue::from_str(&format!("refresh_token=; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age=0", secure)).expect("valid"),
]
}
/// Attach Set-Cookie headers to a response.
fn with_cookies(mut response: Response, cookies: Vec<HeaderValue>) -> Response {
for cookie in cookies {
response.headers_mut().append(SET_COOKIE, cookie);
}
response
}
/// Extract a cookie value by name from the raw Cookie header.
fn extract_cookie_value(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
let cookie_header = headers.get("cookie")?.to_str().ok()?;
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) {
return Some(value.to_string());
}
}
None
}
// ---------------------------------------------------------------------------
// Rate limiter
// ---------------------------------------------------------------------------
#[derive(Clone, Default)]
pub struct LoginRateLimiter {
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
@@ -62,28 +117,25 @@ impl LoginRateLimiter {
Self::default()
}
/// Returns true if the request should be rate-limited
pub async fn is_limited(&self, key: &str) -> bool {
let mut attempts = self.attempts.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(300); // 5-minute window
let window = std::time::Duration::from_secs(300);
let max_attempts = 10u32;
if let Some((first_attempt, count)) = attempts.get_mut(key) {
if now.duration_since(*first_attempt) > window {
// Window expired, reset
*first_attempt = now;
*count = 1;
false
} else if *count >= max_attempts {
true // Rate limited
true
} else {
*count += 1;
false
}
} else {
attempts.insert(key.to_string(), (now, 1));
// Cleanup old entries periodically
if attempts.len() > 1000 {
let cutoff = now - window;
attempts.retain(|_, (t, _)| *t > cutoff);
@@ -93,46 +145,67 @@ impl LoginRateLimiter {
}
}
// ---------------------------------------------------------------------------
// Endpoints
// ---------------------------------------------------------------------------
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
// Rate limit check
) -> impl IntoResponse {
if state.login_limiter.is_limited(&req.username).await {
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts. Try again later."))).into_response();
}
if state.login_limiter.is_limited("ip:default").await {
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts from this location. Try again later."))).into_response();
}
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
"SELECT id, username, role FROM users WHERE username = ?"
let row: Option<(UserInfo, String)> = sqlx::query_as::<_, (i64, String, String, String)>(
"SELECT id, username, role, password FROM users WHERE username = ?"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
.ok()
.flatten()
.map(|(id, username, role, password)| {
(UserInfo { id, username, role }, password)
});
let user = match user {
Some(u) => u,
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
let (user, hash) = match row {
Some(r) => r,
None => {
let _ = bcrypt::verify("timing-constant-dummy", "$2b$12$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
}
};
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
.bind(user.id)
.fetch_one(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
}
let now = chrono::Utc::now().timestamp() as u64;
let family = uuid::Uuid::new_v4().to_string();
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
let _ = sqlx::query(
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
)
.bind(user.id)
.bind(&family)
.bind(refresh_expires as i64)
.execute(&state.db)
.await;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
)
@@ -141,73 +214,262 @@ pub async fn login(
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
with_cookies(response, vec![
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
])
}
pub async fn refresh(
State(state): State<AppState>,
Json(req): Json<RefreshRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
let claims = decode::<Claims>(
&req.refresh_token,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let refresh_token = match extract_cookie_value(&headers, "refresh_token") {
Some(t) => t,
None => return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Missing refresh token"))).into_response(),
clear_cookie_headers(),
),
};
let claims = match decode::<Claims>(
&refresh_token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
) {
Ok(c) => c.claims,
Err(_) => return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response(),
clear_cookie_headers(),
),
};
if claims.claims.token_type != "refresh" {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
if claims.token_type != "refresh" {
return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid token type"))).into_response(),
clear_cookie_headers(),
);
}
// Check if this refresh token family has been revoked (reuse detection)
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let revoked: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
)
.bind(&claims.claims.family)
.fetch_one(&state.db)
.bind(&claims.family)
.fetch_one(&mut *tx)
.await
.unwrap_or(0) > 0;
if revoked {
// Token reuse detected — revoke entire family and force re-login
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
tx.rollback().await.ok();
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.sub, claims.family);
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.bind(claims.sub)
.execute(&state.db)
.await;
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
return with_cookies(
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Token reuse detected. Please log in again."))).into_response(),
clear_cookie_headers(),
);
}
let family_exists: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM refresh_tokens WHERE family = ? AND user_id = ?"
)
.bind(&claims.family)
.bind(claims.sub)
.fetch_one(&mut *tx)
.await
.unwrap_or(0) > 0;
if !family_exists {
tx.rollback().await.ok();
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response();
}
let user = UserInfo {
id: claims.claims.sub,
username: claims.claims.username,
role: claims.claims.role,
id: claims.sub,
username: claims.username,
role: claims.role,
};
// Rotate: new family for each refresh
let new_family = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().timestamp() as u64;
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
Ok(t) => t,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
// Revoke old family
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.claims.family)
.bind(claims.claims.sub)
.execute(&state.db)
.await;
if sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.family)
.bind(claims.sub)
.execute(&mut *tx)
.await
.is_err()
{
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
if sqlx::query(
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
)
.bind(user.id)
.bind(&new_family)
.bind(refresh_expires as i64)
.execute(&mut *tx)
.await
.is_err()
{
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
if tx.commit().await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
with_cookies(response, vec![
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
])
}
/// Get current authenticated user info from access_token cookie.
pub async fn me(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = match extract_cookie_value(&headers, "access_token") {
Some(t) => t,
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Not authenticated"))).into_response(),
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token"))).into_response(),
};
if claims.token_type != "access" {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token type"))).into_response();
}
let expires_at = chrono::DateTime::from_timestamp(claims.exp as i64, 0)
.map(|t| t.to_rfc3339())
.unwrap_or_default();
(StatusCode::OK, Json(ApiResponse::ok(MeResponse {
user: UserInfo {
id: claims.sub,
username: claims.username,
role: claims.role,
},
expires_at,
}))).into_response()
}
/// Logout: clear auth cookies and revoke refresh token family.
pub async fn logout(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
if let Some(token) = extract_cookie_value(&headers, "access_token") {
if let Ok(claims) = decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.execute(&state.db)
.await;
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'logout', ?)"
)
.bind(claims.claims.sub)
.bind(format!("User {} logged out", claims.claims.username))
.execute(&state.db)
.await;
}
}
let response = (StatusCode::OK, Json(ApiResponse::ok(()))).into_response();
with_cookies(response, clear_cookie_headers())
}
// ---------------------------------------------------------------------------
// WebSocket ticket
// ---------------------------------------------------------------------------
#[derive(Debug, Serialize)]
pub struct WsTicketResponse {
pub ticket: String,
pub expires_in: u64,
}
/// Create a one-time ticket for WebSocket authentication.
/// Requires a valid access_token cookie (set by login).
pub async fn create_ws_ticket(
State(state): State<AppState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = match extract_cookie_value(&headers, "access_token") {
Some(t) => t,
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Not authenticated"))).into_response(),
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token"))).into_response(),
};
if claims.token_type != "access" {
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token type"))).into_response();
}
let ticket = uuid::Uuid::new_v4().to_string();
let claim = crate::ws::TicketClaim {
user_id: claims.sub,
username: claims.username,
role: claims.role,
created_at: std::time::Instant::now(),
};
{
let mut tickets = state.ws_tickets.lock().await;
tickets.insert(ticket.clone(), claim);
// Cleanup expired tickets (>30s old) on every creation
tickets.retain(|_, c| c.created_at.elapsed().as_secs() < 30);
}
(StatusCode::OK, Json(ApiResponse::ok(WsTicketResponse {
ticket,
expires_in: 30,
}))).into_response()
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
let claims = Claims {
sub: user.id,
@@ -227,24 +489,17 @@ fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// Axum middleware: require valid JWT access token
/// Axum middleware: require valid JWT access token from cookie
pub async fn require_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let token = match auth_header {
Some(t) => t,
None => return Err(StatusCode::UNAUTHORIZED),
};
let token = extract_cookie_value(request.headers(), "access_token")
.ok_or(StatusCode::UNAUTHORIZED)?;
let claims = decode::<Claims>(
token,
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
@@ -254,9 +509,7 @@ pub async fn require_auth(
return Err(StatusCode::UNAUTHORIZED);
}
// Inject claims into request extensions for handlers to use
request.extensions_mut().insert(claims.claims);
Ok(next.run(request).await)
}
@@ -274,7 +527,6 @@ pub async fn require_admin(
return Err(StatusCode::FORBIDDEN);
}
// Capture audit info before running handler
let method = request.method().clone();
let path = request.uri().path().to_string();
let user_id = claims.sub;
@@ -282,7 +534,6 @@ pub async fn require_admin(
let response = next.run(request).await;
// Record admin action to audit log (fire and forget — don't block response)
let status = response.status();
if status.is_success() {
let action = format!("{} {}", method, path);
@@ -308,8 +559,10 @@ pub async fn change_password(
if req.new_password.len() < 6 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("新密码至少6位"))));
}
if req.new_password.len() > 72 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("密码不能超过72位"))));
}
// Verify old password
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
@@ -322,7 +575,6 @@ pub async fn change_password(
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("当前密码错误"))));
}
// Update password
let new_hash = bcrypt::hash(&req.new_password, 12).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
sqlx::query("UPDATE users SET password = ? WHERE id = ?")
.bind(&new_hash)
@@ -331,7 +583,6 @@ pub async fn change_password(
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'change_password', ?)"
)

View File

@@ -0,0 +1,243 @@
use axum::{extract::State, Json};
use serde::Serialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
#[derive(Debug, Serialize)]
pub struct PolicyConflict {
pub conflict_type: String,
pub severity: String,
pub description: String,
pub policies: Vec<ConflictPolicyRef>,
}
#[derive(Debug, Serialize)]
pub struct ConflictPolicyRef {
pub table_name: String,
pub row_id: i64,
pub name: String,
pub target_type: String,
pub target_id: Option<String>,
}
/// GET /api/policies/conflicts — scan all policies for conflicts
pub async fn scan_conflicts(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let mut conflicts: Vec<PolicyConflict> = Vec::new();
// 1. USB: multiple enabled policies for the same target_group
if let Ok(rows) = sqlx::query(
"SELECT target_group, COUNT(*) as cnt, GROUP_CONCAT(id) as ids, GROUP_CONCAT(name) as names, \
GROUP_CONCAT(policy_type) as types \
FROM usb_policies WHERE enabled = 1 AND target_group IS NOT NULL \
GROUP BY target_group HAVING cnt > 1"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let group: String = row.get("target_group");
let ids: String = row.get("ids");
let names: String = row.get("names");
let types: String = row.get("types");
let id_vec: Vec<i64> = ids.split(',').filter_map(|s| s.parse().ok()).collect();
let name_vec: Vec<&str> = names.split(',').collect();
let type_vec: Vec<&str> = types.split(',').collect();
conflicts.push(PolicyConflict {
conflict_type: "usb_duplicate_policy".to_string(),
severity: "high".to_string(),
description: format!("分组 '{}' 同时存在 {} 条启用的USB策略 ({})", group, id_vec.len(), type_vec.join(", ")),
policies: id_vec.iter().enumerate().map(|(i, id)| ConflictPolicyRef {
table_name: "usb_policies".to_string(),
row_id: *id,
name: name_vec.get(i).unwrap_or(&"?").to_string(),
target_type: "group".to_string(),
target_id: Some(group.clone()),
}).collect(),
});
}
}
// 2. USB: all_block + whitelist for same target (contradictory)
if let Ok(rows) = sqlx::query(
"SELECT a.id as aid, a.name as aname, a.target_group as agroup, \
b.id as bid, b.name as bname \
FROM usb_policies a JOIN usb_policies b ON a.target_group = b.target_group AND a.id < b.id \
WHERE a.enabled = 1 AND b.enabled = 1 \
AND ((a.policy_type = 'all_block' AND b.policy_type = 'whitelist') OR \
(a.policy_type = 'whitelist' AND b.policy_type = 'all_block'))"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let group: Option<String> = row.get("agroup");
conflicts.push(PolicyConflict {
conflict_type: "usb_block_vs_whitelist".to_string(),
severity: "critical".to_string(),
description: format!("分组 '{}' 同时存在全封堵和白名单USB策略互斥", group.as_deref().unwrap_or("?")),
policies: vec![
ConflictPolicyRef {
table_name: "usb_policies".to_string(),
row_id: row.get("aid"),
name: row.get("aname"),
target_type: "group".to_string(),
target_id: group.clone(),
},
ConflictPolicyRef {
table_name: "usb_policies".to_string(),
row_id: row.get("bid"),
name: row.get("bname"),
target_type: "group".to_string(),
target_id: group,
},
],
});
}
}
// 3. Web filter: same target, same pattern, different rule_type (allow vs block)
if let Ok(rows) = sqlx::query(
"SELECT a.id as aid, a.pattern as apattern, a.rule_type as artype, \
b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \
FROM web_filter_rules a JOIN web_filter_rules b ON a.pattern = b.pattern AND a.id < b.id \
WHERE a.enabled = 1 AND b.enabled = 1 \
AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \
AND a.rule_type != b.rule_type"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let pattern: String = row.get("apattern");
let artype: String = row.get("artype");
let brtype: String = row.get("brtype");
let ttype: String = row.get("ttype");
let tid: Option<String> = row.get("tid");
conflicts.push(PolicyConflict {
conflict_type: "web_filter_allow_vs_block".to_string(),
severity: "high".to_string(),
description: format!("URL '{}' 同时被 {} 和 {},互斥", pattern, artype, brtype),
policies: vec![
ConflictPolicyRef {
table_name: "web_filter_rules".to_string(),
row_id: row.get("aid"),
name: format!("{}: {}", artype, pattern),
target_type: ttype.clone(),
target_id: tid.clone(),
},
ConflictPolicyRef {
table_name: "web_filter_rules".to_string(),
row_id: row.get("bid"),
name: format!("{}: {}", brtype, pattern),
target_type: ttype,
target_id: tid,
},
],
});
}
}
// 4. Clipboard: same source/target process, allow vs block
if let Ok(rows) = sqlx::query(
"SELECT a.id as aid, a.rule_type as artype, a.source_process as asrc, a.target_process as adst, \
b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \
FROM clipboard_rules a JOIN clipboard_rules b ON a.id < b.id \
WHERE a.enabled = 1 AND b.enabled = 1 \
AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \
AND a.direction = b.direction \
AND COALESCE(a.source_process,'') = COALESCE(b.source_process,'') \
AND COALESCE(a.target_process,'') = COALESCE(b.target_process,'') \
AND a.rule_type != b.rule_type"
)
.fetch_all(&state.db)
.await
{
for row in &rows {
let asrc: Option<String> = row.get("asrc");
let adst: Option<String> = row.get("adst");
let artype: String = row.get("artype");
let brtype: String = row.get("brtype");
let desc = format!(
"剪贴板规则冲突: {}{} 同时存在 {}{}",
asrc.as_deref().unwrap_or("*"),
adst.as_deref().unwrap_or("*"),
artype, brtype,
);
let ttype: String = row.get("ttype");
let tid: Option<String> = row.get("tid");
conflicts.push(PolicyConflict {
conflict_type: "clipboard_allow_vs_block".to_string(),
severity: "medium".to_string(),
description: desc,
policies: vec![
ConflictPolicyRef {
table_name: "clipboard_rules".to_string(),
row_id: row.get("aid"),
name: format!("{}: {}", artype, asrc.as_deref().unwrap_or("*")),
target_type: ttype.clone(),
target_id: tid.clone(),
},
ConflictPolicyRef {
table_name: "clipboard_rules".to_string(),
row_id: row.get("bid"),
name: format!("{}: {}", brtype, asrc.as_deref().unwrap_or("*")),
target_type: ttype,
target_id: tid,
},
],
});
}
}
// 5. Plugin disabled but has active rules
let plugin_tables: [(&str, &str, &str, &str); 4] = [
("web_filter_rules", "上网行为过滤", "web_filter", "SELECT COUNT(*) FROM web_filter_rules WHERE enabled = 1"),
("software_blacklist", "软件黑名单", "software_blocker", "SELECT COUNT(*) FROM software_blacklist WHERE enabled = 1"),
("popup_filter_rules", "弹窗拦截", "popup_blocker", "SELECT COUNT(*) FROM popup_filter_rules WHERE enabled = 1"),
("clipboard_rules", "剪贴板管控", "clipboard_control", "SELECT COUNT(*) FROM clipboard_rules WHERE enabled = 1"),
];
for (_table, label, plugin, query) in &plugin_tables {
let active_count: i64 = sqlx::query_scalar(query)
.fetch_one(&state.db)
.await
.unwrap_or(0);
if active_count > 0 {
let disabled: bool = sqlx::query_scalar::<_, i32>(
"SELECT COUNT(*) FROM plugin_state WHERE plugin_name = ? AND enabled = 0"
)
.bind(plugin)
.fetch_one(&state.db)
.await
.unwrap_or(0) > 0;
if disabled {
conflicts.push(PolicyConflict {
conflict_type: "plugin_disabled_with_rules".to_string(),
severity: "low".to_string(),
description: format!("插件 '{}' 已禁用,但仍有 {} 条启用规则,规则不会生效", label, active_count),
policies: vec![ConflictPolicyRef {
table_name: "plugin_state".to_string(),
row_id: 0,
name: plugin.to_string(),
target_type: "global".to_string(),
target_id: None,
}],
});
}
}
}
Json(ApiResponse::ok(serde_json::json!({
"conflicts": conflicts,
"total": conflicts.len(),
"critical_count": conflicts.iter().filter(|c| c.severity == "critical").count(),
"high_count": conflicts.iter().filter(|c| c.severity == "high").count(),
"medium_count": conflicts.iter().filter(|c| c.severity == "medium").count(),
"low_count": conflicts.iter().filter(|c| c.severity == "low").count(),
})))
}

View File

@@ -4,6 +4,28 @@ use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
/// GET /api/devices/:uid/health-score
pub async fn get_health_score(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
match crate::health::get_device_score(&state.db, &uid).await {
Ok(Some(score)) => Json(ApiResponse::ok(score)),
Ok(None) => Json(ApiResponse::error("No health score available")),
Err(e) => Json(ApiResponse::internal_error("health score", e)),
}
}
/// GET /api/dashboard/health-overview
pub async fn health_overview(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
match crate::health::get_health_overview(&state.db).await {
Ok(overview) => Json(ApiResponse::ok(overview)),
Err(e) => Json(ApiResponse::internal_error("health overview", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct DeviceListParams {
pub status: Option<String>,
@@ -26,6 +48,10 @@ pub struct DeviceRow {
pub last_heartbeat: Option<String>,
pub registered_at: Option<String>,
pub group_name: Option<String>,
#[sqlx(default)]
pub health_score: Option<i32>,
#[sqlx(default)]
pub health_level: Option<String>,
}
pub async fn list(
@@ -41,13 +67,16 @@ pub async fn list(
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let devices = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
"SELECT d.id, d.device_uid, d.hostname, d.ip_address, d.mac_address, d.os_version, d.client_version,
d.status, d.last_heartbeat, d.registered_at, d.group_name,
h.score as health_score, h.level as health_level
FROM devices d
LEFT JOIN device_health_scores h ON h.device_uid = d.device_uid
WHERE 1=1
AND (? IS NULL OR d.status = ?)
AND (? IS NULL OR d.group_name = ?)
AND (? IS NULL OR d.hostname LIKE '%' || ? || '%' OR d.ip_address LIKE '%' || ? || '%')
ORDER BY d.registered_at DESC LIMIT ? OFFSET ?"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
@@ -187,16 +216,6 @@ pub async fn remove(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<()>> {
// If client is connected, send self-destruct command
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
// Delete device and all associated data in a transaction
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
@@ -224,6 +243,8 @@ pub async fn remove(
// Delete plugin-related data
let cleanup_tables = [
"hardware_assets",
"software_assets",
"asset_changes",
"usb_events",
"usb_file_operations",
"usage_daily",
@@ -231,8 +252,20 @@ pub async fn remove(
"software_violations",
"web_access_log",
"popup_block_stats",
"disk_encryption_status",
"disk_encryption_alerts",
"print_events",
"clipboard_violations",
"behavior_metrics",
"anomaly_alerts",
"device_health_scores",
"patch_status",
];
for table in &cleanup_tables {
// Safety: table names are hardcoded constants above, not user input.
// Parameterized ? is used for device_uid.
debug_assert!(table.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
"BUG: table name contains unexpected characters: {}", table);
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
.bind(&uid)
.execute(&mut *tx)
@@ -253,6 +286,17 @@ pub async fn remove(
if let Err(e) = tx.commit().await {
return Json(ApiResponse::internal_error("commit device deletion", e));
}
// Send self-destruct command AFTER successful commit
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
state.clients.unregister(&uid).await;
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
Json(ApiResponse::ok(()))

View File

@@ -50,6 +50,9 @@ pub async fn create_group(
if name.is_empty() || name.len() > 50 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效"))));
}
if name.contains('<') || name.contains('>') || name.contains('"') || name.contains('\'') {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符"))));
}
// Check if group already exists
let exists: bool = sqlx::query_scalar(
@@ -78,6 +81,9 @@ pub async fn rename_group(
if new_name.is_empty() || new_name.len() > 50 {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效"))));
}
if new_name.contains('<') || new_name.contains('>') || new_name.contains('"') || new_name.contains('\'') {
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符"))));
}
let result = sqlx::query(
"UPDATE devices SET group_name = ? WHERE group_name = ?"

View File

@@ -1,4 +1,4 @@
use axum::{routing::{get, post, put, delete}, Router, Json, middleware};
use axum::{routing::{get, post, put, delete}, Router, Json, middleware, http::StatusCode, response::IntoResponse};
use serde::{Deserialize, Serialize};
use crate::AppState;
@@ -9,23 +9,31 @@ pub mod usb;
pub mod alerts;
pub mod plugins;
pub mod groups;
pub mod conflict;
pub fn routes(state: AppState) -> Router<AppState> {
let public = Router::new()
.route("/api/auth/login", post(auth::login))
.route("/api/auth/refresh", post(auth::refresh))
.route("/api/auth/logout", post(auth::logout))
.route("/health", get(health_check))
.with_state(state.clone());
// Read-only routes (any authenticated user)
let read_routes = Router::new()
// Auth
.route("/api/auth/me", get(auth::me))
.route("/api/auth/change-password", put(auth::change_password))
// WebSocket ticket (requires auth cookie)
.route("/api/ws/ticket", post(auth::create_ws_ticket))
// Devices
.route("/api/devices", get(devices::list))
.route("/api/devices/:uid", get(devices::get_detail))
.route("/api/devices/:uid/status", get(devices::get_status))
.route("/api/devices/:uid/history", get(devices::get_history))
.route("/api/devices/:uid/health-score", get(devices::get_health_score))
// Dashboard
.route("/api/dashboard/health-overview", get(devices::health_overview))
// Assets
.route("/api/assets/hardware", get(assets::list_hardware))
.route("/api/assets/software", get(assets::list_software))
@@ -40,6 +48,8 @@ pub fn routes(state: AppState) -> Router<AppState> {
.route("/api/alerts/records", get(alerts::list_records))
// Plugin read routes
.merge(plugins::read_routes())
// Policy conflict scan
.route("/api/policies/conflicts", get(conflict::scan_conflicts))
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// Write routes (admin only)
@@ -50,6 +60,8 @@ pub fn routes(state: AppState) -> Router<AppState> {
.route("/api/groups", post(groups::create_group))
.route("/api/groups/:name", put(groups::rename_group).delete(groups::delete_group))
.route("/api/devices/:uid/group", put(groups::move_device))
// TLS cert rotation
.route("/api/system/tls-rotate", post(system_tls_rotate))
// USB (write)
.route("/api/usb/policies", post(usb::create_policy))
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
@@ -76,6 +88,45 @@ pub fn routes(state: AppState) -> Router<AppState> {
.merge(ws_router)
}
/// Trigger TLS certificate rotation for all online devices.
/// Admin sends the new certificate PEM and a transition deadline.
/// The server pushes a ConfigUpdate(TlsCertRotate) to all connected clients.
#[derive(Deserialize)]
struct TlsRotateRequest {
/// Path to the new certificate PEM file
cert_path: String,
/// ISO 8601 timestamp when the old cert stops being valid (transition deadline)
valid_until: String,
}
#[derive(Serialize)]
struct TlsRotateResponse {
devices_notified: usize,
}
async fn system_tls_rotate(
axum::extract::State(state): axum::extract::State<AppState>,
Json(req): Json<TlsRotateRequest>,
) -> impl IntoResponse {
let cert_pem = match tokio::fs::read(&req.cert_path).await {
Ok(pem) => pem,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(ApiResponse::<TlsRotateResponse>::error(
format!("Cannot read cert file {}: {}", req.cert_path, e),
)),
).into_response();
}
};
let count = crate::tcp::push_tls_cert_rotation(&state.clients, &cert_pem, &req.valid_until).await;
(StatusCode::OK, Json(ApiResponse::ok(TlsRotateResponse {
devices_notified: count,
}))).into_response()
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,

View File

@@ -0,0 +1,48 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::Deserialize;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct AnomalyListParams {
pub device_uid: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
/// GET /api/plugins/anomaly/alerts
pub async fn list_anomaly_alerts(
State(state): State<AppState>,
Query(params): Query<AnomalyListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let page = params.page.unwrap_or(1);
let page_size = params.page_size.unwrap_or(20).min(100);
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty());
match crate::anomaly::get_anomaly_summary(&state.db, device_uid, page, page_size).await {
Ok(result) => Json(ApiResponse::ok(result)),
Err(e) => Json(ApiResponse::internal_error("anomaly alerts", e)),
}
}
/// PUT /api/plugins/anomaly/alerts/:id/handle
/// Mark an anomaly alert as handled.
pub async fn handle_anomaly_alert(
State(state): State<AppState>,
Path(id): Path<i64>,
claims: axum::Extension<crate::api::auth::Claims>,
) -> Json<ApiResponse<()>> {
let result = sqlx::query(
"UPDATE anomaly_alerts SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
)
.bind(&claims.username)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => Json(ApiResponse::ok(())),
Ok(_) => Json(ApiResponse::error("Alert not found")),
Err(e) => Json(ApiResponse::internal_error("handle anomaly alert", e)),
}
}

View File

@@ -21,7 +21,7 @@ pub struct CreateRuleRequest {
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled, updated_at \
FROM clipboard_rules ORDER BY updated_at DESC"
FROM clipboard_rules ORDER BY updated_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await
@@ -127,6 +127,18 @@ pub async fn update_rule(
let content_pattern = body.content_pattern.or_else(|| existing.get::<Option<String>, _>("content_pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate merged values
if let Some(ref rt) = rule_type {
if !matches!(rt.as_str(), "allow" | "block") {
return Json(ApiResponse::error("rule_type must be 'allow' or 'block'"));
}
}
if let Some(ref d) = direction {
if !matches!(d.as_str(), "in" | "out" | "both") {
return Json(ApiResponse::error("direction must be 'in', 'out', or 'both'"));
}
}
let result = sqlx::query(
"UPDATE clipboard_rules SET rule_type = ?, direction = ?, source_process = ?, target_process = ?, content_pattern = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)

View File

@@ -28,7 +28,7 @@ pub async fn list_status(
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
ORDER BY s.device_uid, s.drive_letter"
ORDER BY s.device_uid, s.drive_letter LIMIT 500"
)
.fetch_all(&state.db)
.await
@@ -58,7 +58,7 @@ pub async fn list_alerts(State(state): State<AppState>) -> Json<ApiResponse<serd
match sqlx::query(
"SELECT a.id, a.device_uid, a.drive_letter, a.alert_type, a.status, a.created_at, a.resolved_at, \
d.hostname FROM encryption_alerts a LEFT JOIN devices d ON a.device_uid = d.device_uid \
ORDER BY a.created_at DESC"
ORDER BY a.created_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await

View File

@@ -8,6 +8,8 @@ pub mod disk_encryption;
pub mod print_audit;
pub mod clipboard_control;
pub mod plugin_control;
pub mod patch;
pub mod anomaly;
use axum::{Router, routing::{get, post, put}};
use crate::AppState;
@@ -25,6 +27,7 @@ pub fn read_routes() -> Router<AppState> {
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
.route("/api/plugins/software-blocker/whitelist", get(software_blocker::list_whitelist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
@@ -36,7 +39,6 @@ pub fn read_routes() -> Router<AppState> {
// Disk Encryption
.route("/api/plugins/disk-encryption/status", get(disk_encryption::list_status))
.route("/api/plugins/disk-encryption/alerts", get(disk_encryption::list_alerts))
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
// Print Audit
.route("/api/plugins/print-audit/events", get(print_audit::list_events))
.route("/api/plugins/print-audit/events/:id", get(print_audit::get_event))
@@ -45,6 +47,12 @@ pub fn read_routes() -> Router<AppState> {
.route("/api/plugins/clipboard-control/violations", get(clipboard_control::list_violations))
// Plugin Control
.route("/api/plugins/control", get(plugin_control::list_plugins))
// Patch Management
.route("/api/plugins/patch/status", get(patch::list_patch_status))
.route("/api/plugins/patch/summary", get(patch::patch_summary))
.route("/api/plugins/patch/device/:uid", get(patch::device_patches))
// Anomaly Detection
.route("/api/plugins/anomaly/alerts", get(anomaly::list_anomaly_alerts))
}
/// Write plugin routes (admin only — require_admin middleware applied by caller)
@@ -56,6 +64,8 @@ pub fn write_routes() -> Router<AppState> {
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
.route("/api/plugins/software-blocker/whitelist", post(software_blocker::add_to_whitelist))
.route("/api/plugins/software-blocker/whitelist/:id", put(software_blocker::update_whitelist).delete(software_blocker::remove_from_whitelist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
@@ -65,6 +75,10 @@ pub fn write_routes() -> Router<AppState> {
// Clipboard Control
.route("/api/plugins/clipboard-control/rules", post(clipboard_control::create_rule))
.route("/api/plugins/clipboard-control/rules/:id", put(clipboard_control::update_rule).delete(clipboard_control::delete_rule))
// Disk Encryption
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
// Plugin Control (enable/disable)
.route("/api/plugins/control/:plugin_name", put(plugin_control::set_plugin_state))
// Anomaly Detection — handle alert
.route("/api/plugins/anomaly/alerts/:id/handle", put(anomaly::handle_anomaly_alert))
}

View File

@@ -0,0 +1,146 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct PatchListParams {
pub device_uid: Option<String>,
pub severity: Option<String>,
#[allow(dead_code)]
pub installed: Option<i32>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
/// GET /api/plugins/patch/status
pub async fn list_patch_status(
State(state): State<AppState>,
Query(params): Query<PatchListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty());
let severity = params.severity.as_deref().filter(|s| !s.is_empty());
let rows = sqlx::query(
"SELECT p.*, d.hostname FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \
WHERE 1=1 \
AND (? IS NULL OR p.device_uid = ?) \
AND (? IS NULL OR p.severity = ?) \
ORDER BY p.updated_at DESC LIMIT ? OFFSET ?"
)
.bind(device_uid).bind(device_uid)
.bind(severity).bind(severity)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"kb_id": r.get::<String, _>("kb_id"),
"title": r.get::<String, _>("title"),
"severity": r.get::<Option<String>, _>("severity"),
"is_installed": r.get::<i32, _>("is_installed"),
"installed_at": r.get::<Option<String>, _>("installed_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
// Summary stats (scoped to same filters as main query)
let total_installed: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM patch_status WHERE is_installed = 1 \
AND (? IS NULL OR device_uid = ?) \
AND (? IS NULL OR severity = ?)"
)
.bind(device_uid).bind(device_uid)
.bind(severity).bind(severity)
.fetch_one(&state.db).await.unwrap_or(0);
let total_missing: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM patch_status WHERE is_installed = 0 \
AND (? IS NULL OR device_uid = ?) \
AND (? IS NULL OR severity = ?)"
)
.bind(device_uid).bind(device_uid)
.bind(severity).bind(severity)
.fetch_one(&state.db).await.unwrap_or(0);
Json(ApiResponse::ok(serde_json::json!({
"patches": items,
"summary": {
"total_installed": total_installed,
"total_missing": total_missing,
},
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query patch status", e)),
}
}
/// GET /api/plugins/patch/summary — per-device patch summary
pub async fn patch_summary(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT p.device_uid, d.hostname, \
COUNT(*) as total_patches, \
SUM(CASE WHEN p.is_installed = 1 THEN 1 ELSE 0 END) as installed, \
SUM(CASE WHEN p.is_installed = 0 THEN 1 ELSE 0 END) as missing, \
MAX(p.updated_at) as last_scan \
FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \
GROUP BY p.device_uid ORDER BY missing DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let devices: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"total_patches": r.get::<i64, _>("total_patches"),
"installed": r.get::<i64, _>("installed"),
"missing": r.get::<i64, _>("missing"),
"last_scan": r.get::<String, _>("last_scan"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({ "devices": devices })))
}
Err(e) => Json(ApiResponse::internal_error("patch summary", e)),
}
}
/// GET /api/plugins/patch/device/:uid — patches for a single device
pub async fn device_patches(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT kb_id, title, severity, is_installed, installed_at, updated_at \
FROM patch_status WHERE device_uid = ? ORDER BY updated_at DESC"
)
.bind(&uid)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let patches: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"kb_id": r.get::<String, _>("kb_id"),
"title": r.get::<String, _>("title"),
"severity": r.get::<Option<String>, _>("severity"),
"is_installed": r.get::<i32, _>("is_installed"),
"installed_at": r.get::<Option<String>, _>("installed_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({ "patches": patches })))
}
Err(e) => Json(ApiResponse::internal_error("device patches", e)),
}
}

View File

@@ -17,7 +17,7 @@ pub struct CreateRuleRequest {
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
@@ -47,6 +47,16 @@ pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRu
if !has_filter {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
}
// Length validation for filter fields
if let Some(ref t) = req.window_title {
if t.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_title too long (max 255)"))); }
}
if let Some(ref c) = req.window_class {
if c.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_class too long (max 255)"))); }
}
if let Some(ref p) = req.process_name {
if p.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("process_name too long (max 255)"))); }
}
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
@@ -81,6 +91,14 @@ pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Jso
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Ensure at least one filter is non-empty after update
let has_filter = window_title.as_ref().map_or(false, |s| !s.is_empty())
|| window_class.as_ref().map_or(false, |s| !s.is_empty())
|| process_name.as_ref().map_or(false, |s| !s.is_empty());
if !has_filter {
return Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required"));
}
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
.bind(&window_title)
.bind(&window_class)

View File

@@ -1,7 +1,7 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use csm_protocol::{Frame, MessageType};
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
@@ -16,7 +16,7 @@ pub struct CreateBlacklistRequest {
}
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
@@ -53,8 +53,8 @@ pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<Cre
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
let payload = fetch_software_payload_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
@@ -80,6 +80,14 @@ pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Input validation (same as create)
if name_pattern.trim().is_empty() || name_pattern.len() > 255 {
return Json(ApiResponse::error("name_pattern must be 1-255 chars"));
}
if !matches!(action.as_str(), "block" | "alert") {
return Json(ApiResponse::error("action must be 'block' or 'alert'"));
}
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern)
.bind(&action)
@@ -92,8 +100,8 @@ pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
let payload = fetch_software_payload_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
@@ -110,8 +118,8 @@ pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path
};
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
let payload = fetch_software_payload_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
@@ -134,6 +142,29 @@ pub async fn list_violations(State(state): State<AppState>, Query(f): Query<Viol
}
}
/// Build the payload for pushing software control config to clients.
/// Includes both blacklist (scoped by target) and whitelist (global).
async fn fetch_software_payload_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> serde_json::Value {
let blacklist = fetch_blacklist_for_push(db, target_type, target_id).await;
// Whitelist is always global — fetch all enabled entries
let whitelist: Vec<String> = sqlx::query_scalar(
"SELECT name_pattern FROM software_whitelist WHERE enabled = 1"
)
.fetch_all(db)
.await
.unwrap_or_default();
serde_json::json!({
"blacklist": blacklist,
"whitelist": whitelist,
})
}
async fn fetch_blacklist_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
@@ -156,3 +187,112 @@ async fn fetch_blacklist_for_push(
})).collect())
.unwrap_or_default()
}
// ─── Whitelist management ───
/// GET /api/plugins/software-blocker/whitelist
pub async fn list_whitelist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, reason, is_builtin, enabled, created_at FROM software_whitelist ORDER BY is_builtin DESC, created_at ASC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"whitelist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"),
"name_pattern": r.get::<String,_>("name_pattern"),
"reason": r.get::<Option<String>,_>("reason"),
"is_builtin": r.get::<bool,_>("is_builtin"),
"enabled": r.get::<bool,_>("enabled"),
"created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software whitelist", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreateWhitelistRequest {
pub name_pattern: String,
pub reason: Option<String>,
}
/// POST /api/plugins/software-blocker/whitelist
pub async fn add_to_whitelist(State(state): State<AppState>, Json(req): Json<CreateWhitelistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
}
match sqlx::query("INSERT INTO software_whitelist (name_pattern, reason) VALUES (?, ?)")
.bind(&req.name_pattern).bind(&req.reason)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push updated whitelist to all online clients
push_whitelist_to_all(&state).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add whitelist entry", e))),
}
}
/// PUT /api/plugins/software-blocker/whitelist/:id
#[derive(Debug, Deserialize)]
pub struct UpdateWhitelistRequest {
pub name_pattern: Option<String>,
pub enabled: Option<bool>,
}
pub async fn update_whitelist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateWhitelistRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT name_pattern, enabled FROM software_whitelist WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query whitelist", e)),
};
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Input validation — validate merged value
if name_pattern.trim().is_empty() || name_pattern.len() > 255 {
return Json(ApiResponse::error("name_pattern must be 1-255 chars"));
}
match sqlx::query("UPDATE software_whitelist SET name_pattern = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern).bind(enabled).bind(id)
.execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
push_whitelist_to_all(&state).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update whitelist", e)),
}
}
/// DELETE /api/plugins/software-blocker/whitelist/:id
pub async fn remove_from_whitelist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
match sqlx::query("DELETE FROM software_whitelist WHERE id = ? AND is_builtin = 0")
.bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
push_whitelist_to_all(&state).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found or is built-in entry")),
Err(e) => Json(ApiResponse::internal_error("remove whitelist entry", e)),
}
}
/// Push updated whitelist to all online clients by resending the full software control config.
async fn push_whitelist_to_all(state: &AppState) {
// Fetch payload once, then broadcast to all online clients
let payload = fetch_software_payload_for_push(&state.db, "global", None).await;
let frame = match Frame::new_json(MessageType::SoftwareBlacklist, &payload) {
Ok(f) => f.encode(),
Err(_) => return,
};
let online = state.clients.list_online().await;
for uid in &online {
state.clients.send_to(uid, frame.clone()).await;
}
tracing::info!("Pushed updated whitelist to {} online clients", online.len());
}

View File

@@ -16,7 +16,7 @@ pub struct CreateRuleRequest {
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC LIMIT 500")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
@@ -75,6 +75,14 @@ pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Jso
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate merged values
if !matches!(rule_type.as_str(), "blacklist" | "whitelist" | "category") {
return Json(ApiResponse::error("rule_type must be 'blacklist', 'whitelist', or 'category'"));
}
if pattern.trim().is_empty() || pattern.len() > 255 {
return Json(ApiResponse::error("pattern must be 1-255 chars"));
}
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
.bind(&rule_type)
.bind(&pattern)
@@ -114,17 +122,42 @@ pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) ->
}
#[derive(Debug, Deserialize)]
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
pub struct LogFilters {
pub device_uid: Option<String>,
pub action: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>() }))),
let limit = f.page_size.unwrap_or(20).min(100);
let offset = f.page.unwrap_or(1).saturating_sub(1) * limit;
let device_uid = f.device_uid.as_deref().filter(|s| !s.is_empty());
let action = f.action.as_deref().filter(|s| !s.is_empty());
let rows = sqlx::query(
"SELECT id, device_uid, url, action, timestamp FROM web_access_log \
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) \
ORDER BY timestamp DESC LIMIT ? OFFSET ?"
)
.bind(device_uid).bind(device_uid)
.bind(action).bind(action)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => Json(ApiResponse::ok(serde_json::json!({
"log": records.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"),
"device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"),
"action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>(),
"page": f.page.unwrap_or(1),
"page_size": limit,
}))),
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
}
}

View File

@@ -66,7 +66,7 @@ pub async fn list_policies(
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
FROM usb_policies ORDER BY created_at DESC"
FROM usb_policies ORDER BY created_at DESC LIMIT 500"
)
.fetch_all(&state.db)
.await;
@@ -106,6 +106,11 @@ pub async fn create_policy(
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = body.enabled.unwrap_or(1);
// Input validation
if body.name.trim().is_empty() || body.name.len() > 100 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name must be 1-100 chars")));
}
let result = sqlx::query(
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
)

315
crates/server/src/health.rs Normal file
View File

@@ -0,0 +1,315 @@
use crate::AppState;
use sqlx::Row;
use tracing::{info, error};
/// Background task: recompute device health scores every 5 minutes
pub async fn health_score_task(state: AppState) {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300));
// First computation runs immediately, then every 5 minutes
loop {
interval.tick().await;
if let Err(e) = recompute_all_scores(&state).await {
error!("Health score computation failed: {}", e);
}
}
}
async fn recompute_all_scores(state: &AppState) -> anyhow::Result<()> {
// Get all device UIDs
let devices: Vec<String> = sqlx::query_scalar(
"SELECT device_uid FROM devices"
)
.fetch_all(&state.db)
.await?;
let mut computed = 0u32;
let mut errors = 0u32;
for uid in &devices {
match compute_and_store_score(&state.db, uid).await {
Ok(score) => {
computed += 1;
tracing::debug!("Health score for {}: {} ({})", uid, score.score, score.level);
}
Err(e) => {
errors += 1;
error!("Failed to compute health score for {}: {}", uid, e);
}
}
}
if computed > 0 {
info!("Health scores computed: {} devices, {} errors", computed, errors);
}
Ok(())
}
struct HealthScoreResult {
score: i32,
status_score: i32,
encryption_score: i32,
load_score: i32,
alert_score: i32,
compliance_score: i32,
patch_score: i32,
level: String,
details: String,
}
async fn compute_and_store_score(
pool: &sqlx::SqlitePool,
device_uid: &str,
) -> anyhow::Result<HealthScoreResult> {
let mut details = Vec::new();
// 1. Online status (15 points)
let status_score: i32 = sqlx::query_scalar(
"SELECT CASE WHEN status = 'online' THEN 15 ELSE 0 END FROM devices WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
if status_score < 15 {
details.push("设备离线".to_string());
}
// 2. Disk encryption (20 points)
let encryption_score: i32 = sqlx::query_scalar(
"SELECT CASE \
WHEN COUNT(*) = 0 THEN 10 \
WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) = COUNT(*) THEN 20 \
WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) > 0 THEN 10 \
ELSE 0 END \
FROM disk_encryption_status WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
if encryption_score < 20 {
let unencrypted: Vec<String> = sqlx::query_scalar(
"SELECT drive_letter FROM disk_encryption_status WHERE device_uid = ? AND protection_status != 'On'"
)
.bind(device_uid)
.fetch_all(pool)
.await
.unwrap_or_default();
if unencrypted.is_empty() && encryption_score < 20 {
details.push("未检测到加密状态".to_string());
} else if !unencrypted.is_empty() {
details.push(format!("未加密驱动器: {}", unencrypted.join(", ")));
}
}
// 3. System load (20 points): CPU(7) + Memory(7) + Disk(6)
let load_row = sqlx::query(
"SELECT cpu_usage, memory_usage, disk_usage FROM device_status WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_optional(pool)
.await?;
let load_score = if let Some(row) = load_row {
let cpu = row.get::<f64, _>("cpu_usage");
let mem = row.get::<f64, _>("memory_usage");
let disk = row.get::<f64, _>("disk_usage");
let cpu_pts = if cpu < 70.0 { 7 } else if cpu < 90.0 { 4 } else { 0 };
let mem_pts = if mem < 80.0 { 7 } else if mem < 95.0 { 4 } else { 0 };
let disk_pts = if disk < 80.0 { 6 } else if disk < 95.0 { 3 } else { 0 };
let total = cpu_pts + mem_pts + disk_pts;
if cpu >= 90.0 { details.push(format!("CPU过高 ({:.0}%)", cpu)); }
if mem >= 95.0 { details.push(format!("内存过高 ({:.0}%)", mem)); }
if disk >= 95.0 { details.push(format!("磁盘空间不足 ({:.0}%)", disk)); }
total
} else {
details.push("无状态数据".to_string());
0
};
// 4. Alert clearance (15 points)
let unhandled_alerts: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM alert_records WHERE device_uid = ? AND handled = 0"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
let alert_score: i32 = if unhandled_alerts == 0 { 15 } else { 0 };
if unhandled_alerts > 0 {
details.push(format!("{}条未处理告警", unhandled_alerts));
}
// 5. Compliance (10 points): no recent software violations
let recent_violations: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM software_violations WHERE device_uid = ? AND timestamp > datetime('now', '-7 days')"
)
.bind(device_uid)
.fetch_one(pool)
.await
.unwrap_or(0);
let compliance_score: i32 = if recent_violations == 0 { 10 } else {
details.push(format!("近期{}次软件违规", recent_violations));
(10 - (recent_violations as i32).min(10)).max(0)
};
// 6. Patch status (20 points): reserved for future patch management
// For now, give full score if device is online
let patch_score: i32 = if status_score > 0 { 20 } else { 10 };
let score = status_score + encryption_score + load_score + alert_score + compliance_score + patch_score;
let level = if score >= 80 {
"healthy"
} else if score >= 50 {
"warning"
} else if score > 0 {
"critical"
} else {
"unknown"
};
let details_json = if details.is_empty() {
"[]".to_string()
} else {
serde_json::to_string(&details).unwrap_or_else(|_| "[]".to_string())
};
let result = HealthScoreResult {
score,
status_score,
encryption_score,
load_score,
alert_score,
compliance_score,
patch_score,
level: level.to_string(),
details: details_json,
};
// Upsert the score
sqlx::query(
"INSERT INTO device_health_scores \
(device_uid, score, status_score, encryption_score, load_score, alert_score, compliance_score, patch_score, level, details, computed_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) \
ON CONFLICT(device_uid) DO UPDATE SET \
score = excluded.score, status_score = excluded.status_score, \
encryption_score = excluded.encryption_score, load_score = excluded.load_score, \
alert_score = excluded.alert_score, compliance_score = excluded.compliance_score, \
patch_score = excluded.patch_score, level = excluded.level, \
details = excluded.details, computed_at = datetime('now')"
)
.bind(device_uid)
.bind(result.score)
.bind(result.status_score)
.bind(result.encryption_score)
.bind(result.load_score)
.bind(result.alert_score)
.bind(result.compliance_score)
.bind(result.patch_score)
.bind(&result.level)
.bind(&result.details)
.execute(pool)
.await?;
Ok(result)
}
/// Compute a single device's health score on demand
pub async fn get_device_score(
pool: &sqlx::SqlitePool,
device_uid: &str,
) -> anyhow::Result<Option<serde_json::Value>> {
// Try to get cached score
let row = sqlx::query(
"SELECT score, status_score, encryption_score, load_score, alert_score, compliance_score, \
patch_score, level, details, computed_at \
FROM device_health_scores WHERE device_uid = ?"
)
.bind(device_uid)
.fetch_optional(pool)
.await?;
match row {
Some(r) => Ok(Some(serde_json::json!({
"device_uid": device_uid,
"score": r.get::<i32, _>("score"),
"breakdown": {
"status": r.get::<i32, _>("status_score"),
"encryption": r.get::<i32, _>("encryption_score"),
"load": r.get::<i32, _>("load_score"),
"alerts": r.get::<i32, _>("alert_score"),
"compliance": r.get::<i32, _>("compliance_score"),
"patches": r.get::<i32, _>("patch_score"),
},
"level": r.get::<String, _>("level"),
"details": serde_json::from_str::<serde_json::Value>(
&r.get::<String, _>("details")
).unwrap_or(serde_json::json!([])),
"computed_at": r.get::<String, _>("computed_at"),
}))),
None => Ok(None),
}
}
/// Get health overview for all devices (dashboard aggregation)
pub async fn get_health_overview(pool: &sqlx::SqlitePool) -> anyhow::Result<serde_json::Value> {
let rows = sqlx::query(
"SELECT h.device_uid, h.score, h.level, d.hostname, d.status, d.group_name \
FROM device_health_scores h \
JOIN devices d ON d.device_uid = h.device_uid \
ORDER BY h.score ASC"
)
.fetch_all(pool)
.await?;
let mut healthy = 0u32;
let mut warning = 0u32;
let mut critical = 0u32;
let mut unknown = 0u32;
let mut total_score = 0i64;
let mut devices: Vec<serde_json::Value> = Vec::with_capacity(rows.len());
for r in &rows {
let level: String = r.get("level");
match level.as_str() {
"healthy" => healthy += 1,
"warning" => warning += 1,
"critical" => critical += 1,
_ => unknown += 1,
}
total_score += r.get::<i32, _>("score") as i64;
devices.push(serde_json::json!({
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<String, _>("hostname"),
"status": r.get::<String, _>("status"),
"group_name": r.get::<String, _>("group_name"),
"score": r.get::<i32, _>("score"),
"level": level,
}));
}
let total = devices.len().max(1);
let avg_score = total_score as f64 / total as f64;
Ok(serde_json::json!({
"summary": {
"total": total,
"healthy": healthy,
"warning": warning,
"critical": critical,
"unknown": unknown,
"avg_score": (avg_score * 10.0).round() / 10.0,
},
"devices": devices,
}))
}

View File

@@ -7,6 +7,7 @@ use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::collections::HashMap;
use tokio::net::TcpListener;
use axum::http::Method as HttpMethod;
use tower_http::cors::CorsLayer;
@@ -23,6 +24,8 @@ mod db;
mod tcp;
mod ws;
mod alert;
mod health;
mod anomaly;
use config::AppConfig;
@@ -38,6 +41,7 @@ pub struct AppState {
pub clients: Arc<tcp::ClientRegistry>,
pub ws_hub: Arc<ws::WsHub>,
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
pub ws_tickets: Arc<tokio::sync::Mutex<HashMap<String, ws::TicketClaim>>>,
}
#[tokio::main]
@@ -58,15 +62,19 @@ async fn main() -> Result<()> {
// Security checks
if config.registration_token.is_empty() {
warn!("SECURITY: registration_token is empty — any device can register!");
anyhow::bail!("FATAL: registration_token is empty. Set it in config.toml or via CSM_REGISTRATION_TOKEN env var. Device registration is disabled for security.");
}
if config.auth.jwt_secret.len() < 32 {
warn!("SECURITY: jwt_secret is too short ({} chars) — consider using a 32+ byte key from CSM_JWT_SECRET env var", config.auth.jwt_secret.len());
if config.auth.jwt_secret.is_empty() || config.auth.jwt_secret.len() < 32 {
anyhow::bail!("FATAL: jwt_secret is missing or too short. Set CSM_JWT_SECRET env var with a 32+ byte random key.");
}
if config.server.tls.is_none() {
warn!("SECURITY: No TLS configured — all TCP communication is plaintext. Configure [server.tls] for production.");
if std::env::var("CSM_DEV").is_err() {
warn!("Set CSM_DEV=1 to suppress this warning in development environments.");
}
}
let config = Arc::new(config);
// Initialize database
let db = init_database(&config.database.path).await?;
run_migrations(&db).await?;
info!("Database initialized at {}", config.database.path);
@@ -84,6 +92,7 @@ async fn main() -> Result<()> {
clients: clients.clone(),
ws_hub: ws_hub.clone(),
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
ws_tickets: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
};
// Start background tasks
@@ -92,6 +101,12 @@ async fn main() -> Result<()> {
alert::cleanup_task(cleanup_state).await;
});
// Health score computation task
let health_state = state.clone();
tokio::spawn(async move {
health::health_score_task(health_state).await;
});
// Start TCP listener for client connections
let tcp_state = state.clone();
let tcp_addr = config.server.tcp_addr.clone();
@@ -131,7 +146,11 @@ async fn main() -> Result<()> {
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("content-security-policy"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' wss:; font-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("permissions-policy"),
axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=(), payment=()"),
))
.with_state(state);
@@ -197,6 +216,9 @@ async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
include_str!("../../../migrations/014_clipboard_control.sql"),
include_str!("../../../migrations/015_plugin_control.sql"),
include_str!("../../../migrations/016_encryption_alerts_unique.sql"),
include_str!("../../../migrations/017_device_health_scores.sql"),
include_str!("../../../migrations/018_patch_management.sql"),
include_str!("../../../migrations/019_software_whitelist.sql"),
];
// Create migrations tracking table
@@ -257,11 +279,27 @@ async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
.await?;
warn!("Created default admin user (username: admin)");
// Print password directly to stderr — bypasses tracing JSON formatter
eprintln!("============================================================");
eprintln!(" Generated admin password: {}", random_password);
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
eprintln!("============================================================");
// Write password to restricted file instead of stderr (avoid log capture)
let pw_path = std::path::Path::new("data/initial-password.txt");
if let Some(parent) = pw_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
match std::fs::write(pw_path, &random_password) {
Ok(_) => {
warn!("Initial admin password saved to data/initial-password.txt (delete after first login)");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(pw_path, std::fs::Permissions::from_mode(0o600));
}
#[cfg(not(unix))]
{
// Windows: restrict ACL would require windows-rs; at minimum hide the file
let _ = std::process::Command::new("attrib").args(["+H", &pw_path.to_string_lossy()]).output();
}
}
Err(e) => warn!("Failed to save initial password to file: {}. Password was: {}", e, random_password),
}
}
Ok(())
@@ -278,13 +316,14 @@ fn build_cors_layer(origins: &[String]) -> CorsLayer {
.collect();
if allowed_origins.is_empty() {
// No CORS — production safe by default
// No CORS — production safe by default (same-origin cookies work without CORS)
CorsLayer::new()
} else {
CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
.allow_methods([HttpMethod::GET, HttpMethod::POST, HttpMethod::PUT, HttpMethod::DELETE])
.allow_headers([axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE])
.allow_credentials(true)
.max_age(std::time::Duration::from_secs(3600))
}
}

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tracing::{info, warn, debug};
use hmac::{Hmac, Mac};
use sha2::Sha256;
@@ -167,7 +167,7 @@ pub async fn push_all_plugin_configs(
}
}
// Software blacklist
// Software blacklist + whitelist
if let Ok(rows) = sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
@@ -180,8 +180,20 @@ pub async fn push_all_plugin_configs(
"name_pattern": r.get::<String, _>("name_pattern"),
"action": r.get::<String, _>("action"),
})).collect();
if !entries.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
// Fetch whitelist (global, always pushed to all devices)
let whitelist: Vec<String> = sqlx::query_scalar(
"SELECT name_pattern FROM software_whitelist WHERE enabled = 1"
)
.fetch_all(db)
.await
.unwrap_or_default();
if !entries.is_empty() || !whitelist.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({
"blacklist": entries,
"whitelist": whitelist,
})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
@@ -261,17 +273,53 @@ pub async fn push_all_plugin_configs(
}
}
// Disk encryption config — push default reporting interval (no dedicated config table)
// Disk encryption config — read from patch_policies if available, else default
{
let config = csm_protocol::DiskEncryptionConfigPayload {
enabled: true,
report_interval_secs: 3600,
let config = if let Ok(Some(row)) = sqlx::query(
"SELECT auto_approve, enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1"
)
.fetch_optional(db)
.await
{
// If patch_policies exist, infer disk encryption should be enabled
csm_protocol::DiskEncryptionConfigPayload {
enabled: row.get::<i32, _>("enabled") != 0,
report_interval_secs: 3600,
}
} else {
csm_protocol::DiskEncryptionConfigPayload {
enabled: true,
report_interval_secs: 3600,
}
};
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
// Patch scan config — read from patch_policies if available, else default
{
let config = if let Ok(Some(row)) = sqlx::query(
"SELECT enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1"
)
.fetch_optional(db)
.await
{
csm_protocol::PatchScanConfigPayload {
enabled: row.get::<i32, _>("enabled") != 0,
scan_interval_secs: 43200,
}
} else {
csm_protocol::PatchScanConfigPayload {
enabled: true,
scan_interval_secs: 43200,
}
};
if let Ok(frame) = Frame::new_json(MessageType::PatchScanConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
// Push plugin enable/disable state — disable any plugins that admin has turned off
if let Ok(rows) = sqlx::query(
"SELECT plugin_name FROM plugin_state WHERE enabled = 0"
@@ -297,10 +345,17 @@ pub async fn push_all_plugin_configs(
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Registry of all connected client sessions
/// Registry of all connected client sessions, including cached device secrets.
#[derive(Clone, Default)]
pub struct ClientRegistry {
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
sessions: Arc<RwLock<HashMap<String, ClientSession>>>,
}
/// Per-device session data kept in memory for fast access.
struct ClientSession {
tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
/// Cached device_secret for HMAC verification — avoids a DB query per heartbeat.
secret: Option<String>,
}
impl ClientRegistry {
@@ -308,8 +363,8 @@ impl ClientRegistry {
Self::default()
}
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, tx);
pub async fn register(&self, device_uid: String, secret: Option<String>, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, ClientSession { tx, secret });
}
pub async fn unregister(&self, device_uid: &str) {
@@ -317,13 +372,25 @@ impl ClientRegistry {
}
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
if let Some(tx) = self.sessions.read().await.get(device_uid) {
tx.send(data).await.is_ok()
if let Some(session) = self.sessions.read().await.get(device_uid) {
session.tx.send(data).await.is_ok()
} else {
false
}
}
/// Get cached device secret for HMAC verification (avoids DB query per heartbeat).
pub async fn get_secret(&self, device_uid: &str) -> Option<String> {
self.sessions.read().await.get(device_uid).and_then(|s| s.secret.clone())
}
/// Backfill cached device secret after a cache miss (e.g. server restart).
pub async fn set_secret(&self, device_uid: &str, secret: String) {
if let Some(session) = self.sessions.write().await.get_mut(device_uid) {
session.secret = Some(secret);
}
}
pub async fn count(&self) -> usize {
self.sessions.read().await.len()
}
@@ -366,7 +433,7 @@ pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<(
Some(acceptor) => {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_client_tls(tls_stream, state).await {
if let Err(e) = handle_client(tls_stream, state).await {
warn!("Client {} TLS error: {}", peer_addr, e);
}
}
@@ -451,6 +518,17 @@ fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &
}
}
/// Constant-time string comparison to prevent timing attacks on secrets.
fn constant_time_eq(a: &str, b: &str) -> bool {
use std::iter;
if a.len() != b.len() {
// Still do a comparison to avoid leaking length via timing
let _ = a.as_bytes().iter().zip(iter::repeat(0u8)).map(|(x, y)| x ^ y);
return false;
}
a.as_bytes().iter().zip(b.as_bytes()).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0
}
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
/// `hmac_fail_count` tracks consecutive HMAC failures; caller checks it for disconnect threshold.
async fn process_frame(
@@ -467,10 +545,10 @@ async fn process_frame(
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
// Validate registration token against configured token
// Validate registration token against configured token (constant-time comparison)
let expected_token = &state.config.registration_token;
if !expected_token.is_empty() {
if req.registration_token.is_empty() || req.registration_token != *expected_token {
if req.registration_token.is_empty() || !constant_time_eq(&req.registration_token, expected_token) {
warn!("Registration rejected for {}: invalid token", req.device_uid);
let err_frame = Frame::new_json(MessageType::RegisterResponse,
&serde_json::json!({"error": "invalid_registration_token"}))?;
@@ -514,7 +592,7 @@ async fn process_frame(
*device_uid = Some(req.device_uid.clone());
// If this device was already connected on a different session, evict the old one
// The new register() call will replace it in the hashmap
state.clients.register(req.device_uid.clone(), tx.clone()).await;
state.clients.register(req.device_uid.clone(), Some(device_secret.clone()), tx.clone()).await;
// Send registration response
let config = csm_protocol::ClientConfig::default();
@@ -539,17 +617,25 @@ async fn process_frame(
return Ok(());
}
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
let secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
// Verify HMAC — use cached secret from ClientRegistry, fall back to DB on cache miss (e.g. after restart)
let mut secret = state.clients.get_secret(&heartbeat.device_uid).await;
if secret.is_none() {
// Cache miss (server restarted) — query DB and backfill cache
let db_secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
if let Some(ref s) = db_secret {
state.clients.set_secret(&heartbeat.device_uid, s.clone()).await;
}
secret = db_secret;
}
if let Some(ref secret) = secret {
if !secret.is_empty() {
@@ -650,6 +736,40 @@ async fn process_frame(
crate::db::DeviceRepo::upsert_software(&state.db, &sw).await?;
}
MessageType::AssetChange => {
let change: csm_protocol::AssetChange = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid asset change: {}", e))?;
if !verify_device_uid(device_uid, "AssetChange", &change.device_uid) {
return Ok(());
}
let change_type_str = match change.change_type {
csm_protocol::AssetChangeType::Hardware => "hardware",
csm_protocol::AssetChangeType::SoftwareAdded => "software_added",
csm_protocol::AssetChangeType::SoftwareRemoved => "software_removed",
};
sqlx::query(
"INSERT INTO asset_changes (device_uid, change_type, change_detail, detected_at) \
VALUES (?, ?, ?, datetime('now'))"
)
.bind(&change.device_uid)
.bind(change_type_str)
.bind(serde_json::to_string(&change.change_detail).map_err(|e| anyhow::anyhow!("Failed to serialize asset change detail: {}", e))?)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting asset change: {}", e))?;
debug!("Asset change: {} {:?} for device {}", change_type_str, change.change_detail, change.device_uid);
state.ws_hub.broadcast(serde_json::json!({
"type": "asset_change",
"device_uid": change.device_uid,
"change_type": change_type_str,
}).to_string()).await;
}
MessageType::UsageReport => {
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
@@ -910,23 +1030,86 @@ async fn process_frame(
return Ok(());
}
for rule_stat in &stats.rule_stats {
sqlx::query(
"INSERT INTO popup_block_stats (device_uid, rule_id, blocked_count, period_secs, reported_at) \
VALUES (?, ?, ?, ?, datetime('now'))"
)
.bind(&stats.device_uid)
.bind(rule_stat.rule_id)
.bind(rule_stat.hits as i32)
.bind(stats.period_secs as i32)
.execute(&state.db)
.await
.ok();
}
// Upsert aggregate stats per device per day
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
sqlx::query(
"INSERT INTO popup_block_stats (device_uid, blocked_count, date) \
VALUES (?, ?, ?) \
ON CONFLICT(device_uid, date) DO UPDATE SET \
blocked_count = blocked_count + excluded.blocked_count"
)
.bind(&stats.device_uid)
.bind(stats.blocked_count as i32)
.bind(&today)
.execute(&state.db)
.await
.ok();
debug!("Popup block stats: {} blocked {} windows in {}s", stats.device_uid, stats.blocked_count, stats.period_secs);
}
MessageType::PatchStatusReport => {
let payload: csm_protocol::PatchStatusPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid patch status report: {}", e))?;
if !verify_device_uid(device_uid, "PatchStatusReport", &payload.device_uid) {
return Ok(());
}
for patch in &payload.patches {
sqlx::query(
"INSERT INTO patch_status (device_uid, kb_id, title, severity, is_installed, installed_at, updated_at) \
VALUES (?, ?, ?, ?, ?, ?, datetime('now')) \
ON CONFLICT(device_uid, kb_id) DO UPDATE SET \
title = excluded.title, severity = COALESCE(excluded.severity, patch_status.severity), \
is_installed = excluded.is_installed, installed_at = excluded.installed_at, \
updated_at = datetime('now')"
)
.bind(&payload.device_uid)
.bind(&patch.kb_id)
.bind(&patch.title)
.bind(&patch.severity)
.bind(patch.is_installed as i32)
.bind(&patch.installed_at)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting patch status: {}", e))?;
}
info!("Patch status reported: {} ({} patches)", payload.device_uid, payload.patches.len());
}
MessageType::BehaviorMetricsReport => {
let metrics: csm_protocol::BehaviorMetricsPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid behavior metrics: {}", e))?;
if !verify_device_uid(device_uid, "BehaviorMetricsReport", &metrics.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO behavior_metrics (device_uid, clipboard_ops_count, clipboard_ops_night, print_jobs_count, usb_file_ops_count, new_processes_count, period_secs, reported_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(&metrics.device_uid)
.bind(metrics.clipboard_ops_count as i32)
.bind(metrics.clipboard_ops_night as i32)
.bind(metrics.print_jobs_count as i32)
.bind(metrics.usb_file_ops_count as i32)
.bind(metrics.new_processes_count as i32)
.bind(metrics.period_secs as i32)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting behavior metrics: {}", e))?;
// Run anomaly detection inline
crate::anomaly::check_anomalies(&state.db, &state.ws_hub, &metrics).await;
debug!("Behavior metrics saved: {} (clipboard={}, print={}, usb_file={}, procs={})",
metrics.device_uid, metrics.clipboard_ops_count, metrics.print_jobs_count,
metrics.usb_file_ops_count, metrics.new_processes_count);
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
@@ -935,13 +1118,14 @@ async fn process_frame(
Ok(())
}
/// Handle a single client TCP connection
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
/// Handle a single client TCP connection (plaintext or TLS)
async fn handle_client<S>(stream: S, state: AppState) -> anyhow::Result<()>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let _ = stream.set_nodelay(true);
let (mut reader, mut writer) = stream.into_split();
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
@@ -1018,81 +1202,50 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
Ok(())
}
/// Handle a TLS-wrapped client connection
async fn handle_client_tls(
stream: tokio_rustls::server::TlsStream<TcpStream>,
state: AppState,
) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let hmac_fail_count = Arc::new(AtomicU32::new(0));
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
/// Push a TLS certificate rotation notice to all online devices.
/// Computes the fingerprint of the new certificate and sends ConfigUpdate(TlsCertRotate).
pub async fn push_tls_cert_rotation(clients: &ClientRegistry, new_cert_pem: &[u8], valid_until: &str) -> usize {
// Compute SHA-256 fingerprint of the new certificate
let certs: Vec<_> = match rustls_pemfile::certs(&mut &new_cert_pem[..]).collect::<Result<Vec<_>, _>>() {
Ok(c) => c,
Err(e) => {
warn!("Failed to parse new certificate for rotation: {:?}", e);
return 0;
}
});
};
// Reader loop with idle timeout
'reader: loop {
let read_result = tokio::time::timeout(
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
reader.read(&mut buffer),
).await;
let end_entity = match certs.first() {
Some(c) => c,
None => {
warn!("No certificates found in PEM for rotation");
return 0;
}
};
let n = match read_result {
Ok(Ok(0)) => break,
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
warn!("Idle timeout for TLS device {:?}, disconnecting", device_uid);
break;
}
let fingerprint = {
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(end_entity.as_ref());
hex::encode(hasher.finalize())
};
info!("Pushing TLS cert rotation: new fingerprint={}... valid_until={}", &fingerprint[..16], valid_until);
let config_update = csm_protocol::ConfigUpdateType::TlsCertRotate {
new_cert_hash: fingerprint,
valid_until: valid_until.to_string(),
};
let online = clients.list_online().await;
let mut pushed = 0usize;
for uid in &online {
let frame = match Frame::new_json(MessageType::ConfigUpdate, &config_update) {
Ok(f) => f,
Err(_) => continue,
};
read_buf.extend_from_slice(&buffer[..n]);
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("TLS connection exceeded max buffer size, dropping");
break;
}
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
read_buf.drain(..frame_size);
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if !rate_limiter.check() {
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
break 'reader;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
warn!("Frame processing error: {}", e);
}
// Disconnect if too many consecutive HMAC failures
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
warn!("Too many HMAC failures for TLS device {:?}, disconnecting", device_uid);
break 'reader;
}
if clients.send_to(uid, frame.encode()).await {
pushed += 1;
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
pushed
}

View File

@@ -1,12 +1,10 @@
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
use axum::response::IntoResponse;
use axum::extract::Query;
use jsonwebtoken::{decode, Validation, DecodingKey};
use serde::Deserialize;
use tokio::sync::broadcast;
use std::sync::Arc;
use tracing::{debug, warn};
use crate::api::auth::Claims;
use crate::AppState;
/// WebSocket hub for broadcasting real-time events to admin browsers
@@ -32,65 +30,73 @@ impl WsHub {
}
}
#[derive(Debug, Deserialize)]
pub struct WsAuthParams {
pub token: Option<String>,
/// Claim stored when a WS ticket is created. Consumed on WS connection.
#[derive(Debug, Clone)]
pub struct TicketClaim {
pub user_id: i64,
pub username: String,
pub role: String,
pub created_at: std::time::Instant,
}
/// HTTP upgrade handler for WebSocket connections
/// Validates JWT token from query parameter before upgrading
#[derive(Debug, Deserialize)]
pub struct WsTicketParams {
pub ticket: Option<String>,
}
/// HTTP upgrade handler for WebSocket connections.
/// Validates a one-time ticket (obtained via POST /api/ws/ticket) before upgrading.
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(params): Query<WsAuthParams>,
Query(params): Query<WsTicketParams>,
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl IntoResponse {
let token = match params.token {
let ticket = match params.ticket {
Some(t) => t,
None => {
warn!("WebSocket connection rejected: no token provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
warn!("WebSocket connection rejected: no ticket provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing ticket").into_response();
}
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(e) => {
warn!("WebSocket connection rejected: invalid token - {}", e);
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
// Consume (remove) the ticket from the store — single use
let claim = {
let mut tickets = state.ws_tickets.lock().await;
match tickets.remove(&ticket) {
Some(claim) => claim,
None => {
warn!("WebSocket connection rejected: invalid or expired ticket");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid or expired ticket").into_response();
}
}
};
if claims.token_type != "access" {
warn!("WebSocket connection rejected: not an access token");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
// Check ticket age (30 second TTL)
if claim.created_at.elapsed().as_secs() > 30 {
warn!("WebSocket connection rejected: ticket expired");
return (axum::http::StatusCode::UNAUTHORIZED, "Ticket expired").into_response();
}
let hub = state.ws_hub.clone();
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
ws.on_upgrade(move |socket| handle_socket(socket, claim, hub))
}
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claims.username);
async fn handle_socket(mut socket: WebSocket, claim: TicketClaim, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claim.username);
let welcome = serde_json::json!({
"type": "connected",
"message": "CSM real-time feed active",
"user": claims.username
"user": claim.username
});
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
return;
}
// Subscribe to broadcast hub for real-time events
let mut rx = hub.subscribe();
loop {
tokio::select! {
// Forward broadcast messages to WebSocket client
msg = rx.recv() => {
match msg {
Ok(text) => {
@@ -104,7 +110,6 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
Err(broadcast::error::RecvError::Closed) => break,
}
}
// Handle incoming WebSocket messages (ping/close)
msg = socket.recv() => {
match msg {
Some(Ok(Message::Ping(data))) => {
@@ -121,5 +126,5 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
}
}
debug!("WebSocket client disconnected: user={}", claims.username);
debug!("WebSocket client disconnected: user={}", claim.username);
}

View File

@@ -0,0 +1,20 @@
-- 017_device_health_scores.sql: Device health scoring system
CREATE TABLE IF NOT EXISTS device_health_scores (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
score INTEGER NOT NULL DEFAULT 0 CHECK(score >= 0 AND score <= 100),
status_score INTEGER NOT NULL DEFAULT 0,
encryption_score INTEGER NOT NULL DEFAULT 0,
load_score INTEGER NOT NULL DEFAULT 0,
alert_score INTEGER NOT NULL DEFAULT 0,
compliance_score INTEGER NOT NULL DEFAULT 0,
patch_score INTEGER NOT NULL DEFAULT 0,
level TEXT NOT NULL DEFAULT 'unknown' CHECK(level IN ('healthy', 'warning', 'critical', 'unknown')),
details TEXT,
computed_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(device_uid)
);
CREATE INDEX IF NOT EXISTS idx_health_scores_level ON device_health_scores(level);
CREATE INDEX IF NOT EXISTS idx_health_scores_computed ON device_health_scores(computed_at);

View File

@@ -0,0 +1,59 @@
-- 018_patch_management.sql: Patch management system
CREATE TABLE IF NOT EXISTS patch_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
kb_id TEXT NOT NULL,
title TEXT NOT NULL,
severity TEXT,
is_installed INTEGER NOT NULL DEFAULT 0,
discovered_at TEXT NOT NULL DEFAULT (datetime('now')),
installed_at TEXT,
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(device_uid, kb_id)
);
CREATE TABLE IF NOT EXISTS patch_policies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'device', 'group')),
target_id TEXT,
auto_approve INTEGER NOT NULL DEFAULT 0,
severity_filter TEXT NOT NULL DEFAULT 'important',
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Behavior metrics for anomaly detection
CREATE TABLE IF NOT EXISTS behavior_metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
clipboard_ops_count INTEGER NOT NULL DEFAULT 0,
clipboard_ops_night INTEGER NOT NULL DEFAULT 0,
print_jobs_count INTEGER NOT NULL DEFAULT 0,
usb_file_ops_count INTEGER NOT NULL DEFAULT 0,
new_processes_count INTEGER NOT NULL DEFAULT 0,
period_secs INTEGER NOT NULL DEFAULT 3600,
reported_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Anomaly alerts generated by the detection engine
CREATE TABLE IF NOT EXISTS anomaly_alerts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
anomaly_type TEXT NOT NULL,
severity TEXT NOT NULL DEFAULT 'medium' CHECK(severity IN ('low', 'medium', 'high', 'critical')),
detail TEXT NOT NULL,
metric_value REAL,
baseline_value REAL,
handled INTEGER NOT NULL DEFAULT 0,
handled_by TEXT,
handled_at TEXT,
triggered_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_patch_status_device ON patch_status(device_uid);
CREATE INDEX IF NOT EXISTS idx_patch_status_severity ON patch_status(severity, is_installed);
CREATE INDEX IF NOT EXISTS idx_behavior_metrics_device_time ON behavior_metrics(device_uid, reported_at);
CREATE INDEX IF NOT EXISTS idx_anomaly_alerts_device ON anomaly_alerts(device_uid);
CREATE INDEX IF NOT EXISTS idx_anomaly_alerts_unhandled ON anomaly_alerts(handled) WHERE handled = 0;

View File

@@ -0,0 +1,54 @@
-- Software whitelist: processes that should NEVER be blocked even if matched by blacklist rules.
-- This provides a safety net to prevent false positives from killing legitimate applications.
CREATE TABLE IF NOT EXISTS software_whitelist (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name_pattern TEXT NOT NULL,
reason TEXT,
is_builtin INTEGER NOT NULL DEFAULT 0, -- 1 = system default, 0 = admin-added
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Built-in whitelist entries for common safe applications
INSERT INTO software_whitelist (name_pattern, reason, is_builtin) VALUES
-- Browsers
('chrome.exe', 'Google Chrome browser', 1),
('msedge.exe', 'Microsoft Edge browser', 1),
('firefox.exe', 'Mozilla Firefox browser', 1),
('iexplore.exe', 'Internet Explorer', 1),
('opera.exe', 'Opera browser', 1),
('brave.exe', 'Brave browser', 1),
('vivaldi.exe', 'Vivaldi browser', 1),
-- Development tools & IDEs
('code.exe', 'Visual Studio Code', 1),
('devenv.exe', 'Visual Studio', 1),
('idea64.exe', 'IntelliJ IDEA', 1),
('webstorm64.exe', 'WebStorm', 1),
('pycharm64.exe', 'PyCharm', 1),
('goland64.exe', 'GoLand', 1),
('clion64.exe', 'CLion', 1),
('rider64.exe', 'Rider', 1),
('trae.exe', 'Trae IDE', 1),
('windsurf.exe', 'Windsurf IDE', 1),
('cursor.exe', 'Cursor IDE', 1),
-- Office & productivity
('winword.exe', 'Microsoft Word', 1),
('excel.exe', 'Microsoft Excel', 1),
('powerpnt.exe', 'Microsoft PowerPoint', 1),
('outlook.exe', 'Microsoft Outlook', 1),
('onenote.exe', 'Microsoft OneNote', 1),
('teams.exe', 'Microsoft Teams', 1),
('wps.exe', 'WPS Office', 1),
-- Terminal & system tools
('cmd.exe', 'Command Prompt', 1),
('powershell.exe', 'PowerShell', 1),
('pwsh.exe', 'PowerShell Core', 1),
('WindowsTerminal.exe', 'Windows Terminal', 1),
-- Communication
('wechat.exe', 'WeChat', 1),
('dingtalk.exe', 'DingTalk', 1),
('feishu.exe', 'Feishu/Lark', 1),
('qq.exe', 'QQ', 1),
('tim.exe', 'Tencent TIM', 1),
-- CSM
('csm-client.exe', 'CSM Client itself', 1);

View File

@@ -1,5 +1,7 @@
/**
* Shared API client with authentication and error handling
* Shared API client with cookie-based authentication.
* Tokens are managed via HttpOnly cookies set by the server —
* the frontend never reads or stores JWT tokens.
*/
const API_BASE = import.meta.env.VITE_API_BASE || ''
@@ -21,43 +23,37 @@ export class ApiError extends Error {
}
}
function getToken(): string | null {
const token = localStorage.getItem('token')
if (!token || token.trim() === '') return null
return token
}
function clearAuth() {
localStorage.removeItem('token')
localStorage.removeItem('refresh_token')
window.location.href = '/login'
}
let refreshPromise: Promise<boolean> | null = null
/** Cached user info from /api/auth/me */
let cachedUser: { id: number; username: string; role: string } | null = null
export function getCachedUser() {
return cachedUser
}
export function clearCachedUser() {
cachedUser = null
}
async function tryRefresh(): Promise<boolean> {
// Coalesce concurrent refresh attempts
if (refreshPromise) return refreshPromise
refreshPromise = (async () => {
const refreshToken = localStorage.getItem('refresh_token')
if (!refreshToken || refreshToken.trim() === '') return false
try {
const response = await fetch(`${API_BASE}/api/auth/refresh`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refresh_token: refreshToken }),
credentials: 'same-origin',
})
if (!response.ok) return false
const result = await response.json()
if (!result.success || !result.data?.access_token) return false
if (!result.success) return false
localStorage.setItem('token', result.data.access_token)
if (result.data.refresh_token) {
localStorage.setItem('refresh_token', result.data.refresh_token)
// Update cached user from refresh response
if (result.data?.user) {
cachedUser = result.data.user
}
return true
} catch {
@@ -74,13 +70,8 @@ async function request<T>(
path: string,
options: RequestInit = {},
): Promise<T> {
const token = getToken()
const headers = new Headers(options.headers || {})
if (token) {
headers.set('Authorization', `Bearer ${token}`)
}
if (options.body && typeof options.body === 'string') {
headers.set('Content-Type', 'application/json')
}
@@ -88,18 +79,21 @@ async function request<T>(
const response = await fetch(`${API_BASE}${path}`, {
...options,
headers,
credentials: 'same-origin',
})
// Handle 401 - try refresh before giving up
if (response.status === 401) {
const refreshed = await tryRefresh()
if (refreshed) {
// Retry the original request with new token
const newToken = getToken()
headers.set('Authorization', `Bearer ${newToken}`)
const retryResponse = await fetch(`${API_BASE}${path}`, { ...options, headers })
const retryResponse = await fetch(`${API_BASE}${path}`, {
...options,
headers,
credentials: 'same-origin',
})
if (retryResponse.status === 401) {
clearAuth()
clearCachedUser()
window.location.href = '/login'
throw new ApiError(401, 'UNAUTHORIZED', 'Session expired')
}
const retryContentType = retryResponse.headers.get('content-type')
@@ -112,7 +106,8 @@ async function request<T>(
}
return retryResult.data as T
}
clearAuth()
clearCachedUser()
window.location.href = '/login'
throw new ApiError(401, 'UNAUTHORIZED', 'Session expired')
}
@@ -159,11 +154,12 @@ export const api = {
return request<T>(path, { method: 'DELETE' })
},
/** Login doesn't use the auth header */
async login(username: string, password: string): Promise<{ access_token: string; refresh_token: string; user: { id: number; username: string; role: string } }> {
/** Login — server sets HttpOnly cookies, we only get user info back */
async login(username: string, password: string): Promise<{ user: { id: number; username: string; role: string } }> {
const response = await fetch(`${API_BASE}/api/auth/login`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
credentials: 'same-origin',
body: JSON.stringify({ username, password }),
})
@@ -172,12 +168,27 @@ export const api = {
throw new ApiError(response.status, 'LOGIN_FAILED', result.error || 'Login failed')
}
localStorage.setItem('token', result.data.access_token)
localStorage.setItem('refresh_token', result.data.refresh_token)
cachedUser = result.data.user
return result.data
},
logout() {
clearAuth()
/** Logout — server clears cookies */
async logout(): Promise<void> {
try {
await fetch(`${API_BASE}/api/auth/logout`, {
method: 'POST',
credentials: 'same-origin',
})
} catch {
// Ignore errors during logout
}
clearCachedUser()
},
/** Check current auth status via /api/auth/me */
async me(): Promise<{ user: { id: number; username: string; role: string }; expires_at: string }> {
const result = await request<{ user: { id: number; username: string; role: string }; expires_at: string }>('/api/auth/me')
cachedUser = (result as { user: { id: number; username: string; role: string } }).user
return result
},
}

View File

@@ -27,40 +27,43 @@ const router = createRouter({
{ path: 'plugins/print-audit', name: 'PrintAudit', component: () => import('../views/plugins/PrintAudit.vue') },
{ path: 'plugins/clipboard-control', name: 'ClipboardControl', component: () => import('../views/plugins/ClipboardControl.vue') },
{ path: 'plugins/plugin-control', name: 'PluginControl', component: () => import('../views/plugins/PluginControl.vue') },
{ path: 'plugins/patch', name: 'PatchManagement', component: () => import('../views/plugins/PatchManagement.vue') },
{ path: 'plugins/anomaly', name: 'AnomalyDetection', component: () => import('../views/plugins/AnomalyDetection.vue') },
],
},
],
})
/** Check if a JWT token is structurally valid and not expired */
function isTokenValid(token: string): boolean {
if (!token || token.trim() === '') return false
try {
const parts = token.split('.')
if (parts.length !== 3) return false
const payload = JSON.parse(atob(parts[1]))
if (!payload.exp) return false
// Reject if token expires within 30 seconds
return payload.exp * 1000 > Date.now() + 30_000
} catch {
return false
}
}
/** Track whether we've already validated auth this session */
let authChecked = false
router.beforeEach((to, _from, next) => {
router.beforeEach(async (to, _from, next) => {
if (to.path === '/login') {
next()
return
}
const token = localStorage.getItem('token')
if (!token || !isTokenValid(token)) {
localStorage.removeItem('token')
localStorage.removeItem('refresh_token')
next('/login')
} else {
// If we've already verified auth this session, allow navigation
// (cookies are sent automatically, no need to check on every route change)
if (authChecked) {
next()
return
}
// Check auth status via /api/auth/me (reads access_token cookie)
try {
const { me } = await import('../lib/api')
await me()
authChecked = true
next()
} catch {
next('/login')
}
})
/** Reset auth check flag (called after logout) */
export function resetAuthCheck() {
authChecked = false
}
export default router

View File

@@ -41,6 +41,32 @@
<div class="stat-label">USB事件(24h)</div>
</div>
</div>
<div class="stat-card" @click="showHealthDetail = true" style="cursor:pointer">
<div class="stat-icon" :class="healthIconClass">
<span style="font-size:20px;font-weight:800;line-height:26px">{{ healthAvg }}</span>
</div>
<div class="stat-info">
<div class="stat-value" :class="healthTextClass">{{ healthAvg }}</div>
<div class="stat-label">健康评分</div>
</div>
</div>
</div>
<!-- Health overview bar -->
<div v-if="healthSummary.total > 0" class="health-bar">
<div class="health-bar-segment healthy" :style="{ flex: healthSummary.healthy }" :title="`${healthSummary.healthy} 健康`">
<span v-if="healthSummary.healthy > 0">{{ healthSummary.healthy }} 健康</span>
</div>
<div class="health-bar-segment warning" :style="{ flex: healthSummary.warning }" :title="`${healthSummary.warning} 告警`">
<span v-if="healthSummary.warning > 0">{{ healthSummary.warning }} 告警</span>
</div>
<div class="health-bar-segment critical" :style="{ flex: healthSummary.critical }" :title="`${healthSummary.critical} 严重`">
<span v-if="healthSummary.critical > 0">{{ healthSummary.critical }} 严重</span>
</div>
<div class="health-bar-segment unknown" :style="{ flex: healthSummary.unknown || 0 }" :title="`${healthSummary.unknown || 0} 未知`">
<span v-if="(healthSummary.unknown || 0) > 0">{{ healthSummary.unknown }} 未知</span>
</div>
<div class="health-bar-label">策略冲突: {{ conflictCount }} </div>
</div>
<!-- Charts row -->
@@ -138,7 +164,7 @@
</template>
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
import { ref, computed, onMounted, onUnmounted } from 'vue'
import { Monitor, Platform, Bell, Connection, Top } from '@element-plus/icons-vue'
import * as echarts from 'echarts'
import { api } from '@/lib/api'
@@ -148,6 +174,24 @@ const recentAlerts = ref<Array<{ id: number; severity: string; detail: string; t
const recentUsbEvents = ref<Array<{ device_name: string; event_type: string; device_uid: string; event_time: string }>>([])
const topDevices = ref<Array<{ hostname: string; cpu_usage: number; memory_usage: number; status: string }>>([])
const healthSummary = ref<{ total: number; healthy: number; warning: number; critical: number; unknown: number; avg_score: number }>({ total: 0, healthy: 0, warning: 0, critical: 0, unknown: 0, avg_score: 0 })
const conflictCount = ref(0)
const showHealthDetail = ref(false)
const healthAvg = computed(() => Math.round(healthSummary.value.avg_score))
const healthIconClass = computed(() => {
const s = healthAvg.value
if (s >= 80) return 'health-good'
if (s >= 50) return 'health-warn'
return 'health-bad'
})
const healthTextClass = computed(() => {
const s = healthAvg.value
if (s >= 80) return 'text-good'
if (s >= 50) return 'text-warn'
return 'text-bad'
})
const cpuChartRef = ref<HTMLElement>()
let chart: echarts.ECharts | null = null
let timer: ReturnType<typeof setInterval> | null = null
@@ -155,10 +199,12 @@ let resizeHandler: (() => void) | null = null
async function fetchDashboard() {
try {
const [devicesData, alertsData, usbData] = await Promise.all([
const [devicesData, alertsData, usbData, healthData, conflictData] = await Promise.all([
api.get<any>('/api/devices'),
api.get<any>('/api/alerts/records?handled=0&page_size=10'),
api.get<any>('/api/usb/events?page_size=10'),
api.get<any>('/api/dashboard/health-overview').catch(() => null),
api.get<any>('/api/policies/conflicts').catch(() => null),
])
const devices = devicesData.devices || []
@@ -179,6 +225,16 @@ async function fetchDashboard() {
const events = usbData.events || []
stats.value.usbEvents = events.length
recentUsbEvents.value = events.slice(0, 8)
// Health overview
if (healthData?.summary) {
healthSummary.value = healthData.summary
}
// Conflict count
if (conflictData?.total !== undefined) {
conflictCount.value = conflictData.total
}
} catch (e) {
console.error('Failed to fetch dashboard data', e)
}
@@ -374,4 +430,51 @@ onUnmounted(() => {
color: var(--csm-text-tertiary);
margin-top: 2px;
}
/* Health bar */
.health-bar {
display: flex;
align-items: center;
height: 32px;
border-radius: 6px;
overflow: hidden;
margin-top: 16px;
background: #f1f5f9;
font-size: 12px;
color: #fff;
position: relative;
}
.health-bar-segment {
display: flex;
align-items: center;
justify-content: center;
min-width: 0;
overflow: hidden;
white-space: nowrap;
padding: 0 8px;
transition: flex 0.3s ease;
}
.health-bar-segment.healthy { background: #16a34a; }
.health-bar-segment.warning { background: #d97706; }
.health-bar-segment.critical { background: #dc2626; }
.health-bar-segment.unknown { background: #94a3b8; }
.health-bar-label {
position: absolute;
right: 12px;
font-size: 12px;
color: #64748b;
font-weight: 500;
}
/* Health score colors */
.stat-icon.health-good { background: #f0fdf4; color: #16a34a; }
.stat-icon.health-warn { background: #fffbeb; color: #d97706; }
.stat-icon.health-bad { background: #fef2f2; color: #dc2626; }
.text-good { color: #16a34a !important; }
.text-warn { color: #d97706 !important; }
.text-bad { color: #dc2626 !important; }
</style>

View File

@@ -143,6 +143,15 @@
<el-tag size="small" effect="plain" round>{{ row.group_name || '默认' }}</el-tag>
</template>
</el-table-column>
<el-table-column label="健康" width="90" sortable :sort-method="(a: any, b: any) => (a.health_score ?? 0) - (b.health_score ?? 0)">
<template #default="{ row }">
<div v-if="row.health_score != null" class="health-cell">
<span class="health-dot" :class="row.health_level"></span>
<span class="health-value" :class="'text-' + healthClass(row.health_score)">{{ row.health_score }}</span>
</div>
<span v-else class="health-cell">-</span>
</template>
</el-table-column>
<el-table-column label="CPU" width="100">
<template #default="{ row }">
<div class="usage-cell">
@@ -380,6 +389,13 @@ function getProgressColor(value?: number): string {
return '#16a34a'
}
function healthClass(score?: number): string {
if (score == null) return 'unknown'
if (score >= 80) return 'good'
if (score >= 50) return 'warn'
return 'bad'
}
function formatTime(t: string | null): string {
if (!t) return '-'
const d = new Date(t)
@@ -795,6 +811,31 @@ async function handleMoveSubmit() {
color: var(--csm-text-primary);
}
/* Health cell */
.health-cell {
display: flex;
align-items: center;
gap: 6px;
}
.health-dot {
width: 7px;
height: 7px;
border-radius: 50%;
flex-shrink: 0;
}
.health-dot.healthy { background: #16a34a; box-shadow: 0 0 4px rgba(22,163,74,0.4); }
.health-dot.warning { background: #d97706; }
.health-dot.critical { background: #dc2626; box-shadow: 0 0 4px rgba(220,38,38,0.3); }
.health-dot.unknown { background: #94a3b8; }
.health-value { font-size: 13px; font-weight: 600; }
.text-good { color: #16a34a; }
.text-warn { color: #d97706; }
.text-bad { color: #dc2626; }
.text-unknown { color: #94a3b8; }
/* Pagination */
.pagination-bar {
display: flex;

View File

@@ -76,6 +76,12 @@
<el-menu-item index="/plugins/plugin-control">
<template #title><span>插件控制</span></template>
</el-menu-item>
<el-menu-item index="/plugins/patch">
<template #title><span>补丁管理</span></template>
</el-menu-item>
<el-menu-item index="/plugins/anomaly">
<template #title><span>异常检测</span></template>
</el-menu-item>
</el-sub-menu>
<el-menu-item index="/settings">
@@ -143,7 +149,8 @@ import {
Monitor, Platform, Connection, Bell, Setting,
ArrowDown, Grid, Expand, Fold, SwitchButton
} from '@element-plus/icons-vue'
import { api } from '@/lib/api'
import { api, getCachedUser } from '@/lib/api'
import { resetAuthCheck } from '@/router'
const route = useRoute()
const router = useRouter()
@@ -153,15 +160,8 @@ const currentRoute = computed(() => route.path)
const unreadAlerts = ref(0)
const username = ref('')
function decodeUsername(): string {
try {
const token = localStorage.getItem('token')
if (!token) return ''
const payload = JSON.parse(atob(token.split('.')[1]))
return payload.username || ''
} catch {
return ''
}
function getCachedUsername(): string {
return getCachedUser()?.username || ''
}
async function fetchUnreadAlerts() {
@@ -189,18 +189,20 @@ const pageTitles: Record<string, string> = {
'/plugins/print-audit': '打印审计',
'/plugins/clipboard-control': '剪贴板管控',
'/plugins/plugin-control': '插件控制',
'/plugins/patch': '补丁管理',
'/plugins/anomaly': '异常检测',
}
const pageTitle = computed(() => pageTitles[route.path] || '仪表盘')
onMounted(() => {
username.value = decodeUsername()
username.value = getCachedUsername()
fetchUnreadAlerts()
})
function handleLogout() {
localStorage.removeItem('token')
localStorage.removeItem('refresh_token')
async function handleLogout() {
await api.logout()
resetAuthCheck()
router.push('/login')
}
</script>

View File

@@ -85,7 +85,7 @@
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage } from 'element-plus'
import { api } from '@/lib/api'
import { api, getCachedUser } from '@/lib/api'
const version = ref('0.1.0')
const dbInfo = ref('SQLite (WAL mode)')
@@ -96,14 +96,11 @@ const pwdForm = reactive({ oldPassword: '', newPassword: '', confirmPassword: ''
const pwdLoading = ref(false)
onMounted(() => {
try {
const token = localStorage.getItem('token')
if (token) {
const payload = JSON.parse(atob(token.split('.')[1]))
user.username = payload.username || 'admin'
user.role = payload.role || 'admin'
}
} catch (e) { console.error('Failed to decode token for username', e) }
const cached = getCachedUser()
if (cached) {
user.username = cached.username
user.role = cached.role
}
api.get<any>('/health')
.then((data: any) => {

View File

@@ -0,0 +1,90 @@
<template>
<div class="page-container">
<div class="csm-card">
<div class="csm-card-header">
<span>异常行为检测</span>
<el-tag v-if="unhandled > 0" type="danger" effect="light" size="small">{{ unhandled }} 未处理</el-tag>
<el-tag v-else type="success" effect="light" size="small">无异常</el-tag>
</div>
<div class="csm-card-body">
<el-table :data="alerts" v-loading="loading" size="small" max-height="520">
<el-table-column prop="hostname" label="终端" width="140" />
<el-table-column label="异常类型" width="180">
<template #default="{ row }">
<span>{{ anomalyLabel(row.anomaly_type) }}</span>
</template>
</el-table-column>
<el-table-column label="严重性" width="90">
<template #default="{ row }">
<el-tag :type="severityType(row.severity)" size="small" effect="light">{{ row.severity }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="detail" label="详情" min-width="300" show-overflow-tooltip />
<el-table-column prop="triggered_at" label="检测时间" width="160" />
</el-table>
<div v-if="alerts.length === 0 && !loading" style="padding:40px 0;text-align:center;color:#94a3b8">
暂无异常行为告警系统运行正常
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { api } from '@/lib/api'
const alerts = ref<any[]>([])
const loading = ref(false)
const unhandled = ref(0)
const anomalyLabels: Record<string, string> = {
night_clipboard_spike: '非工作时间剪贴板异常',
usb_file_exfiltration: 'USB文件大量拷贝',
high_print_volume: '打印量异常',
process_spawn_spike: '进程启动频率异常',
}
function anomalyLabel(type: string): string {
return anomalyLabels[type] || type
}
function severityType(s: string): string {
if (s === 'critical') return 'danger'
if (s === 'high') return 'warning'
if (s === 'medium') return ''
return 'info'
}
async function fetchData() {
loading.value = true
try {
const data = await api.get<any>('/api/plugins/anomaly/alerts?page_size=50')
alerts.value = data.alerts || []
unhandled.value = data.unhandled_count || 0
} catch (e) {
console.error('Failed to fetch anomaly alerts', e)
} finally {
loading.value = false
}
}
onMounted(fetchData)
</script>
<style scoped>
.csm-card-header {
font-weight: 600;
font-size: 15px;
color: var(--csm-text-primary);
padding: 16px 20px;
border-bottom: 1px solid var(--csm-border-color);
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
}
.csm-card-body {
padding: 16px 20px;
}
</style>

View File

@@ -0,0 +1,92 @@
<template>
<div class="page-container">
<div class="csm-card">
<div class="csm-card-header">
<span>补丁管理</span>
<el-tag type="info" effect="plain" size="small">{{ summary.total_installed }} 已安装 / {{ summary.total_missing }} 缺失</el-tag>
</div>
<div class="csm-card-body">
<el-table :data="patches" v-loading="loading" size="small" max-height="520">
<el-table-column prop="hostname" label="终端" width="140" />
<el-table-column prop="kb_id" label="补丁编号" width="120" />
<el-table-column prop="title" label="描述" min-width="280" show-overflow-tooltip />
<el-table-column prop="severity" label="严重性" width="100">
<template #default="{ row }">
<el-tag v-if="row.severity" :type="severityType(row.severity)" size="small" effect="light">{{ row.severity }}</el-tag>
<span v-else class="text-muted">-</span>
</template>
</el-table-column>
<el-table-column label="状态" width="90">
<template #default="{ row }">
<el-tag :type="row.is_installed ? 'success' : 'danger'" size="small" effect="light">
{{ row.is_installed ? '已安装' : '缺失' }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="installed_at" label="安装时间" width="120">
<template #default="{ row }">{{ row.installed_at || '-' }}</template>
</el-table-column>
</el-table>
<div style="display:flex;justify-content:flex-end;padding-top:12px">
<el-pagination :total="total" :page-size="pageSize" layout="total, prev, pager, next" @current-change="handlePage" />
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { api } from '@/lib/api'
const patches = ref<any[]>([])
const loading = ref(false)
const total = ref(0)
const page = ref(1)
const pageSize = 20
const summary = ref({ total_installed: 0, total_missing: 0 })
async function fetchData() {
loading.value = true
try {
const data = await api.get<any>(`/api/plugins/patch/status?page=${page.value}&page_size=${pageSize}`)
patches.value = data.patches || []
total.value = data.total || 0
if (data.summary) summary.value = data.summary
} catch (e) {
console.error('Failed to fetch patches', e)
} finally {
loading.value = false
}
}
function handlePage(p: number) {
page.value = p
fetchData()
}
function severityType(s: string): string {
if (s === 'Critical') return 'danger'
if (s === 'Important') return 'warning'
return 'info'
}
onMounted(fetchData)
</script>
<style scoped>
.csm-card-header {
font-weight: 600;
font-size: 15px;
color: var(--csm-text-primary);
padding: 16px 20px;
border-bottom: 1px solid var(--csm-border-color);
display: flex;
align-items: center;
justify-content: space-between;
}
.csm-card-body {
padding: 16px 20px;
}
.text-muted { color: #94a3b8; }
</style>

255
wiki/SECURITY-AUDIT.md Normal file
View File

@@ -0,0 +1,255 @@
# CSM 安全审计报告
> **审计日期**: 2026-04-11 | **审计范围**: 全系统 (Server + Client + Protocol + Frontend)
> **方法论**: OWASP Top 10 (2021), CWE Top 25, 手动源码审查 + 攻击者视角分析
---
## 执行摘要
| 严重级别 | 数量 | 说明 |
|----------|------|------|
| **CRITICAL** | 4 | 可直接被远程利用,导致系统完全沦陷 |
| **HIGH** | 12 | 需特定条件但影响重大 |
| **MEDIUM** | 12 | 有限影响或需较高权限 |
| **LOW** | 8 | 纵深防御建议 |
**最关键发现**: JWT Secret 硬编码在版本控制中、注册 Token 为空允许任意设备注册、凭据文件无 ACL 保护、默认无 TLS 传输加密。这四个问题组合意味着攻击者可以在数分钟内完全接管系统。
---
## CRITICAL (4)
### AUD-001: JWT Secret 硬编码在 config.toml 中
- **文件**: `config.toml:12``crates/server/src/config.rs`
- **CWE**: CWE-798 (硬编码凭证)
- **OWASP**: A07:2021 - 安全配置错误
**漏洞代码**:
```toml
jwt_secret = "39ffc129-dd62-4eb4-bbc0-8bf4b8e2ccc7"
```
**攻击场景**: 任何能访问仓库的人可提取 secret为任意用户含 admin伪造 JWT获得系统完全管理控制权可推送恶意配置到所有医院终端、禁用安全控制、篡改审计记录。
**修复**:
1.`config.toml` 中移除硬编码 secret
2. 通过 `CSM_JWT_SECRET` 环境变量独占加载
3. secret 为空时拒绝启动
4. **立即轮换**已泄露的 secret `39ffc129-dd62-4eb4-bbc0-8bf4b8e2ccc7`
5.`config.toml` 加入 `.gitignore`
---
### AUD-002: 空 Registration Token — 任意设备注册
- **文件**: `config.toml:1`, `crates/server/src/tcp.rs:549-558`
- **CWE**: CWE-306 (关键功能缺失认证)
- **OWASP**: A07:2021
**漏洞代码**:
```toml
registration_token = ""
```
```rust
if !expected_token.is_empty() { // 空字符串直接跳过验证
```
**攻击场景**: TCP 端口 9999 可达的任何攻击者可注册恶意设备,注入伪造审计数据掩盖安全事件,获取所有插件配置(含安全策略)。
**修复**: 设置强 registration token空 token 时拒绝启动。
---
### AUD-003: Windows 凭据文件无 ACL 保护
- **文件**: `crates/client/src/main.rs:280-283`
- **CWE**: CWE-732 (关键资源权限不当)
- **OWASP**: A01:2021
**漏洞代码**:
```rust
#[cfg(not(unix))]
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
std::fs::write(path, content) // 无任何 ACL 设置
}
```
**攻击场景**: `device_secret.txt`HMAC 密钥)和 `device_uid.txt`(设备身份)以默认权限写入,设备上任何用户进程可读取。攻击者可提取密钥伪造心跳,在不同机器上模拟设备。
**修复**: 使用 `icacls``SetSecurityInfo` 设置仅 SYSTEM 可访问的 ACL。
---
### AUD-004: 默认无 TLS — 明文传输所有敏感数据
- **文件**: `config.toml` (无 `[server.tls]`), `crates/server/src/tcp.rs:411-414`
- **CWE**: CWE-319 (明文传输敏感信息)
- **OWASP**: A02:2021
**攻击场景**: TCP 端口 9999 以明文运行。`device_secret`、所有插件配置、Web 过滤规则、USB 策略、软件黑名单均以明文 JSON 传输。网络嗅探者可提取设备认证密钥并逆向工程安全策略。
**修复**: 生产环境强制 TLS无 TLS 时拒绝启动(`CSM_DEV=1` 除外)。
---
## HIGH (12)
### AUD-005: Refresh Token 未存储 — 撤销机制不完整
- **文件**: `crates/server/src/api/auth.rs:131-133`
- **CWE**: CWE-613
- `refresh_tokens` 表存在但从未写入。登录不存储 token record刷新仅检查 family 是否撤销,不验证 token 是否曾被实际颁发。无法强制注销所有 session。
### AUD-006: Refresh Token Family Rotation 存在 TOCTOU 竞争条件
- **文件**: `crates/server/src/api/auth.rs:167-183`
- **CWE**: CWE-367
- 检查 family 撤销状态与执行撤销之间无事务保护。并发使用同一 stolen token 的两个请求均可通过检查,攻击者获得全新的未撤销 token family。
### AUD-007: 客户端不验证服务器身份
- **文件**: `crates/client/src/network/mod.rs:149-162`
- **CWE**: CWE-295
- 客户端盲目连接任何响应的服务器。ARP/DNS 欺骗攻击者可推送恶意配置禁用所有安全插件、注入有害规则。HMAC 仅保护心跳,不保护配置推送。
### AUD-008: PowerShell 命令注入面
- **文件**: `crates/client/src/asset/mod.rs:82-83`, `crates/client/src/clipboard_control/mod.rs:143-159`
- **CWE**: CWE-78
- `powershell_lines()` 通过 `format!()` 拼接命令参数。若服务器推送含引号/转义字符的恶意规则,可能导致 PowerShell 命令注入。
### AUD-009: 服务停止/卸载未受保护
- **文件**: `crates/client/src/service.rs:17-69`
- **CWE**: CWE-284
- `csm-client.exe --uninstall` 无认证保护。终端管理员权限用户可完全移除安全代理,绕过所有监控。无服务恢复策略、无看门狗进程、无反调试保护。
### AUD-010: JWT Token 存储在 localStorage
- **文件**: `web/src/lib/api.ts:25-28, 175-176`
- **CWE**: CWE-922
- Access token 和 refresh token 均存储在 `localStorage`,可被同源任意 JS 访问。XSS 漏洞可直接窃取 7 天有效期的 refresh token。
### AUD-011: CSP 允许 unsafe-inline + unsafe-eval
- **文件**: `crates/server/src/main.rs:142-143`
- **CWE**: CWE-693
- `script-src 'self' 'unsafe-inline' 'unsafe-eval'` 使 CSP 对 XSS 几乎无效。结合 localStorage token 存储,单个 XSS 即可导致管理员会话完全沦陷。
### AUD-012: WebSocket JWT 在 URL 查询参数中
- **文件**: `crates/server/src/ws.rs:36-73`
- **CWE**: CWE-312
- JWT 通过 `/ws?token=eyJ...` 传输。Token 出现在浏览器历史、服务器访问日志、代理日志中。且 WebSocket handler 不检查用户角色,非管理员可接收所有广播事件。
### AUD-013: 告警规则 Webhook SSRF
- **文件**: `crates/server/src/api/alerts.rs:115-131`
- **CWE**: CWE-918
- `notify_webhook` 字段无 URL 验证。可设置为 `http://169.254.169.254/latest/meta-data/` (AWS 元数据) 或 `file:///etc/passwd`,将服务器变成 SSRF 代理。
### AUD-014: 仅基于用户名的速率限制可绕过
- **文件**: `crates/server/src/api/auth.rs:101`
- **CWE**: CWE-307
- 速率限制仅以用户名为 key。攻击者可用 `Admin``ADMIN` 等变体绕过。无 IP 限制,可分布式暴力破解。
### AUD-015: 磁盘加密确认在只读路由层 — 权限提升
- **文件**: `crates/server/src/api/plugins/mod.rs:42`
- **CWE**: CWE-862
- PUT `acknowledge_alert``read_routes()` 中(仅需认证,不需 admin。任何认证用户可确认忽略加密告警掩盖合规违规。
### AUD-016: 初始管理员密码输出到 stderr
- **文件**: `crates/server/src/main.rs:270-275`
- **CWE**: CWE-532
- 初始密码通过 `eprintln!` 输出。容器化部署中 stderr 被日志聚合系统捕获,有日志访问权限者可获取管理员密码。
---
## MEDIUM (12)
| # | 发现 | 文件 | CWE |
|---|------|------|-----|
| AUD-017 | 多个 Update handler 跳过输入验证 (软件黑名单/Web过滤器/剪贴板) | `software_blocker.rs:79`, `web_filter.rs:62`, `clipboard_control.rs:107` | CWE-20 |
| AUD-018 | USB 策略 rules 字段接受任意 JSON 无验证 | `usb.rs:94-135` | CWE-20 |
| AUD-019 | 密码无最大长度限制 (bcrypt 72 字节截断) | `auth.rs:303` | CWE-20 |
| AUD-020 | 多个字段缺少长度验证 (弹出窗口/剪贴板/USB策略名) | 多处 | CWE-20 |
| AUD-021 | 多个列表端点无分页 (黑名单/白名单/规则/策略) | 多处 | CWE-770 |
| AUD-022 | 磁盘加密状态列表无分页可全库转储 | `disk_encryption.rs:12-55` | CWE-770 |
| AUD-023 | JWT 角色仅信任 claim 不查库 (降级延迟) | `auth.rs:273` | CWE-863 |
| AUD-024 | 缺少 HSTS 头 | `main.rs:123-143` | CWE-319 |
| AUD-025 | CORS 配置需严格限制 | `main.rs:284-301` | CWE-942 |
| AUD-026 | 日志中泄露设备 UID 和服务器地址 | `main.rs:62`, `network/mod.rs:34` | CWE-532 |
| AUD-027 | 注册 Token 回退空字符串 | `main.rs:72` | CWE-254 |
| AUD-028 | conflict.rs 中 format! SQL 模式 (当前安全但脆弱) | `conflict.rs:205` | CWE-89 |
---
## LOW (8)
| # | 发现 | 文件 |
|---|------|------|
| AUD-029 | 组名未过滤 HTML 特殊字符 | `groups.rs:72-96` |
| AUD-030 | 弹出窗口阻止器更新可创建无过滤器的规则 | `popup_blocker.rs:67-104` |
| AUD-031 | 设备删除非原子 (自毁帧在事务前发送) | `devices.rs:215-306` |
| AUD-032 | 受保护进程列表硬编码且可修补绕过 | `software_blocker/mod.rs:9-73` |
| AUD-033 | hosts 文件修改可能与 EDR 冲突 | `web_filter/mod.rs:59-93` |
| AUD-034 | 软件拦截器 TOCTOU 竞争条件 (已缓解) | `software_blocker/mod.rs:329-386` |
| AUD-035 | 前端路由守卫不验证 JWT 签名 | `router/index.ts:38-49` |
| AUD-036 | WebSocket 不验证入站消息 (当前丢弃) | `ws.rs:108-119` |
---
## 修复优先级
### P0 — 立即 (24h)
| 修复项 | 对应发现 | 工作量 |
|--------|---------|--------|
| 轮换 JWT Secret移至环境变量 | AUD-001 | 30min |
| 设置非空 registration_token | AUD-002 | 15min |
| 凭据文件添加 Windows ACL | AUD-003 | 1h |
| 生产环境强制 TLS | AUD-004 | 2h |
### P1 — 短期 (1 周)
| 修复项 | 对应发现 | 工作量 |
|--------|---------|--------|
| Refresh token 存储到 DB + 事务保护 | AUD-005, 006 | 4h |
| Update handler 添加输入验证 | AUD-017 | 4h |
| Webhook URL 验证防 SSRF | AUD-013 | 1h |
| 磁盘加密确认移至 admin 路由 | AUD-015 | 15min |
| 初始密码写入文件替代 stderr | AUD-016 | 30min |
| 添加 IP 速率限制 | AUD-014 | 2h |
### P2 — 中期 (1 月)
| 修复项 | 对应发现 | 工作量 |
|--------|---------|--------|
| Token 迁移至 HttpOnly Cookie | AUD-010 | 8h |
| CSP 强化 (nonce-based) | AUD-011 | 4h |
| WebSocket ticket 认证 | AUD-012 | 4h |
| 服务器身份验证 (证书固定) | AUD-007 | 8h |
| 服务保护 (恢复策略/看门狗) | AUD-009 | 4h |
| PowerShell 注入面消除 | AUD-008 | 6h |
| HSTS + Permissions-Policy 头 | AUD-024, 036 | 1h |
| 分页补充 | AUD-021, 022 | 4h |
---
## 安全亮点 (做得好的地方)
1. **SQL 注入防御**: 全库一致使用 `sqlx::bind()` 参数化
2. **错误处理**: `ApiResponse::internal_error()` 不泄露内部错误详情
3. **密码哈希**: bcrypt cost=12符合行业标准
4. **Token Family 轮换**: 检测 token 重放并撤销整个 family
5. **常量时间比较**: 注册 token 验证已使用 `constant_time_eq()`
6. **帧速率限制**: 100 帧/5秒/连接
7. **审计日志**: 所有管理员写入操作记录到 `admin_audit_log`
8. **HMAC 心跳**: 设备认证使用 HMAC-SHA256
9. **进程保护列表**: 防止误杀系统关键进程
10. **输入验证**: Create handler 普遍有字段验证

88
wiki/client.md Normal file
View File

@@ -0,0 +1,88 @@
# Client客户端代理
## 设计思想
`csm-client` 是部署在医院终端设备上的 Windows 代理程序,设计为:
1. **无人值守运行** — 支持控制台模式(开发调试)和 Windows 服务模式(生产部署)
2. **自动重连** — 指数退避策略1s → 60s断线后 drain stale frames
3. **插件化采集** — 每个插件独立 task通过 `watch` channel 接收配置,通过 `mpsc` channel 上报数据
4. **单入口 data channel** — 所有插件共享一个 `mpsc::channel::<Frame>(1024)`network 模块统一发送
关键设计决策:
- **watch + mpsc 双通道** — `watch` 用于服务器推送配置到插件(多消费者最新值),`mpsc` 用于插件上报数据到网络层(多生产者有序队列)
- **device_uid 持久化** — UUID 首次生成后写入 `device_uid.txt`,与可执行文件同目录
- **device_secret 持久化** — 注册成功后写入 `device_secret.txt`,重启后自动认证
## 代码逻辑
### 启动流程
```
main() → load device_uid → load device_secret → create ClientState
→ create data channel (mpsc 1024)
→ create watch channels for each plugin config
→ spawn core tasks (monitor, asset, usb)
→ spawn plugin tasks (11 plugins)
→ reconnect loop: connect_and_run() with exponential backoff
```
### 网络层 (`network/mod.rs`)
- `connect_and_run()` — TCP 连接、注册/认证、双工读写循环
- `handle_server_message()` — 根据 MessageType 分发服务器下发的帧到对应 watch channel
- `PluginChannels` — 持有所有插件的 `watch::Sender`,用于接收服务器推送的配置
- 注册流程:发送 Register → 收到 RegisterResponse含 device_secret→ 持久化 secret
- 认证流程:已有 device_secret 时,心跳帧携带 HMAC-SHA256 签名
### 插件统一模板
每个插件遵循相同模式:
```rust
pub async fn start(
mut config_rx: watch::Receiver<PluginConfig>,
data_tx: mpsc::Sender<Frame>,
device_uid: String,
) {
loop {
tokio::select! {
result = config_rx.changed() => { /* 更新 config */ }
_ = interval.tick() => {
if !config.enabled { continue; }
// 采集数据 → Frame::new_json() → data_tx.send()
}
}
}
}
```
### 双模式运行
- **控制台模式**: 直接 `cargo run -p csm-client`Ctrl+C 优雅退出
- **服务模式**: `--install` 注册 Windows 服务、`--service` 以服务方式运行、`--uninstall` 卸载
## 关联模块
- [[protocol]] — 使用 Frame 构造上报帧,解析服务器下发帧
- [[server]] — TCP 连接的对端,接收帧并处理
- [[plugins]] — 每个插件的具体实现逻辑
## 关键文件
| 文件 | 职责 |
|------|------|
| `crates/client/src/main.rs` | 启动入口、插件 channel 创建、task spawn、重连循环 |
| `crates/client/src/network/mod.rs` | TCP 连接、注册认证、双工读写、服务器消息分发 |
| `crates/client/src/service.rs` | Windows 服务安装/卸载/运行(`#[cfg(target_os = "windows")]` |
| `crates/client/src/monitor/mod.rs` | 核心设备状态采集CPU/内存/进程) |
| `crates/client/src/asset/mod.rs` | 硬件/软件资产采集 |
| `crates/client/src/usb/mod.rs` | USB 设备插拔监控 |
| `crates/client/src/web_filter/mod.rs` | 上网拦截插件 |
| `crates/client/src/usage_timer/mod.rs` | 使用时长记录插件 |
| `crates/client/src/software_blocker/mod.rs` | 软件禁止安装插件 |
| `crates/client/src/popup_blocker/mod.rs` | 弹窗拦截插件 |
| `crates/client/src/usb_audit/mod.rs` | U盘文件操作审计插件 |
| `crates/client/src/watermark/mod.rs` | 屏幕水印插件 |
| `crates/client/src/disk_encryption/mod.rs` | 磁盘加密检测插件 |
| `crates/client/src/print_audit/mod.rs` | 打印审计插件 |
| `crates/client/src/clipboard_control/mod.rs` | 剪贴板管控插件 |
| `crates/client/src/patch/mod.rs` | 补丁管理插件 |

71
wiki/database.md Normal file
View File

@@ -0,0 +1,71 @@
# Database数据库层
## 设计思想
SQLite 单文件数据库WAL 模式支持并发读写。设计原则:
1. **只追加迁移** — 永不修改已有 migration 文件
2. **参数绑定** — 所有 SQL 使用 `.bind()`,绝不拼接
3. **upsert 模式**`ON CONFLICT ... DO UPDATE` 处理重复上报,必须更新 `updated_at`
4. **嵌入式迁移** — SQL 文件通过 `include_str!` 编译进二进制,运行时按序执行
5. **外键启用**`foreign_keys(true)` 强制引用完整性
## 代码逻辑
### 初始化
```
main() → init_database() → SQLite WAL + Normal sync + 5s busy timeout + FK on
→ run_migrations() → CREATE _migrations 表 → 按序执行 001-018
→ ensure_default_admin() → 首次启动生成随机 admin 密码
```
### 连接池配置
- 最大 8 连接
- cache_size = -64000 (64MB)
- wal_autocheckpoint = 1000
### 迁移历史
| # | 文件 | 内容 |
|---|------|------|
| 001 | init.sql | users, devices 表 |
| 002 | assets.sql | hardware_assets, software_assets, asset_changes 表 |
| 003 | usb.sql | usb_events, usb_policies, usb_device_patterns 表 |
| 004 | alerts.sql | alert_rules, alert_records 表 |
| 005 | web_filter.sql | web_filter_rules, web_access_logs 表 |
| 006 | usage_timer.sql | usage_daily, app_usage 表 |
| 007 | software_blocker.sql | software_blacklist, software_violations 表 |
| 008 | popup_blocker.sql | popup_blocker_rules, popup_block_stats 表 |
| 009 | usb_file_audit.sql | usb_file_operations 表 |
| 010 | watermark.sql | watermark_configs 表 |
| 011 | token_security.sql | token_families 表JWT token family 轮换) |
| 012 | disk_encryption.sql | disk_encryption_status, disk_encryption_alerts 表 |
| 013 | print_audit.sql | print_events 表 |
| 014 | clipboard_control.sql | clipboard_rules, clipboard_violations 表 |
| 015 | plugin_control.sql | plugin_states 表 |
| 016 | encryption_alerts_unique.sql | 唯一约束修复 |
| 017 | device_health_scores.sql | device_health_scores 表 |
| 018 | patch_management.sql | patch_status 表 |
### 数据操作层 (`db.rs`)
`DeviceRepo` 提供:
- 设备注册/查询/删除/分组
- 资产增删改查
- USB 事件记录和策略管理
- 告警规则和记录操作
- 所有插件数据的 CRUD
## 关联模块
- [[server]] — 通过 db.rs 访问数据库
- [[plugins]] — 每个插件有对应的数据库表
## 关键文件
| 文件 | 职责 |
|------|------|
| `crates/server/src/db.rs` | DeviceRepo 数据库操作方法集合 |
| `crates/server/src/main.rs` | 数据库初始化、迁移执行 |
| `migrations/001_init.sql` ~ `018_*.sql` | 数据库迁移脚本 |

46
wiki/index.md Normal file
View File

@@ -0,0 +1,46 @@
# CSM 知识库
## 项目画像
CSM (Client Security Manager) — 医院终端安全管控平台C/S + Web 三层架构。管理 11 个安全插件覆盖上网拦截、U盘管控、打印审计、剪贴板管控、补丁管理等场景。
**关键数字**: 3 个 Rust crate + Vue 前端 | 18 个数据库迁移 | 13 个客户端插件 | ~30 个 API 端点 | 自定义 TCP 二进制协议
## 模块导航树
```
CSM
├── [[protocol]] — 二进制协议层Frame 编解码、MessageType、payload 定义)
├── [[server]] — 服务端HTTP API + TCP 接入 + WebSocket + SQLite
├── [[client]] — 客户端代理Windows 服务、插件采集、自动重连)
├── [[web-frontend]] — Web 管理面板Vue 3 SPA
├── [[plugins]] — 插件体系(端到端设计、新增插件清单)
└── [[database]] — 数据库层SQLite、迁移、操作方法
```
## 核心架构决策
### 为什么用自定义 TCP 二进制协议而不是 HTTP
内网环境低延迟需求,二进制帧比 HTTP 更省带宽和延迟。帧头仅 10 字节MAGIC+VERSION+TYPE+LENGTHpayload 用 JSON 保持可调试性。
### 为什么插件配置用 watch channel 而不是 HTTP 轮询?
Server 主动推送配置变更到 Client避免轮询延迟。`tokio::watch` 保证每个插件总是读到最新配置值,配置下发 → 全链路秒级生效。
### 为什么嵌入前端而不是独立部署?
`include_dir!` 编译时打包 `web/dist/`,部署只需一个 server 二进制文件。SPA fallback 让前端路由(如 `/devices`)直接返回 `index.html`
### 为什么 SQLite 而不是 PostgreSQL
医院内网单机部署场景零外部依赖。WAL 模式 + 64MB 缓存足以支撑数百台终端的并发写入。
### 为什么三级作用域推送global/group/device
医院按科室分组管理设备。全局策略作为基线,科室策略覆盖特定需求,单设备策略处理例外情况。`push_to_targets()` 自动解析作用域并过滤在线设备。
## 技术栈速查
| 层 | 技术 |
|----|------|
| 服务端 | Rust + Axum + SQLx + SQLite + JWT + Rustls |
| 客户端 | Rust + Tokio + sysinfo + windows-rs |
| 协议 | 自定义 TCP 二进制MAGIC + VERSION + TYPE + LENGTH + JSON payload |
| 前端 | Vue 3 + TypeScript + Vite + Element Plus + Pinia + ECharts |
| 构建 | Cargo workspace + npm |

78
wiki/plugins.md Normal file
View File

@@ -0,0 +1,78 @@
# Plugin System插件体系
## 设计思想
CSM 的核心扩展机制,采用**端到端插件化**设计:
- Client 端每个插件独立 tokio task负责数据采集/策略执行
- Server 端每个插件有独立的 API handler 模块和数据库表
- Protocol 层每个插件有专属 MessageType 范围和 payload struct
- Frontend 端每个插件有独立页面组件
三级配置推送:`global``group``device`,优先级递增。
## 代码逻辑
### 插件全链路(以 Web Filter 为例)
```
1. API: POST /api/plugins/web-filter/rules → server/api/plugins/web_filter.rs
2. Server 存储 → db.rs → INSERT INTO web_filter_rules
3. 推送 → push_to_targets(db, clients, WebFilterRuleUpdate, payload, scope, target_id)
4. TCP → Client network/mod.rs handle_server_message() → web_filter_tx.send()
5. Client → web_filter/mod.rs config_rx.changed() → 更新本地规则 → 采集上报
6. Client → Frame::new_json(WebAccessLog, entry) → data_tx → network → TCP → server
7. Server → tcp.rs process_frame(WebAccessLog) → db.rs → INSERT INTO web_access_logs
8. Frontend → GET /api/plugins/web-filter/log → 展示
```
### 现有插件一览
| 插件 | 消息类型范围 | 方向 | 功能 |
|------|-------------|------|------|
| Web Filter | 0x2x | S→C 规则, C→S 日志 | URL 黑白名单、访问日志 |
| Usage Timer | 0x3x | C→S 报告 | 每日使用时长、应用使用统计 |
| Software Blocker | 0x4x | S→C 黑名单, C→S 违规 | 禁止安装软件、违规上报 |
| Popup Blocker | 0x5x | S→C 规则, C→S 统计 | 弹窗拦截规则、拦截统计 |
| USB File Audit | 0x6x | C→S 记录 | U盘文件操作审计 |
| Watermark | 0x70 | S→C 配置 | 屏幕水印显示配置 |
| USB Policy | 0x71 | S→C 策略 | U盘管控全阻/白名单/黑名单) |
| Plugin Control | 0x80-0x81 | S→C 命令 | 远程启停插件 |
| Disk Encryption | 0x90, 0x93 | C→S 状态, S→C 配置 | 磁盘加密状态检测 |
| Print Audit | 0x91 | C→S 事件 | 打印操作审计 |
| Clipboard Control | 0x94-0x95 | S→C 规则, C→S 违规 | 剪贴板操作管控(仅上报元数据) |
| Patch Management | 0xA0-0xA2 | 双向 | 系统补丁扫描与安装 |
| Behavior Metrics | 0xB0 | C→S 指标 | 行为指标采集(异常检测输入) |
### 新增插件必改文件清单
| # | 文件 | 改动 |
|---|------|------|
| 1 | `crates/protocol/src/message.rs` | 添加 MessageType 枚举值 + payload struct |
| 2 | `crates/protocol/src/lib.rs` | re-export 新类型 |
| 3 | `crates/client/src/<plugin>/mod.rs` | 创建插件实现 |
| 4 | `crates/client/src/main.rs` | `mod <plugin>`, watch channel, PluginChannels 字段, spawn |
| 5 | `crates/client/src/network/mod.rs` | PluginChannels 字段, handle_server_message 分支 |
| 6 | `crates/server/src/api/plugins/<plugin>.rs` | 创建 API handler |
| 7 | `crates/server/src/api/plugins/mod.rs` | mod 声明 + 路由注册 |
| 8 | `crates/server/src/tcp.rs` | process_frame 新分支 + push_all_plugin_configs |
| 9 | `crates/server/src/db.rs` | 新增 DB 操作方法 |
| 10 | `migrations/NNN_<name>.sql` | 新迁移文件 |
| 11 | `crates/server/src/main.rs` | include_str! 新迁移 |
## 关联模块
- [[protocol]] — 定义插件的 MessageType 和 payload
- [[client]] — 插件采集端
- [[server]] — 插件 API 和数据处理
- [[web-frontend]] — 插件管理页面
- [[database]] — 每个插件的数据库表
## 关键文件
| 文件 | 职责 |
|------|------|
| `crates/client/src/<plugin>/mod.rs` | 客户端插件实现(每个插件一个目录) |
| `crates/server/src/api/plugins/<plugin>.rs` | 服务端插件 API每个插件一个文件 |
| `crates/server/src/tcp.rs` | 帧分发 + push_to_targets + push_all_plugin_configs |
| `crates/client/src/main.rs` | 插件 watch channel 创建 + task spawn |
| `crates/client/src/network/mod.rs` | PluginChannels 定义 + 服务器消息分发 |

58
wiki/protocol.md Normal file
View File

@@ -0,0 +1,58 @@
# Protocol二进制协议层
## 设计思想
`csm-protocol` 是 Server 和 Client 共享的协议定义 crate。核心设计决策
1. **零拷贝编解码**`Frame::encode()` / `Frame::decode()` 直接操作字节切片,无中间分配
2. **类型安全**`MessageType` 枚举确保所有消息类型在编译期可见,`TryFrom<u8>` 处理未知类型
3. **JSON payload** — 网络传输用 JSON`serde`),兼顾可调试性和跨语言兼容性
4. **payload 上限 4MB**`MAX_PAYLOAD_SIZE` 防止恶意帧耗尽内存
二进制帧格式:`MAGIC(4B "CSM\0") + VERSION(1B) + TYPE(1B) + LENGTH(4B big-endian) + PAYLOAD(变长 JSON)`
## 代码逻辑
### 帧生命周期
```
发送方: T → Frame::new_json(mt, &data) → Frame::encode() → Vec<u8> → TCP stream
接收方: TCP bytes → Frame::decode(&buf) → Option<Frame> → Frame::decode_payload::<T>()
```
### MessageType 分块规划
| 范围 | 插件 | 方向 |
|------|------|------|
| 0x01-0x0F | Core心跳/注册/状态/资产) | 双向 |
| 0x10-0x1F | Core Server→Client策略/配置/任务) | S→C |
| 0x20-0x2F | Web Filter | C→S 日志, S→C 规则 |
| 0x30-0x3F | Usage Timer | C→S 报告 |
| 0x40-0x4F | Software Blocker | C→S 违规, S→C 黑名单 |
| 0x50-0x5F | Popup Blocker | C→S 统计, S→C 规则 |
| 0x60-0x6F | USB File Audit | C→S 操作记录 |
| 0x70-0x7F | Watermark + USB Policy | S→C 配置 |
| 0x80-0x8F | Plugin Control | S→C 启停命令 |
| 0x90-0x9F | Disk Encryption / Print / Clipboard | 混合 |
| 0xA0-0xAF | Patch Management | C→S 状态, S→C 配置 |
| 0xB0-0xBF | Behavior Metrics | C→S 指标 |
### 关键类型
- `Frame` — 帧结构version + msg_type + payload bytes
- `FrameError` — 解码错误枚举InvalidMagic / UnknownMessageType / PayloadTooLarge / Io
- 每个 MessageType 对应一个 payload struct`WebAccessLogEntry`, `HeartbeatPayload`
## 关联模块
- [[server]] — TCP 接入层调用 `Frame::decode()` 解析客户端帧,调用 `push_to_targets()` 推送配置帧
- [[client]] — 通过 `Frame::new_json()` 构造上报帧,通过 `Frame::decode()` 解析服务器下发的帧
- [[plugins]] — 每个插件定义自己的 payload struct 在此 crate 中
## 关键文件
| 文件 | 职责 |
|------|------|
| `crates/protocol/src/message.rs` | MessageType 枚举、Frame 编解码、所有 payload struct |
| `crates/protocol/src/device.rs` | DeviceStatus、ProcessInfo、HardwareAsset、UsbEvent 等设备相关类型 |
| `crates/protocol/src/lib.rs` | Re-export 所有公开类型 |

84
wiki/server.md Normal file
View File

@@ -0,0 +1,84 @@
# Server服务端
## 设计思想
`csm-server` 是整个系统的核心枢纽,同时承载三个协议:
1. **TCP 二进制协议** (端口 9999) — 接入 Client 代理
2. **HTTP REST API** (端口 9998) — 服务 Web 面板
3. **WebSocket** (`/ws`) — 实时推送设备状态变更到前端
关键设计决策:
- **SQLite + WAL** — 单机部署零依赖WAL 模式支持并发读写
- **include_dir 嵌入前端** — 编译时将 `web/dist/` 打包进二进制,部署只需一个文件
- **三层权限** — public登录/健康检查)→ authenticated只读→ admin写操作
- **ClientRegistry** — `Arc<RwLock<HashMap>>` 管理在线客户端的 TCP 写端,支持 `push_to_targets()` 三级作用域推送
## 代码逻辑
### 启动流程
```
main() → load config → init SQLite → run migrations → ensure admin
→ spawn TCP listener (9999)
→ spawn alert cleanup task
→ spawn health score task
→ build HTTP router (9998) with CORS/security headers/SPA fallback
→ axum::serve()
```
### TCP 接入层 (`tcp.rs`)
- `start_tcp_server()` — 监听 TCP每连接 spawn 一个 task
- `process_frame()` — 根据 MessageType 分发到对应 handler需先 verify_device_uid
- `ClientRegistry` — 线程安全的在线设备注册表,支持 `list_online()``send_frame()`
- `push_to_targets(db, clients, msg_type, payload, target_type, target_id)` — 三级作用域推送global/group/device
- 帧速率限制100 帧/5秒/连接
- HMAC 验证:心跳帧必须携带 HMAC-SHA256 签名,连续 3 次失败断开
- 空闲超时180 秒无数据断开
- 最大并发连接500
### HTTP API (`api/`)
路由分三层:
- **public**: `/api/auth/login`, `/api/auth/refresh`, `/health`
- **authenticated** (require_auth 中间件): GET 类设备/资产/告警/插件查询
- **admin** (require_admin + require_auth): 设备删除、策略增删改、插件配置写入
统一响应格式 `ApiResponse<T>``{ success, data, error }`,分页默认 page=1, page_size=20, 上限 100。
### WebSocket (`ws.rs`)
- `WsHub` 广播设备上线/离线/状态变更事件给所有连接的前端客户端
- JWT 认证通过 query parameter `?token=xxx`
### 后台任务
- `alert::cleanup_task()` — 定期清理过期告警
- `health::health_score_task()` — 定期计算设备健康评分
## 关联模块
- [[protocol]] — 使用 Frame 编解码和 MessageType 分发
- [[client]] — TCP 连接的对端
- [[web-frontend]] — HTTP API 和 WebSocket 的消费者
- [[plugins]] — API 层的 plugins/ 子模块处理所有插件相关路由
- [[database]] — 数据库操作集中在 db.rs
## 关键文件
| 文件 | 职责 |
|------|------|
| `crates/server/src/main.rs` | 启动入口、数据库初始化、迁移、路由组装、SPA fallback |
| `crates/server/src/tcp.rs` | TCP 监听、帧处理、ClientRegistry、push_to_targets |
| `crates/server/src/ws.rs` | WebSocket hub 广播 |
| `crates/server/src/api/mod.rs` | 路由定义、ApiResponse 信封、Pagination |
| `crates/server/src/api/auth.rs` | JWT 登录/刷新/改密、限流、require_auth/require_admin 中间件 |
| `crates/server/src/api/devices.rs` | 设备列表/详情/状态/历史/健康评分 API |
| `crates/server/src/api/plugins/mod.rs` | 插件路由注册read_routes + write_routes |
| `crates/server/src/api/plugins/*.rs` | 各插件 API handler每个插件一个文件 |
| `crates/server/src/db.rs` | DeviceRepo 数据库操作方法集合 |
| `crates/server/src/config.rs` | AppConfig TOML 配置加载 |
| `crates/server/src/health.rs` | 设备健康评分计算 |
| `crates/server/src/anomaly.rs` | 异常检测逻辑 |
| `crates/server/src/alert.rs` | 告警处理与清理 |
| `crates/server/src/audit.rs` | 审计日志 |

70
wiki/web-frontend.md Normal file
View File

@@ -0,0 +1,70 @@
# Web Frontend管理面板
## 设计思想
Vue 3 + TypeScript + Vite + Element Plus + Pinia + ECharts 的单页应用。关键决策:
1. **SPA 嵌入部署** — 构建产物 `web/dist/` 通过 `include_dir!` 编译进 server 二进制,部署零额外依赖
2. **JWT 本地存储** — token 存 `localStorage`路由守卫检查过期30 秒内即将过期视为无效
3. **按路由懒加载** — 所有页面组件使用 `() => import(...)` 动态导入
## 代码逻辑
### 路由结构
```
/login → Login.vue公开
/ → Layout.vue认证后
/dashboard → Dashboard.vue仪表盘/健康概览)
/devices → Devices.vue设备列表
/devices/:uid → DeviceDetail.vue设备详情
/usb → UsbPolicy.vueU盘策略管理
/alerts → Alerts.vue告警管理
/settings → Settings.vue系统设置
/plugins/web-filter → WebFilter.vue
/plugins/usage-timer → UsageTimer.vue
/plugins/software-blocker → SoftwareBlocker.vue
/plugins/popup-blocker → PopupBlocker.vue
/plugins/usb-file-audit → UsbFileAudit.vue
/plugins/watermark → Watermark.vue
/plugins/disk-encryption → DiskEncryption.vue
/plugins/print-audit → PrintAudit.vue
/plugins/clipboard-control → ClipboardControl.vue
/plugins/plugin-control → PluginControl.vue
```
### 认证流程
1. Login.vue → `POST /api/auth/login` → 获取 access_token + refresh_token
2. token 存入 localStorage
3. 路由守卫 `beforeEach` 检查 JWT 过期(解析 payload.exp
4. API 调用携带 `Authorization: Bearer <token>` header
5. token 过期 → 自动跳转 /login
### API 通信
`web/src/lib/api.ts` — 封装所有 API 调用,统一处理认证和错误。
### 状态管理
`web/src/stores/devices.ts` — Pinia store 管理设备列表状态。
## 关联模块
- [[server]] — 消费其 HTTP REST API 和 WebSocket 推送
- [[plugins]] — 每个插件页面对应 server 端的插件 API
## 关键文件
| 文件 | 职责 |
|------|------|
| `web/src/main.ts` | 应用入口、Vue 实例创建 |
| `web/src/App.vue` | 根组件 |
| `web/src/router/index.ts` | 路由定义、JWT 路由守卫 |
| `web/src/lib/api.ts` | API 通信封装 |
| `web/src/stores/devices.ts` | Pinia 设备状态管理 |
| `web/src/views/Layout.vue` | 主布局(侧边栏+内容区) |
| `web/src/views/Dashboard.vue` | 仪表盘页 |
| `web/src/views/Devices.vue` | 设备列表页 |
| `web/src/views/DeviceDetail.vue` | 设备详情页 |
| `web/src/views/plugins/*.vue` | 各插件管理页面 |