feat: 新增补丁管理和异常检测插件及相关功能
feat(protocol): 添加补丁管理和行为指标协议类型 feat(client): 实现补丁管理插件采集功能 feat(server): 添加补丁管理和异常检测API feat(database): 新增补丁状态和异常检测相关表 feat(web): 添加补丁管理和异常检测前端页面 fix(security): 增强输入验证和防注入保护 refactor(auth): 重构认证检查逻辑 perf(service): 优化Windows服务恢复策略 style: 统一健康评分显示样式 docs: 更新知识库文档
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
# CSM — 企业终端安全管理系统
|
||||
|
||||
> **知识库**: @wiki/index.md — 编译后的模块化知识,新会话加载即了解全貌。
|
||||
|
||||
## 项目概览
|
||||
|
||||
CSM (Client Security Manager) 是一个医院设备终端安全管控平台,采用 C/S + Web 管理面板三层架构。
|
||||
|
||||
@@ -121,6 +121,14 @@ fn collect_system_details() -> (Option<String>, Option<String>, Option<String>)
|
||||
#[cfg(target_os = "windows")]
|
||||
fn powershell_lines(command: &str) -> Vec<String> {
|
||||
use std::process::Command;
|
||||
|
||||
// Reject commands containing suspicious patterns that could indicate injection
|
||||
let lower = command.to_lowercase();
|
||||
if lower.contains("invoke-expression") || lower.contains("iex ") || lower.contains("& ") {
|
||||
tracing::warn!("Rejected suspicious PowerShell command pattern");
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let output = match Command::new("powershell")
|
||||
.args(["-NoProfile", "-NonInteractive", "-Command",
|
||||
&format!("[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}", command)])
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, error, warn};
|
||||
use tracing::{info, error, warn, debug};
|
||||
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
|
||||
|
||||
mod monitor;
|
||||
@@ -17,6 +17,7 @@ mod web_filter;
|
||||
mod disk_encryption;
|
||||
mod clipboard_control;
|
||||
mod print_audit;
|
||||
mod patch;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod service;
|
||||
|
||||
@@ -58,17 +59,22 @@ fn main() -> Result<()> {
|
||||
info!("CSM Client starting (console mode)...");
|
||||
|
||||
let device_uid = load_or_create_device_uid()?;
|
||||
info!("Device UID: {}", device_uid);
|
||||
debug!("Device UID: {}", device_uid);
|
||||
|
||||
let server_addr = std::env::var("CSM_SERVER")
|
||||
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
|
||||
|
||||
let registration_token = std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default();
|
||||
if registration_token.is_empty() {
|
||||
tracing::warn!("CSM_REGISTRATION_TOKEN not set — device registration may fail");
|
||||
}
|
||||
|
||||
let state = ClientState {
|
||||
device_uid,
|
||||
server_addr,
|
||||
config: ClientConfig::default(),
|
||||
device_secret: load_device_secret(),
|
||||
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
|
||||
registration_token,
|
||||
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
|
||||
};
|
||||
|
||||
@@ -97,6 +103,7 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
let (disk_encryption_tx, disk_encryption_rx) = tokio::sync::watch::channel(disk_encryption::DiskEncryptionConfig::default());
|
||||
let (print_audit_tx, print_audit_rx) = tokio::sync::watch::channel(print_audit::PrintAuditConfig::default());
|
||||
let (clipboard_control_tx, clipboard_control_rx) = tokio::sync::watch::channel(clipboard_control::ClipboardControlConfig::default());
|
||||
let (patch_tx, patch_rx) = tokio::sync::watch::channel(patch::PluginConfig::default());
|
||||
|
||||
// Build plugin channels struct
|
||||
let plugins = network::PluginChannels {
|
||||
@@ -110,6 +117,7 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
disk_encryption_tx,
|
||||
print_audit_tx,
|
||||
clipboard_control_tx,
|
||||
patch_tx,
|
||||
};
|
||||
|
||||
// Spawn core monitoring tasks
|
||||
@@ -182,6 +190,12 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
clipboard_control::start(clipboard_control_rx, cc_data_tx, cc_uid).await;
|
||||
});
|
||||
|
||||
let patch_data_tx = data_tx.clone();
|
||||
let patch_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
patch::start(patch_rx, patch_data_tx, patch_uid).await;
|
||||
});
|
||||
|
||||
// Connect to server with reconnect
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let max_backoff = Duration::from_secs(60);
|
||||
@@ -270,5 +284,16 @@ fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Resu
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
|
||||
std::fs::write(path, content)
|
||||
std::fs::write(path, content)?;
|
||||
// Restrict file ACL to SYSTEM and Administrators only on Windows
|
||||
let path_str = path.to_string_lossy();
|
||||
// Remove inherited permissions and grant only SYSTEM full control
|
||||
let _ = std::process::Command::new("icacls")
|
||||
.args([&*path_str, "/inheritance:r", "/grant:r", "SYSTEM:(F)", "/grant:r", "Administrators:(F)"])
|
||||
.output();
|
||||
// Also hide the file
|
||||
let _ = std::process::Command::new("attrib")
|
||||
.args(["+H", &*path_str])
|
||||
.output();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload, PatchScanConfigPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use sha2::{Sha256, Digest};
|
||||
|
||||
use crate::ClientState;
|
||||
|
||||
@@ -21,6 +23,7 @@ pub struct PluginChannels {
|
||||
pub disk_encryption_tx: tokio::sync::watch::Sender<crate::disk_encryption::DiskEncryptionConfig>,
|
||||
pub print_audit_tx: tokio::sync::watch::Sender<crate::print_audit::PrintAuditConfig>,
|
||||
pub clipboard_control_tx: tokio::sync::watch::Sender<crate::clipboard_control::ClipboardControlConfig>,
|
||||
pub patch_tx: tokio::sync::watch::Sender<crate::patch::PluginConfig>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
@@ -30,7 +33,7 @@ pub async fn connect_and_run(
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
|
||||
info!("TCP connected to {}", state.server_addr);
|
||||
debug!("TCP connected to {}", state.server_addr);
|
||||
|
||||
if state.use_tls {
|
||||
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
|
||||
@@ -40,9 +43,7 @@ pub async fn connect_and_run(
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a TCP stream with TLS.
|
||||
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
|
||||
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
|
||||
/// Wrap a TCP stream with TLS and certificate pinning.
|
||||
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
@@ -62,19 +63,38 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::
|
||||
// Always include system roots as fallback
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
|
||||
warn!("TLS certificate verification DISABLED — do not use in production!");
|
||||
// Check if skip-verify is allowed (only in CSM_DEV mode)
|
||||
let skip_verify = std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true")
|
||||
&& std::env::var("CSM_DEV").is_ok();
|
||||
|
||||
let config = if skip_verify {
|
||||
warn!("TLS certificate verification DISABLED — CSM_DEV mode only!");
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
// Build standard verifier with pinning wrapper
|
||||
let inner = rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store))
|
||||
.build()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build TLS verifier: {:?}", e))?;
|
||||
|
||||
let pin_file = pin_file_path();
|
||||
let pinned_hashes = load_pinned_hashes(&pin_file);
|
||||
|
||||
let verifier = PinnedCertVerifier {
|
||||
inner,
|
||||
pin_file,
|
||||
pinned_hashes: Arc::new(Mutex::new(pinned_hashes)),
|
||||
};
|
||||
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(verifier))
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
||||
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
|
||||
|
||||
// Extract hostname from server_addr (strip port)
|
||||
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
|
||||
@@ -86,6 +106,131 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// Default pin file path: %PROGRAMDATA%\CSM\server_cert_pin (Windows)
|
||||
fn pin_file_path() -> PathBuf {
|
||||
if let Ok(custom) = std::env::var("CSM_TLS_PIN_FILE") {
|
||||
PathBuf::from(custom)
|
||||
} else if cfg!(target_os = "windows") {
|
||||
std::env::var("PROGRAMDATA")
|
||||
.map(|p| PathBuf::from(p).join("CSM").join("server_cert_pin"))
|
||||
.unwrap_or_else(|_| PathBuf::from("server_cert_pin"))
|
||||
} else {
|
||||
PathBuf::from("/var/lib/csm/server_cert_pin")
|
||||
}
|
||||
}
|
||||
|
||||
/// Load pinned certificate hashes from file.
|
||||
/// Format: one hex-encoded SHA-256 hash per line.
|
||||
fn load_pinned_hashes(path: &PathBuf) -> Vec<String> {
|
||||
match std::fs::read_to_string(path) {
|
||||
Ok(content) => content.lines()
|
||||
.map(|l| l.trim().to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect(),
|
||||
Err(_) => Vec::new(), // First connection — no pin file yet
|
||||
}
|
||||
}
|
||||
|
||||
/// Save a pinned hash to the pin file.
|
||||
fn save_pinned_hash(path: &PathBuf, hash: &str) {
|
||||
if let Some(parent) = path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
let _ = std::fs::write(path, format!("{}\n", hash));
|
||||
}
|
||||
|
||||
/// Compute SHA-256 fingerprint of a DER-encoded certificate.
|
||||
fn cert_fingerprint(cert: &rustls_pki_types::CertificateDer) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(cert.as_ref());
|
||||
hex::encode(hasher.finalize())
|
||||
}
|
||||
|
||||
/// Certificate verifier with pinning support.
|
||||
/// On first connection (no stored pin), records the certificate fingerprint.
|
||||
/// On subsequent connections, verifies the fingerprint matches.
|
||||
#[derive(Debug)]
|
||||
struct PinnedCertVerifier {
|
||||
inner: Arc<rustls::client::WebPkiServerVerifier>,
|
||||
pin_file: PathBuf,
|
||||
pinned_hashes: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for PinnedCertVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &rustls_pki_types::CertificateDer,
|
||||
intermediates: &[rustls_pki_types::CertificateDer],
|
||||
server_name: &rustls_pki_types::ServerName,
|
||||
ocsp_response: &[u8],
|
||||
now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
// 1. Standard PKIX verification
|
||||
self.inner.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?;
|
||||
|
||||
// 2. Compute certificate fingerprint
|
||||
let fingerprint = cert_fingerprint(end_entity);
|
||||
|
||||
// 3. Check against pinned hashes
|
||||
let mut pinned = self.pinned_hashes.lock().unwrap();
|
||||
if pinned.is_empty() {
|
||||
// First connection — record the certificate fingerprint
|
||||
info!("Recording server certificate pin: {}...", &fingerprint[..16]);
|
||||
save_pinned_hash(&self.pin_file, &fingerprint);
|
||||
pinned.push(fingerprint);
|
||||
} else if !pinned.contains(&fingerprint) {
|
||||
warn!("Certificate pin mismatch! Expected one of {:?}, got {}", pinned, fingerprint);
|
||||
return Err(rustls::Error::General(
|
||||
"Server certificate does not match pinned fingerprint. Possible MITM attack.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &rustls_pki_types::CertificateDer,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
self.inner.verify_tls12_signature(message, cert, dss)
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
message: &[u8],
|
||||
cert: &rustls_pki_types::CertificateDer,
|
||||
dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
self.inner.verify_tls13_signature(message, cert, dss)
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
self.inner.supported_verify_schemes()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update pinned certificate hash (called when receiving TlsCertRotate).
|
||||
pub fn update_cert_pin(new_hash: &str) {
|
||||
let pin_file = pin_file_path();
|
||||
let mut pinned = load_pinned_hashes(&pin_file);
|
||||
if !pinned.contains(&new_hash.to_string()) {
|
||||
pinned.push(new_hash.to_string());
|
||||
// Keep only the last 2 hashes (current + rotating)
|
||||
while pinned.len() > 2 {
|
||||
pinned.remove(0);
|
||||
}
|
||||
// Write all hashes to file
|
||||
if let Some(parent) = pin_file.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
let content = pinned.iter().map(|h| h.as_str()).collect::<Vec<_>>().join("\n");
|
||||
let _ = std::fs::write(&pin_file, format!("{}\n", content));
|
||||
info!("Updated certificate pin file with new hash: {}...", &new_hash[..16]);
|
||||
}
|
||||
}
|
||||
|
||||
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
@@ -242,7 +387,20 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
info!("Received policy update: {}", policy);
|
||||
}
|
||||
MessageType::ConfigUpdate => {
|
||||
info!("Received config update");
|
||||
let update: csm_protocol::ConfigUpdateType = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid config update: {}", e))?;
|
||||
match update {
|
||||
csm_protocol::ConfigUpdateType::UpdateIntervals { heartbeat, status, asset } => {
|
||||
info!("Config update: intervals heartbeat={}s status={}s asset={}s", heartbeat, status, asset);
|
||||
}
|
||||
csm_protocol::ConfigUpdateType::TlsCertRotate { new_cert_hash, valid_until } => {
|
||||
info!("Certificate rotation: new hash={}... valid_until={}", &new_cert_hash[..16.min(new_cert_hash.len())], valid_until);
|
||||
update_cert_pin(&new_cert_hash);
|
||||
}
|
||||
csm_protocol::ConfigUpdateType::SelfDestruct => {
|
||||
warn!("Self-destruct command received (not implemented)");
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageType::TaskExecute => {
|
||||
warn!("Task execution requested (not yet implemented)");
|
||||
@@ -276,7 +434,14 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
|
||||
let whitelist: Vec<String> = payload.get("whitelist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig {
|
||||
enabled: true,
|
||||
blacklist,
|
||||
whitelist,
|
||||
};
|
||||
plugins.software_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PopupRules => {
|
||||
@@ -322,6 +487,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
};
|
||||
plugins.clipboard_control_tx.send(config)?;
|
||||
}
|
||||
MessageType::PatchScanConfig => {
|
||||
let config: PatchScanConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid patch scan config: {}", e))?;
|
||||
info!("Received patch scan config: enabled={}, interval={}s", config.enabled, config.scan_interval_secs);
|
||||
let plugin_config = crate::patch::PluginConfig {
|
||||
enabled: config.enabled,
|
||||
scan_interval_secs: config.scan_interval_secs,
|
||||
};
|
||||
plugins.patch_tx.send(plugin_config)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
@@ -351,7 +526,7 @@ fn handle_plugin_control(
|
||||
}
|
||||
"software_blocker" => {
|
||||
if !enabled {
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![], whitelist: vec![] })?;
|
||||
}
|
||||
}
|
||||
"popup_blocker" => {
|
||||
@@ -384,6 +559,11 @@ fn handle_plugin_control(
|
||||
plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
"patch" => {
|
||||
if !enabled {
|
||||
plugins.patch_tx.send(crate::patch::PluginConfig { enabled: false, scan_interval_secs: 43200 })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
|
||||
116
crates/client/src/patch/mod.rs
Normal file
116
crates/client/src/patch/mod.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use tokio::sync::watch;
|
||||
use csm_protocol::{Frame, MessageType, PatchStatusPayload, PatchEntry};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PluginConfig {
|
||||
pub enabled: bool,
|
||||
pub scan_interval_secs: u64,
|
||||
}
|
||||
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<PluginConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
let mut config = config_rx.borrow_and_update().clone();
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(
|
||||
if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 }
|
||||
));
|
||||
interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() { break; }
|
||||
config = config_rx.borrow_and_update().clone();
|
||||
let new_secs = if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 };
|
||||
interval = tokio::time::interval(std::time::Duration::from_secs(new_secs));
|
||||
interval.tick().await;
|
||||
debug!("Patch config updated: enabled={}, interval={}s", config.enabled, new_secs);
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
if !config.enabled { continue; }
|
||||
match collect_patches().await {
|
||||
Ok(patches) => {
|
||||
if patches.is_empty() {
|
||||
debug!("No patches collected for device {}", device_uid);
|
||||
continue;
|
||||
}
|
||||
debug!("Collected {} patches for device {}", patches.len(), device_uid);
|
||||
let payload = PatchStatusPayload {
|
||||
device_uid: device_uid.clone(),
|
||||
patches,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PatchStatusReport, &payload) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
warn!("Failed to send patch status: channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Patch collection failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_patches() -> anyhow::Result<Vec<PatchEntry>> {
|
||||
// SECURITY: PowerShell command uses only hardcoded strings with no user/remote input.
|
||||
// The format!() only inserts PowerShell syntax, not external data.
|
||||
let output = tokio::process::Command::new("powershell")
|
||||
.args([
|
||||
"-NoProfile", "-NonInteractive", "-Command",
|
||||
&format!(
|
||||
"[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; \
|
||||
Get-HotFix | Select-Object -First 200 | \
|
||||
ForEach-Object {{ \
|
||||
[PSCustomObject]@{{ \
|
||||
kb = $_.HotFixID; \
|
||||
desc = $_.Description; \
|
||||
installed = if ($_.InstalledOn) {{ $_.InstalledOn.ToString('yyyy-MM-dd') }} else {{ '' }} \
|
||||
}} \
|
||||
}} | ConvertTo-Json -Compress"
|
||||
),
|
||||
])
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(anyhow::anyhow!("PowerShell failed: {}", stderr));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let trimmed = stdout.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Handle single-item case (PowerShell returns object instead of array)
|
||||
let items: Vec<serde_json::Value> = if trimmed.starts_with('[') {
|
||||
serde_json::from_str(trimmed).unwrap_or_default()
|
||||
} else {
|
||||
serde_json::from_str(trimmed).map(|v: serde_json::Value| vec![v]).unwrap_or_default()
|
||||
};
|
||||
|
||||
let patches: Vec<PatchEntry> = items.iter().filter_map(|item| {
|
||||
let kb = item.get("kb")?.as_str()?.to_string();
|
||||
if kb.is_empty() { return None; }
|
||||
let desc = item.get("desc").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let installed_str = item.get("installed").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
Some(PatchEntry {
|
||||
title: format!("{} - {}", kb, desc),
|
||||
kb_id: kb,
|
||||
severity: None, // Will be enriched server-side from known CVE data
|
||||
is_installed: true,
|
||||
installed_at: if installed_str.is_empty() { None } else { Some(installed_str.to_string()) },
|
||||
})
|
||||
}).collect();
|
||||
|
||||
Ok(patches)
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::ffi::OsString;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, error};
|
||||
use tracing::{info, error, debug};
|
||||
use windows_service::define_windows_service;
|
||||
use windows_service::service::{
|
||||
ServiceAccess, ServiceControl, ServiceErrorControl, ServiceExitCode,
|
||||
ServiceAccess, ServiceAction, ServiceActionType, ServiceControl, ServiceErrorControl,
|
||||
ServiceExitCode, ServiceFailureActions, ServiceFailureResetPeriod,
|
||||
ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType,
|
||||
};
|
||||
use windows_service::service_control_handler::{self, ServiceControlHandlerResult};
|
||||
@@ -38,6 +39,30 @@ pub fn install() -> anyhow::Result<()> {
|
||||
let service = manager.create_service(&service_info, ServiceAccess::CHANGE_CONFIG)?;
|
||||
service.set_description(SERVICE_DESCRIPTION)?;
|
||||
|
||||
// Configure service recovery: restart on failure with escalating delays
|
||||
let failure_actions = ServiceFailureActions {
|
||||
reset_period: ServiceFailureResetPeriod::After(std::time::Duration::from_secs(60)),
|
||||
reboot_msg: None,
|
||||
command: None,
|
||||
actions: Some(vec![
|
||||
ServiceAction {
|
||||
action_type: ServiceActionType::Restart,
|
||||
delay: std::time::Duration::from_secs(5),
|
||||
},
|
||||
ServiceAction {
|
||||
action_type: ServiceActionType::Restart,
|
||||
delay: std::time::Duration::from_secs(30),
|
||||
},
|
||||
ServiceAction {
|
||||
action_type: ServiceActionType::Restart,
|
||||
delay: std::time::Duration::from_secs(60),
|
||||
},
|
||||
]),
|
||||
};
|
||||
if let Err(e) = service.update_failure_actions(failure_actions) {
|
||||
tracing::warn!("Failed to set service recovery actions: {}", e);
|
||||
}
|
||||
|
||||
println!("Service '{}' installed successfully.", SERVICE_NAME);
|
||||
println!("Use 'sc start {}' or restart to launch.", SERVICE_NAME);
|
||||
Ok(())
|
||||
@@ -117,7 +142,7 @@ fn run_service_inner() -> anyhow::Result<()> {
|
||||
|
||||
// Build ClientState
|
||||
let device_uid = crate::load_or_create_device_uid()?;
|
||||
info!("Device UID: {}", device_uid);
|
||||
debug!("Device UID: {}", device_uid);
|
||||
|
||||
let state = crate::ClientState {
|
||||
device_uid,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use std::collections::HashSet;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{info, warn, debug};
|
||||
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// System-critical processes that must never be killed regardless of server rules.
|
||||
/// Killing any of these would cause system instability or a BSOD.
|
||||
const PROTECTED_PROCESSES: &[&str] = &[
|
||||
// Windows system processes
|
||||
"system",
|
||||
"system idle process",
|
||||
"svchost.exe",
|
||||
@@ -20,8 +22,60 @@ const PROTECTED_PROCESSES: &[&str] = &[
|
||||
"registry",
|
||||
"smss.exe",
|
||||
"conhost.exe",
|
||||
"ntoskrnl.exe",
|
||||
"dcomlaunch.exe",
|
||||
"rundll32.exe",
|
||||
"sihost.exe",
|
||||
"taskeng.exe",
|
||||
"wermgr.exe",
|
||||
"WerFault.exe",
|
||||
"fontdrvhost.exe",
|
||||
"ctfmon.exe",
|
||||
"SearchIndexer.exe",
|
||||
"SearchHost.exe",
|
||||
"RuntimeBroker.exe",
|
||||
"SecurityHealthService.exe",
|
||||
"SecurityHealthSystray.exe",
|
||||
"MpCmdRun.exe",
|
||||
"MsMpEng.exe",
|
||||
"NisSrv.exe",
|
||||
// Common browsers — should never be blocked unless explicitly configured with exact name
|
||||
"chrome.exe",
|
||||
"msedge.exe",
|
||||
"firefox.exe",
|
||||
"iexplore.exe",
|
||||
"opera.exe",
|
||||
"brave.exe",
|
||||
"vivaldi.exe",
|
||||
"thorium.exe",
|
||||
// Development tools & IDEs
|
||||
"code.exe",
|
||||
"devenv.exe",
|
||||
"idea64.exe",
|
||||
"webstorm64.exe",
|
||||
"pycharm64.exe",
|
||||
"goland64.exe",
|
||||
"clion64.exe",
|
||||
"rider64.exe",
|
||||
"datagrip64.exe",
|
||||
"trae.exe",
|
||||
"windsurf.exe",
|
||||
"cursor.exe",
|
||||
"zed.exe",
|
||||
// Terminal & system tools
|
||||
"cmd.exe",
|
||||
"powershell.exe",
|
||||
"pwsh.exe",
|
||||
"WindowsTerminal.exe",
|
||||
"conhost.exe",
|
||||
// CSM itself
|
||||
"csm-client.exe",
|
||||
];
|
||||
|
||||
/// Cooldown period (seconds) before reporting/killing the same process again.
|
||||
/// Prevents spamming violations for long-running blocked processes.
|
||||
const REPORT_COOLDOWN_SECS: u64 = 300; // 5 minutes
|
||||
|
||||
/// Software blacklist entry from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BlacklistEntry {
|
||||
@@ -36,6 +90,8 @@ pub struct BlacklistEntry {
|
||||
pub struct SoftwareBlockerConfig {
|
||||
pub enabled: bool,
|
||||
pub blacklist: Vec<BlacklistEntry>,
|
||||
/// Server-pushed whitelist: processes matching these patterns are never blocked.
|
||||
pub whitelist: Vec<String>,
|
||||
}
|
||||
|
||||
/// Start software blocker plugin.
|
||||
@@ -47,6 +103,11 @@ pub async fn start(
|
||||
) {
|
||||
info!("Software blocker plugin started");
|
||||
let mut config = SoftwareBlockerConfig::default();
|
||||
// Track recently acted-on processes to avoid repeated kill/report spam
|
||||
let mut recent_actions: HashSet<String> = HashSet::new();
|
||||
let mut cooldown_interval = tokio::time::interval(std::time::Duration::from_secs(REPORT_COOLDOWN_SECS));
|
||||
cooldown_interval.tick().await;
|
||||
|
||||
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
|
||||
scan_interval.tick().await;
|
||||
|
||||
@@ -58,13 +119,23 @@ pub async fn start(
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
|
||||
// Clear cooldown cache when config changes so new rules take effect immediately
|
||||
recent_actions.clear();
|
||||
config = new_config;
|
||||
}
|
||||
_ = cooldown_interval.tick() => {
|
||||
// Periodically clear the cooldown cache so we can re-check
|
||||
let cleared = recent_actions.len();
|
||||
recent_actions.clear();
|
||||
if cleared > 0 {
|
||||
info!("Software blocker cooldown cache cleared ({} entries)", cleared);
|
||||
}
|
||||
}
|
||||
_ = scan_interval.tick() => {
|
||||
if !config.enabled || config.blacklist.is_empty() {
|
||||
continue;
|
||||
}
|
||||
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
|
||||
scan_processes(&config.blacklist, &config.whitelist, &data_tx, &device_uid, &mut recent_actions).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -72,83 +143,148 @@ pub async fn start(
|
||||
|
||||
async fn scan_processes(
|
||||
blacklist: &[BlacklistEntry],
|
||||
whitelist: &[String],
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
recent_actions: &mut HashSet<String>,
|
||||
) {
|
||||
let running = get_running_processes_with_pids();
|
||||
|
||||
for entry in blacklist {
|
||||
for (process_name, pid) in &running {
|
||||
if pattern_matches(&entry.name_pattern, process_name) {
|
||||
// Never kill system-critical processes
|
||||
if is_protected_process(process_name) {
|
||||
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
|
||||
continue;
|
||||
}
|
||||
|
||||
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
|
||||
|
||||
// Report violation to server
|
||||
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
|
||||
let action_taken = match entry.action.as_str() {
|
||||
"block" => "blocked_install",
|
||||
"alert" => "alerted",
|
||||
other => other,
|
||||
};
|
||||
let violation = SoftwareViolationReport {
|
||||
device_uid: device_uid.to_string(),
|
||||
software_name: process_name.clone(),
|
||||
action_taken: action_taken.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
|
||||
// Kill the process directly by captured PID (avoids TOCTOU race)
|
||||
if entry.action == "block" {
|
||||
kill_process_by_pid(*pid, process_name);
|
||||
}
|
||||
if !pattern_matches(&entry.name_pattern, process_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Never kill protected processes (system, browsers, IDEs, etc.)
|
||||
if is_protected_process(process_name) {
|
||||
warn!(
|
||||
"Blacklisted match '{}' skipped — protected process (pid={})",
|
||||
process_name, pid
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check server-pushed whitelist (takes precedence over blacklist)
|
||||
if is_whitelisted(process_name, whitelist) {
|
||||
debug!(
|
||||
"Blacklisted match '{}' skipped — whitelisted (pid={})",
|
||||
process_name, pid
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if already acted on recently (cooldown)
|
||||
let action_key = format!("{}:{}", process_name.to_lowercase(), pid);
|
||||
if recent_actions.contains(&action_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
|
||||
|
||||
// Report violation to server
|
||||
let action_taken = match entry.action.as_str() {
|
||||
"block" => "blocked_install",
|
||||
"alert" => "alerted",
|
||||
other => other,
|
||||
};
|
||||
let violation = SoftwareViolationReport {
|
||||
device_uid: device_uid.to_string(),
|
||||
software_name: process_name.clone(),
|
||||
action_taken: action_taken.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
|
||||
// Kill the process if action is "block"
|
||||
if entry.action == "block" {
|
||||
kill_process_by_pid(*pid, process_name);
|
||||
}
|
||||
|
||||
// Mark as recently acted-on
|
||||
recent_actions.insert(action_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Match a blacklist pattern against a process name.
|
||||
///
|
||||
/// **Matching rules (case-insensitive)**:
|
||||
/// - No wildcard → exact filename match only (e.g. `chrome.exe` matches `chrome.exe` but NOT `new_chrome.exe`)
|
||||
/// - `*` wildcard → glob-style pattern match (e.g. `*miner*` matches `bitcoin_miner.exe`)
|
||||
/// - Pattern with `.exe` suffix → match against full process name
|
||||
/// - Pattern without extension → match against stem (name without `.exe`)
|
||||
///
|
||||
/// **IMPORTANT**: We intentionally do NOT use substring matching (`contains()`) for
|
||||
/// non-wildcard patterns. Substring matching caused false positives where a pattern
|
||||
/// like "game" would match "game_bar.exe" or even "svchost.exe" in edge cases.
|
||||
fn pattern_matches(pattern: &str, name: &str) -> bool {
|
||||
let pattern_lower = pattern.to_lowercase();
|
||||
let name_lower = name.to_lowercase();
|
||||
// Support wildcard patterns
|
||||
if pattern_lower.contains('*') {
|
||||
let parts: Vec<&str> = pattern_lower.split('*').collect();
|
||||
let mut pos = 0;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if i == 0 && !parts[0].is_empty() {
|
||||
// Pattern starts with literal → must match at start
|
||||
if !name_lower.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else {
|
||||
match name_lower[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
// If pattern ends with literal (no trailing *), must match at end
|
||||
if !parts.last().map_or(true, |p| p.is_empty()) {
|
||||
return name_lower.ends_with(parts.last().unwrap());
|
||||
}
|
||||
true
|
||||
// Extract the stem (filename without extension) for extension-less patterns
|
||||
let name_stem = name_lower.strip_suffix(".exe").unwrap_or(&name_lower);
|
||||
|
||||
if pattern_lower.contains('*') {
|
||||
// Wildcard glob matching
|
||||
glob_matches(&pattern_lower, &name_lower, &name_stem)
|
||||
} else {
|
||||
name_lower.contains(&pattern_lower)
|
||||
// Exact match: compare against full name OR stem (if pattern has no extension)
|
||||
let pattern_stem = pattern_lower.strip_suffix(".exe").unwrap_or(&pattern_lower);
|
||||
name_lower == pattern_lower || name_stem == pattern_stem
|
||||
}
|
||||
}
|
||||
|
||||
/// Glob-style pattern matching with `*` wildcards.
|
||||
/// Checks both the full name and the stem (without .exe).
|
||||
fn glob_matches(pattern: &str, full_name: &str, stem: &str) -> bool {
|
||||
// Try matching against both the full process name and the stem
|
||||
glob_matches_single(pattern, full_name) || glob_matches_single(pattern, stem)
|
||||
}
|
||||
|
||||
fn glob_matches_single(pattern: &str, text: &str) -> bool {
|
||||
let parts: Vec<&str> = pattern.split('*').collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut pos = 0;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
// First segment: must match at the start
|
||||
if !text.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else if i == parts.len() - 1 {
|
||||
// Last segment: must match at the end (if pattern doesn't end with *)
|
||||
return text.ends_with(part) && text[..text.len() - part.len()].len() >= pos;
|
||||
} else {
|
||||
// Middle segment: must appear after current position
|
||||
match text[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If pattern ends with *, anything after last match is fine
|
||||
if pattern.ends_with('*') {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Full match required
|
||||
pos == text.len()
|
||||
}
|
||||
|
||||
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
|
||||
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
|
||||
#[cfg(target_os = "windows")]
|
||||
@@ -253,3 +389,12 @@ fn is_protected_process(name: &str) -> bool {
|
||||
let lower = name.to_lowercase();
|
||||
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
|
||||
}
|
||||
|
||||
/// Check if a process name matches any server-pushed whitelist pattern.
|
||||
/// Whitelist patterns use the same matching logic as blacklist (exact or glob).
|
||||
fn is_whitelisted(process_name: &str, whitelist: &[String]) -> bool {
|
||||
if whitelist.is_empty() {
|
||||
return false;
|
||||
}
|
||||
whitelist.iter().any(|pattern| pattern_matches(pattern, process_name))
|
||||
}
|
||||
|
||||
@@ -29,4 +29,6 @@ pub use message::{
|
||||
PrintEventPayload,
|
||||
ClipboardRulesPayload, ClipboardRule, ClipboardViolationPayload,
|
||||
PopupBlockStatsPayload, PopupRuleStat,
|
||||
PatchStatusPayload, PatchEntry, PatchScanConfigPayload,
|
||||
BehaviorMetricsPayload,
|
||||
};
|
||||
|
||||
@@ -71,6 +71,14 @@ pub enum MessageType {
|
||||
// Plugin: Clipboard Control (剪贴板管控)
|
||||
ClipboardRules = 0x94,
|
||||
ClipboardViolation = 0x95,
|
||||
|
||||
// Plugin: Patch Management (补丁管理)
|
||||
PatchStatusReport = 0xA0,
|
||||
PatchScanConfig = 0xA1,
|
||||
PatchInstallCommand = 0xA2,
|
||||
|
||||
// Plugin: Behavior Metrics (行为指标)
|
||||
BehaviorMetricsReport = 0xB0,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for MessageType {
|
||||
@@ -108,6 +116,10 @@ impl TryFrom<u8> for MessageType {
|
||||
0x91 => Ok(Self::PrintEvent),
|
||||
0x94 => Ok(Self::ClipboardRules),
|
||||
0x95 => Ok(Self::ClipboardViolation),
|
||||
0xA0 => Ok(Self::PatchStatusReport),
|
||||
0xA1 => Ok(Self::PatchScanConfig),
|
||||
0xA2 => Ok(Self::PatchInstallCommand),
|
||||
0xB0 => Ok(Self::BehaviorMetricsReport),
|
||||
_ => Err(format!("Unknown message type: 0x{:02X}", value)),
|
||||
}
|
||||
}
|
||||
@@ -264,7 +276,7 @@ pub struct TaskExecutePayload {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum ConfigUpdateType {
|
||||
UpdateIntervals { heartbeat: u64, status: u64, asset: u64 },
|
||||
TlsCertRotate,
|
||||
TlsCertRotate { new_cert_hash: String, valid_until: String },
|
||||
SelfDestruct,
|
||||
}
|
||||
|
||||
@@ -442,6 +454,44 @@ pub struct PopupRuleStat {
|
||||
pub hits: u32,
|
||||
}
|
||||
|
||||
/// Plugin: Patch Status Report (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PatchStatusPayload {
|
||||
pub device_uid: String,
|
||||
pub patches: Vec<PatchEntry>,
|
||||
}
|
||||
|
||||
/// Information about a single patch/hotfix.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PatchEntry {
|
||||
pub kb_id: String,
|
||||
pub title: String,
|
||||
pub severity: Option<String>, // "Critical" | "Important" | "Moderate" | "Low"
|
||||
pub is_installed: bool,
|
||||
pub installed_at: Option<String>,
|
||||
}
|
||||
|
||||
/// Plugin: Patch Scan Config (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct PatchScanConfigPayload {
|
||||
pub enabled: bool,
|
||||
pub scan_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Plugin: Behavior Metrics Report (Client → Server)
|
||||
/// Enhanced periodic metrics for anomaly detection.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BehaviorMetricsPayload {
|
||||
pub device_uid: String,
|
||||
pub clipboard_ops_count: u32,
|
||||
pub clipboard_ops_night: u32,
|
||||
pub print_jobs_count: u32,
|
||||
pub usb_file_ops_count: u32,
|
||||
pub new_processes_count: u32,
|
||||
pub period_secs: u64,
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -41,6 +41,26 @@ pub async fn cleanup_task(state: AppState) {
|
||||
error!("Failed to cleanup alert records: {}", e);
|
||||
}
|
||||
|
||||
// Cleanup old revoked token families (keep 30 days for audit)
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM revoked_token_families WHERE revoked_at < datetime('now', '-30 days')"
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup revoked token families: {}", e);
|
||||
}
|
||||
|
||||
// Cleanup old anomaly alerts that have been handled
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM anomaly_alerts WHERE handled = 1 AND triggered_at < datetime('now', '-90 days')"
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup handled anomaly alerts: {}", e);
|
||||
}
|
||||
|
||||
// Mark devices as offline if no heartbeat for 2 minutes
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"
|
||||
|
||||
170
crates/server/src/anomaly.rs
Normal file
170
crates/server/src/anomaly.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use sqlx::Row;
|
||||
use tracing::{info, warn};
|
||||
use csm_protocol::BehaviorMetricsPayload;
|
||||
|
||||
/// Check incoming behavior metrics against anomaly rules and generate alerts
|
||||
pub async fn check_anomalies(
|
||||
pool: &sqlx::SqlitePool,
|
||||
ws_hub: &crate::ws::WsHub,
|
||||
metrics: &BehaviorMetricsPayload,
|
||||
) {
|
||||
let mut alerts: Vec<serde_json::Value> = Vec::new();
|
||||
|
||||
// Rule 1: Night-time clipboard operations (> 10 in reporting period)
|
||||
if metrics.clipboard_ops_night > 10 {
|
||||
alerts.push(serde_json::json!({
|
||||
"anomaly_type": "night_clipboard_spike",
|
||||
"severity": "high",
|
||||
"detail": format!("非工作时间剪贴板操作异常: {}次 (阈值: 10次)", metrics.clipboard_ops_night),
|
||||
"metric_value": metrics.clipboard_ops_night,
|
||||
}));
|
||||
}
|
||||
|
||||
// Rule 2: High USB file operations (> 100 per hour)
|
||||
if metrics.period_secs > 0 {
|
||||
let usb_per_hour = (metrics.usb_file_ops_count as f64 / metrics.period_secs as f64) * 3600.0;
|
||||
if usb_per_hour > 100.0 {
|
||||
alerts.push(serde_json::json!({
|
||||
"anomaly_type": "usb_file_exfiltration",
|
||||
"severity": "critical",
|
||||
"detail": format!("USB文件操作频率异常: {:.0}次/小时 (阈值: 100次/小时)", usb_per_hour),
|
||||
"metric_value": usb_per_hour,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Rule 3: High print volume (> 50 per reporting period)
|
||||
if metrics.print_jobs_count > 50 {
|
||||
alerts.push(serde_json::json!({
|
||||
"anomaly_type": "high_print_volume",
|
||||
"severity": "medium",
|
||||
"detail": format!("打印量异常: {}次 (阈值: 50次)", metrics.print_jobs_count),
|
||||
"metric_value": metrics.print_jobs_count,
|
||||
}));
|
||||
}
|
||||
|
||||
// Rule 4: Excessive new processes (> 20 per hour)
|
||||
if metrics.period_secs > 0 {
|
||||
let procs_per_hour = (metrics.new_processes_count as f64 / metrics.period_secs as f64) * 3600.0;
|
||||
if procs_per_hour > 20.0 {
|
||||
alerts.push(serde_json::json!({
|
||||
"anomaly_type": "process_spawn_spike",
|
||||
"severity": "medium",
|
||||
"detail": format!("新进程启动异常: {:.0}次/小时 (阈值: 20次/小时)", procs_per_hour),
|
||||
"metric_value": procs_per_hour,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Insert anomaly alerts
|
||||
for alert in &alerts {
|
||||
if let Err(e) = sqlx::query(
|
||||
"INSERT INTO anomaly_alerts (device_uid, anomaly_type, severity, detail, metric_value, triggered_at) \
|
||||
VALUES (?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&metrics.device_uid)
|
||||
.bind(alert.get("anomaly_type").and_then(|v| v.as_str()).unwrap_or("unknown"))
|
||||
.bind(alert.get("severity").and_then(|v| v.as_str()).unwrap_or("medium"))
|
||||
.bind(alert.get("detail").and_then(|v| v.as_str()).unwrap_or(""))
|
||||
.bind(alert.get("metric_value").and_then(|v| v.as_f64()).unwrap_or(0.0))
|
||||
.execute(pool)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to insert anomaly alert: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast anomaly alerts via WebSocket
|
||||
if !alerts.is_empty() {
|
||||
for alert in &alerts {
|
||||
ws_hub.broadcast(serde_json::json!({
|
||||
"type": "anomaly_alert",
|
||||
"device_uid": metrics.device_uid,
|
||||
"anomaly_type": alert.get("anomaly_type"),
|
||||
"severity": alert.get("severity"),
|
||||
"detail": alert.get("detail"),
|
||||
}).to_string()).await;
|
||||
}
|
||||
info!("Detected {} anomalies for device {}", alerts.len(), metrics.device_uid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get anomaly alert summary for a device or all devices
|
||||
pub async fn get_anomaly_summary(
|
||||
pool: &sqlx::SqlitePool,
|
||||
device_uid: Option<&str>,
|
||||
page: u32,
|
||||
page_size: u32,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let offset = page.saturating_sub(1) * page_size;
|
||||
|
||||
let rows = if let Some(uid) = device_uid {
|
||||
sqlx::query(
|
||||
"SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \
|
||||
WHERE a.device_uid = ? ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(uid)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
} else {
|
||||
sqlx::query(
|
||||
"SELECT a.*, d.hostname FROM anomaly_alerts a JOIN devices d ON d.device_uid = a.device_uid \
|
||||
ORDER BY a.triggered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
};
|
||||
|
||||
let total: i64 = if let Some(uid) = device_uid {
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts WHERE device_uid = ?")
|
||||
.bind(uid)
|
||||
.fetch_one(pool)
|
||||
.await?
|
||||
} else {
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM anomaly_alerts")
|
||||
.fetch_one(pool)
|
||||
.await?
|
||||
};
|
||||
|
||||
let alerts: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"hostname": r.get::<String, _>("hostname"),
|
||||
"anomaly_type": r.get::<String, _>("anomaly_type"),
|
||||
"severity": r.get::<String, _>("severity"),
|
||||
"detail": r.get::<String, _>("detail"),
|
||||
"metric_value": r.get::<f64, _>("metric_value"),
|
||||
"handled": r.get::<i32, _>("handled"),
|
||||
"triggered_at": r.get::<String, _>("triggered_at"),
|
||||
})).collect();
|
||||
|
||||
// Summary counts (scoped to same filter)
|
||||
let unhandled: i64 = if let Some(uid) = device_uid {
|
||||
sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0 AND device_uid = ?"
|
||||
)
|
||||
.bind(uid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM anomaly_alerts WHERE handled = 0"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(0)
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"alerts": alerts,
|
||||
"total": total,
|
||||
"unhandled_count": unhandled,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}))
|
||||
}
|
||||
@@ -21,7 +21,7 @@ pub async fn list_rules(
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
|
||||
FROM alert_rules ORDER BY created_at DESC"
|
||||
FROM alert_rules ORDER BY created_at DESC LIMIT 500"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
@@ -116,7 +116,26 @@ pub async fn create_rule(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<CreateRuleRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
// Validate rule_type
|
||||
if !matches!(body.rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid rule_type")));
|
||||
}
|
||||
|
||||
// Validate severity
|
||||
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
|
||||
if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Invalid severity")));
|
||||
}
|
||||
|
||||
// Validate webhook URL (SSRF prevention)
|
||||
if let Some(ref url) = body.notify_webhook {
|
||||
if !url.starts_with("https://") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL must use HTTPS")));
|
||||
}
|
||||
if url.len() > 2048 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("Webhook URL too long")));
|
||||
}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
|
||||
@@ -174,6 +193,26 @@ pub async fn update_rule(
|
||||
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
|
||||
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
|
||||
|
||||
// Validate rule_type
|
||||
if !matches!(rule_type.as_str(), "device_offline" | "usb_event" | "web_access" | "software_violation" | "custom") {
|
||||
return Json(ApiResponse::error("Invalid rule_type"));
|
||||
}
|
||||
|
||||
// Validate severity
|
||||
if !matches!(severity.as_str(), "low" | "medium" | "high" | "critical") {
|
||||
return Json(ApiResponse::error("Invalid severity"));
|
||||
}
|
||||
|
||||
// Validate webhook URL (SSRF prevention)
|
||||
if let Some(ref url) = notify_webhook {
|
||||
if !url.starts_with("https://") {
|
||||
return Json(ApiResponse::error("Webhook URL must use HTTPS"));
|
||||
}
|
||||
if url.len() > 2048 {
|
||||
return Json(ApiResponse::error("Webhook URL too long"));
|
||||
}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
|
||||
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
|
||||
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::{Response, IntoResponse}};
|
||||
use axum::http::header::{SET_COOKIE, HeaderValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
|
||||
use std::sync::Arc;
|
||||
@@ -28,11 +29,15 @@ pub struct LoginRequest {
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LoginResponse {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub user: UserInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MeResponse {
|
||||
pub user: UserInfo,
|
||||
pub expires_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct UserInfo {
|
||||
pub id: i64,
|
||||
@@ -40,18 +45,68 @@ pub struct UserInfo {
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RefreshRequest {
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChangePasswordRequest {
|
||||
pub old_password: String,
|
||||
pub new_password: String,
|
||||
}
|
||||
|
||||
/// In-memory rate limiter for login attempts
|
||||
// ---------------------------------------------------------------------------
|
||||
// Cookie helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn is_secure_cookies() -> bool {
|
||||
std::env::var("CSM_DEV").is_err()
|
||||
}
|
||||
|
||||
fn access_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
|
||||
let secure = if is_secure_cookies() { "; Secure" } else { "" };
|
||||
HeaderValue::from_str(&format!(
|
||||
"access_token={}; HttpOnly{}; SameSite=Strict; Path=/; Max-Age={}",
|
||||
token, secure, ttl_secs
|
||||
)).expect("valid cookie header")
|
||||
}
|
||||
|
||||
fn refresh_cookie_header(token: &str, ttl_secs: u64) -> HeaderValue {
|
||||
let secure = if is_secure_cookies() { "; Secure" } else { "" };
|
||||
HeaderValue::from_str(&format!(
|
||||
"refresh_token={}; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age={}",
|
||||
token, secure, ttl_secs
|
||||
)).expect("valid cookie header")
|
||||
}
|
||||
|
||||
fn clear_cookie_headers() -> Vec<HeaderValue> {
|
||||
let secure = if is_secure_cookies() { "; Secure" } else { "" };
|
||||
vec![
|
||||
HeaderValue::from_str(&format!("access_token=; HttpOnly{}; SameSite=Strict; Path=/; Max-Age=0", secure)).expect("valid"),
|
||||
HeaderValue::from_str(&format!("refresh_token=; HttpOnly{}; SameSite=Strict; Path=/api/auth/refresh; Max-Age=0", secure)).expect("valid"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Attach Set-Cookie headers to a response.
|
||||
fn with_cookies(mut response: Response, cookies: Vec<HeaderValue>) -> Response {
|
||||
for cookie in cookies {
|
||||
response.headers_mut().append(SET_COOKIE, cookie);
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
/// Extract a cookie value by name from the raw Cookie header.
|
||||
fn extract_cookie_value(headers: &axum::http::HeaderMap, name: &str) -> Option<String> {
|
||||
let cookie_header = headers.get("cookie")?.to_str().ok()?;
|
||||
for cookie in cookie_header.split(';') {
|
||||
let cookie = cookie.trim();
|
||||
if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Rate limiter
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LoginRateLimiter {
|
||||
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
|
||||
@@ -62,28 +117,25 @@ impl LoginRateLimiter {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Returns true if the request should be rate-limited
|
||||
pub async fn is_limited(&self, key: &str) -> bool {
|
||||
let mut attempts = self.attempts.lock().await;
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(300); // 5-minute window
|
||||
let window = std::time::Duration::from_secs(300);
|
||||
let max_attempts = 10u32;
|
||||
|
||||
if let Some((first_attempt, count)) = attempts.get_mut(key) {
|
||||
if now.duration_since(*first_attempt) > window {
|
||||
// Window expired, reset
|
||||
*first_attempt = now;
|
||||
*count = 1;
|
||||
false
|
||||
} else if *count >= max_attempts {
|
||||
true // Rate limited
|
||||
true
|
||||
} else {
|
||||
*count += 1;
|
||||
false
|
||||
}
|
||||
} else {
|
||||
attempts.insert(key.to_string(), (now, 1));
|
||||
// Cleanup old entries periodically
|
||||
if attempts.len() > 1000 {
|
||||
let cutoff = now - window;
|
||||
attempts.retain(|_, (t, _)| *t > cutoff);
|
||||
@@ -93,46 +145,67 @@ impl LoginRateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Endpoints
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
|
||||
// Rate limit check
|
||||
) -> impl IntoResponse {
|
||||
if state.login_limiter.is_limited(&req.username).await {
|
||||
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
|
||||
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts. Try again later."))).into_response();
|
||||
}
|
||||
if state.login_limiter.is_limited("ip:default").await {
|
||||
return (StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::<LoginResponse>::error("Too many login attempts from this location. Try again later."))).into_response();
|
||||
}
|
||||
|
||||
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
|
||||
"SELECT id, username, role FROM users WHERE username = ?"
|
||||
let row: Option<(UserInfo, String)> = sqlx::query_as::<_, (i64, String, String, String)>(
|
||||
"SELECT id, username, role, password FROM users WHERE username = ?"
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|(id, username, role, password)| {
|
||||
(UserInfo { id, username, role }, password)
|
||||
});
|
||||
|
||||
let user = match user {
|
||||
Some(u) => u,
|
||||
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
|
||||
let (user, hash) = match row {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
let _ = bcrypt::verify("timing-constant-dummy", "$2b$12$aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
|
||||
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let hash: String = sqlx::query_scalar::<_, String>(
|
||||
"SELECT password FROM users WHERE id = ?"
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
|
||||
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid credentials"))).into_response();
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let family = uuid::Uuid::new_v4().to_string();
|
||||
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
|
||||
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
|
||||
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
};
|
||||
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
};
|
||||
|
||||
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(&family)
|
||||
.bind(refresh_expires as i64)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
// Audit log
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
|
||||
)
|
||||
@@ -141,73 +214,262 @@ pub async fn login(
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token,
|
||||
user,
|
||||
}))))
|
||||
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
|
||||
with_cookies(response, vec![
|
||||
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
|
||||
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
|
||||
])
|
||||
}
|
||||
|
||||
pub async fn refresh(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<RefreshRequest>,
|
||||
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
|
||||
let claims = decode::<Claims>(
|
||||
&req.refresh_token,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let refresh_token = match extract_cookie_value(&headers, "refresh_token") {
|
||||
Some(t) => t,
|
||||
None => return with_cookies(
|
||||
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Missing refresh token"))).into_response(),
|
||||
clear_cookie_headers(),
|
||||
),
|
||||
};
|
||||
|
||||
let claims = match decode::<Claims>(
|
||||
&refresh_token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
) {
|
||||
Ok(c) => c.claims,
|
||||
Err(_) => return with_cookies(
|
||||
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response(),
|
||||
clear_cookie_headers(),
|
||||
),
|
||||
};
|
||||
|
||||
if claims.claims.token_type != "refresh" {
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
|
||||
if claims.token_type != "refresh" {
|
||||
return with_cookies(
|
||||
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid token type"))).into_response(),
|
||||
clear_cookie_headers(),
|
||||
);
|
||||
}
|
||||
|
||||
// Check if this refresh token family has been revoked (reuse detection)
|
||||
let mut tx = match state.db.begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
};
|
||||
|
||||
let revoked: bool = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
|
||||
)
|
||||
.bind(&claims.claims.family)
|
||||
.fetch_one(&state.db)
|
||||
.bind(&claims.family)
|
||||
.fetch_one(&mut *tx)
|
||||
.await
|
||||
.unwrap_or(0) > 0;
|
||||
|
||||
if revoked {
|
||||
// Token reuse detected — revoke entire family and force re-login
|
||||
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
|
||||
tx.rollback().await.ok();
|
||||
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.sub, claims.family);
|
||||
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
|
||||
.bind(claims.claims.sub)
|
||||
.bind(claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
|
||||
return with_cookies(
|
||||
(StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Token reuse detected. Please log in again."))).into_response(),
|
||||
clear_cookie_headers(),
|
||||
);
|
||||
}
|
||||
|
||||
let family_exists: bool = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM refresh_tokens WHERE family = ? AND user_id = ?"
|
||||
)
|
||||
.bind(&claims.family)
|
||||
.bind(claims.sub)
|
||||
.fetch_one(&mut *tx)
|
||||
.await
|
||||
.unwrap_or(0) > 0;
|
||||
|
||||
if !family_exists {
|
||||
tx.rollback().await.ok();
|
||||
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<LoginResponse>::error("Invalid refresh token"))).into_response();
|
||||
}
|
||||
|
||||
let user = UserInfo {
|
||||
id: claims.claims.sub,
|
||||
username: claims.claims.username,
|
||||
role: claims.claims.role,
|
||||
id: claims.sub,
|
||||
username: claims.username,
|
||||
role: claims.role,
|
||||
};
|
||||
|
||||
// Rotate: new family for each refresh
|
||||
let new_family = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
|
||||
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
|
||||
let access_token = match create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
};
|
||||
let refresh_token = match create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
};
|
||||
|
||||
// Revoke old family
|
||||
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
|
||||
.bind(&claims.claims.family)
|
||||
.bind(claims.claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
if sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
|
||||
.bind(&claims.family)
|
||||
.bind(claims.sub)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token,
|
||||
user,
|
||||
}))))
|
||||
let refresh_expires = now + state.config.auth.refresh_token_ttl_secs;
|
||||
if sqlx::query(
|
||||
"INSERT INTO refresh_tokens (user_id, family, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))"
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(&new_family)
|
||||
.bind(refresh_expires as i64)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
|
||||
if tx.commit().await.is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
|
||||
let response = (StatusCode::OK, Json(ApiResponse::ok(LoginResponse { user }))).into_response();
|
||||
with_cookies(response, vec![
|
||||
access_cookie_header(&access_token, state.config.auth.access_token_ttl_secs),
|
||||
refresh_cookie_header(&refresh_token, state.config.auth.refresh_token_ttl_secs),
|
||||
])
|
||||
}
|
||||
|
||||
/// Get current authenticated user info from access_token cookie.
|
||||
pub async fn me(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let token = match extract_cookie_value(&headers, "access_token") {
|
||||
Some(t) => t,
|
||||
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Not authenticated"))).into_response(),
|
||||
};
|
||||
|
||||
let claims = match decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(c) => c.claims,
|
||||
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token"))).into_response(),
|
||||
};
|
||||
|
||||
if claims.token_type != "access" {
|
||||
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<MeResponse>::error("Invalid token type"))).into_response();
|
||||
}
|
||||
|
||||
let expires_at = chrono::DateTime::from_timestamp(claims.exp as i64, 0)
|
||||
.map(|t| t.to_rfc3339())
|
||||
.unwrap_or_default();
|
||||
|
||||
(StatusCode::OK, Json(ApiResponse::ok(MeResponse {
|
||||
user: UserInfo {
|
||||
id: claims.sub,
|
||||
username: claims.username,
|
||||
role: claims.role,
|
||||
},
|
||||
expires_at,
|
||||
}))).into_response()
|
||||
}
|
||||
|
||||
/// Logout: clear auth cookies and revoke refresh token family.
|
||||
pub async fn logout(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Some(token) = extract_cookie_value(&headers, "access_token") {
|
||||
if let Ok(claims) = decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
|
||||
.bind(claims.claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'logout', ?)"
|
||||
)
|
||||
.bind(claims.claims.sub)
|
||||
.bind(format!("User {} logged out", claims.claims.username))
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let response = (StatusCode::OK, Json(ApiResponse::ok(()))).into_response();
|
||||
with_cookies(response, clear_cookie_headers())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket ticket
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct WsTicketResponse {
|
||||
pub ticket: String,
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// Create a one-time ticket for WebSocket authentication.
|
||||
/// Requires a valid access_token cookie (set by login).
|
||||
pub async fn create_ws_ticket(
|
||||
State(state): State<AppState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let token = match extract_cookie_value(&headers, "access_token") {
|
||||
Some(t) => t,
|
||||
None => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Not authenticated"))).into_response(),
|
||||
};
|
||||
|
||||
let claims = match decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(c) => c.claims,
|
||||
Err(_) => return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token"))).into_response(),
|
||||
};
|
||||
|
||||
if claims.token_type != "access" {
|
||||
return (StatusCode::UNAUTHORIZED, Json(ApiResponse::<WsTicketResponse>::error("Invalid token type"))).into_response();
|
||||
}
|
||||
|
||||
let ticket = uuid::Uuid::new_v4().to_string();
|
||||
let claim = crate::ws::TicketClaim {
|
||||
user_id: claims.sub,
|
||||
username: claims.username,
|
||||
role: claims.role,
|
||||
created_at: std::time::Instant::now(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut tickets = state.ws_tickets.lock().await;
|
||||
tickets.insert(ticket.clone(), claim);
|
||||
|
||||
// Cleanup expired tickets (>30s old) on every creation
|
||||
tickets.retain(|_, c| c.created_at.elapsed().as_secs() < 30);
|
||||
}
|
||||
|
||||
(StatusCode::OK, Json(ApiResponse::ok(WsTicketResponse {
|
||||
ticket,
|
||||
expires_in: 30,
|
||||
}))).into_response()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
|
||||
let claims = Claims {
|
||||
sub: user.id,
|
||||
@@ -227,24 +489,17 @@ fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
/// Axum middleware: require valid JWT access token
|
||||
/// Axum middleware: require valid JWT access token from cookie
|
||||
pub async fn require_auth(
|
||||
State(state): State<AppState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let auth_header = request.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let token = match auth_header {
|
||||
Some(t) => t,
|
||||
None => return Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
let token = extract_cookie_value(request.headers(), "access_token")
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let claims = decode::<Claims>(
|
||||
token,
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
@@ -254,9 +509,7 @@ pub async fn require_auth(
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// Inject claims into request extensions for handlers to use
|
||||
request.extensions_mut().insert(claims.claims);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
@@ -274,7 +527,6 @@ pub async fn require_admin(
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
// Capture audit info before running handler
|
||||
let method = request.method().clone();
|
||||
let path = request.uri().path().to_string();
|
||||
let user_id = claims.sub;
|
||||
@@ -282,7 +534,6 @@ pub async fn require_admin(
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Record admin action to audit log (fire and forget — don't block response)
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let action = format!("{} {}", method, path);
|
||||
@@ -308,8 +559,10 @@ pub async fn change_password(
|
||||
if req.new_password.len() < 6 {
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("新密码至少6位"))));
|
||||
}
|
||||
if req.new_password.len() > 72 {
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("密码不能超过72位"))));
|
||||
}
|
||||
|
||||
// Verify old password
|
||||
let hash: String = sqlx::query_scalar::<_, String>(
|
||||
"SELECT password FROM users WHERE id = ?"
|
||||
)
|
||||
@@ -322,7 +575,6 @@ pub async fn change_password(
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("当前密码错误"))));
|
||||
}
|
||||
|
||||
// Update password
|
||||
let new_hash = bcrypt::hash(&req.new_password, 12).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
sqlx::query("UPDATE users SET password = ? WHERE id = ?")
|
||||
.bind(&new_hash)
|
||||
@@ -331,7 +583,6 @@ pub async fn change_password(
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Audit log
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'change_password', ?)"
|
||||
)
|
||||
|
||||
243
crates/server/src/api/conflict.rs
Normal file
243
crates/server/src/api/conflict.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use axum::{extract::State, Json};
|
||||
use serde::Serialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PolicyConflict {
|
||||
pub conflict_type: String,
|
||||
pub severity: String,
|
||||
pub description: String,
|
||||
pub policies: Vec<ConflictPolicyRef>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ConflictPolicyRef {
|
||||
pub table_name: String,
|
||||
pub row_id: i64,
|
||||
pub name: String,
|
||||
pub target_type: String,
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
/// GET /api/policies/conflicts — scan all policies for conflicts
|
||||
pub async fn scan_conflicts(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let mut conflicts: Vec<PolicyConflict> = Vec::new();
|
||||
|
||||
// 1. USB: multiple enabled policies for the same target_group
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT target_group, COUNT(*) as cnt, GROUP_CONCAT(id) as ids, GROUP_CONCAT(name) as names, \
|
||||
GROUP_CONCAT(policy_type) as types \
|
||||
FROM usb_policies WHERE enabled = 1 AND target_group IS NOT NULL \
|
||||
GROUP BY target_group HAVING cnt > 1"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
for row in &rows {
|
||||
let group: String = row.get("target_group");
|
||||
let ids: String = row.get("ids");
|
||||
let names: String = row.get("names");
|
||||
let types: String = row.get("types");
|
||||
let id_vec: Vec<i64> = ids.split(',').filter_map(|s| s.parse().ok()).collect();
|
||||
let name_vec: Vec<&str> = names.split(',').collect();
|
||||
let type_vec: Vec<&str> = types.split(',').collect();
|
||||
|
||||
conflicts.push(PolicyConflict {
|
||||
conflict_type: "usb_duplicate_policy".to_string(),
|
||||
severity: "high".to_string(),
|
||||
description: format!("分组 '{}' 同时存在 {} 条启用的USB策略 ({})", group, id_vec.len(), type_vec.join(", ")),
|
||||
policies: id_vec.iter().enumerate().map(|(i, id)| ConflictPolicyRef {
|
||||
table_name: "usb_policies".to_string(),
|
||||
row_id: *id,
|
||||
name: name_vec.get(i).unwrap_or(&"?").to_string(),
|
||||
target_type: "group".to_string(),
|
||||
target_id: Some(group.clone()),
|
||||
}).collect(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 2. USB: all_block + whitelist for same target (contradictory)
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT a.id as aid, a.name as aname, a.target_group as agroup, \
|
||||
b.id as bid, b.name as bname \
|
||||
FROM usb_policies a JOIN usb_policies b ON a.target_group = b.target_group AND a.id < b.id \
|
||||
WHERE a.enabled = 1 AND b.enabled = 1 \
|
||||
AND ((a.policy_type = 'all_block' AND b.policy_type = 'whitelist') OR \
|
||||
(a.policy_type = 'whitelist' AND b.policy_type = 'all_block'))"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
for row in &rows {
|
||||
let group: Option<String> = row.get("agroup");
|
||||
conflicts.push(PolicyConflict {
|
||||
conflict_type: "usb_block_vs_whitelist".to_string(),
|
||||
severity: "critical".to_string(),
|
||||
description: format!("分组 '{}' 同时存在全封堵和白名单USB策略,互斥", group.as_deref().unwrap_or("?")),
|
||||
policies: vec![
|
||||
ConflictPolicyRef {
|
||||
table_name: "usb_policies".to_string(),
|
||||
row_id: row.get("aid"),
|
||||
name: row.get("aname"),
|
||||
target_type: "group".to_string(),
|
||||
target_id: group.clone(),
|
||||
},
|
||||
ConflictPolicyRef {
|
||||
table_name: "usb_policies".to_string(),
|
||||
row_id: row.get("bid"),
|
||||
name: row.get("bname"),
|
||||
target_type: "group".to_string(),
|
||||
target_id: group,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Web filter: same target, same pattern, different rule_type (allow vs block)
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT a.id as aid, a.pattern as apattern, a.rule_type as artype, \
|
||||
b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \
|
||||
FROM web_filter_rules a JOIN web_filter_rules b ON a.pattern = b.pattern AND a.id < b.id \
|
||||
WHERE a.enabled = 1 AND b.enabled = 1 \
|
||||
AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \
|
||||
AND a.rule_type != b.rule_type"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
for row in &rows {
|
||||
let pattern: String = row.get("apattern");
|
||||
let artype: String = row.get("artype");
|
||||
let brtype: String = row.get("brtype");
|
||||
let ttype: String = row.get("ttype");
|
||||
let tid: Option<String> = row.get("tid");
|
||||
conflicts.push(PolicyConflict {
|
||||
conflict_type: "web_filter_allow_vs_block".to_string(),
|
||||
severity: "high".to_string(),
|
||||
description: format!("URL '{}' 同时被 {} 和 {},互斥", pattern, artype, brtype),
|
||||
policies: vec![
|
||||
ConflictPolicyRef {
|
||||
table_name: "web_filter_rules".to_string(),
|
||||
row_id: row.get("aid"),
|
||||
name: format!("{}: {}", artype, pattern),
|
||||
target_type: ttype.clone(),
|
||||
target_id: tid.clone(),
|
||||
},
|
||||
ConflictPolicyRef {
|
||||
table_name: "web_filter_rules".to_string(),
|
||||
row_id: row.get("bid"),
|
||||
name: format!("{}: {}", brtype, pattern),
|
||||
target_type: ttype,
|
||||
target_id: tid,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Clipboard: same source/target process, allow vs block
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT a.id as aid, a.rule_type as artype, a.source_process as asrc, a.target_process as adst, \
|
||||
b.id as bid, b.rule_type as brtype, a.target_type as ttype, a.target_id as tid \
|
||||
FROM clipboard_rules a JOIN clipboard_rules b ON a.id < b.id \
|
||||
WHERE a.enabled = 1 AND b.enabled = 1 \
|
||||
AND a.target_type = b.target_type AND COALESCE(a.target_id,'') = COALESCE(b.target_id,'') \
|
||||
AND a.direction = b.direction \
|
||||
AND COALESCE(a.source_process,'') = COALESCE(b.source_process,'') \
|
||||
AND COALESCE(a.target_process,'') = COALESCE(b.target_process,'') \
|
||||
AND a.rule_type != b.rule_type"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
for row in &rows {
|
||||
let asrc: Option<String> = row.get("asrc");
|
||||
let adst: Option<String> = row.get("adst");
|
||||
let artype: String = row.get("artype");
|
||||
let brtype: String = row.get("brtype");
|
||||
let desc = format!(
|
||||
"剪贴板规则冲突: {} → {} 同时存在 {} 和 {}",
|
||||
asrc.as_deref().unwrap_or("*"),
|
||||
adst.as_deref().unwrap_or("*"),
|
||||
artype, brtype,
|
||||
);
|
||||
let ttype: String = row.get("ttype");
|
||||
let tid: Option<String> = row.get("tid");
|
||||
conflicts.push(PolicyConflict {
|
||||
conflict_type: "clipboard_allow_vs_block".to_string(),
|
||||
severity: "medium".to_string(),
|
||||
description: desc,
|
||||
policies: vec![
|
||||
ConflictPolicyRef {
|
||||
table_name: "clipboard_rules".to_string(),
|
||||
row_id: row.get("aid"),
|
||||
name: format!("{}: {}", artype, asrc.as_deref().unwrap_or("*")),
|
||||
target_type: ttype.clone(),
|
||||
target_id: tid.clone(),
|
||||
},
|
||||
ConflictPolicyRef {
|
||||
table_name: "clipboard_rules".to_string(),
|
||||
row_id: row.get("bid"),
|
||||
name: format!("{}: {}", brtype, asrc.as_deref().unwrap_or("*")),
|
||||
target_type: ttype,
|
||||
target_id: tid,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Plugin disabled but has active rules
|
||||
let plugin_tables: [(&str, &str, &str, &str); 4] = [
|
||||
("web_filter_rules", "上网行为过滤", "web_filter", "SELECT COUNT(*) FROM web_filter_rules WHERE enabled = 1"),
|
||||
("software_blacklist", "软件黑名单", "software_blocker", "SELECT COUNT(*) FROM software_blacklist WHERE enabled = 1"),
|
||||
("popup_filter_rules", "弹窗拦截", "popup_blocker", "SELECT COUNT(*) FROM popup_filter_rules WHERE enabled = 1"),
|
||||
("clipboard_rules", "剪贴板管控", "clipboard_control", "SELECT COUNT(*) FROM clipboard_rules WHERE enabled = 1"),
|
||||
];
|
||||
for (_table, label, plugin, query) in &plugin_tables {
|
||||
let active_count: i64 = sqlx::query_scalar(query)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if active_count > 0 {
|
||||
let disabled: bool = sqlx::query_scalar::<_, i32>(
|
||||
"SELECT COUNT(*) FROM plugin_state WHERE plugin_name = ? AND enabled = 0"
|
||||
)
|
||||
.bind(plugin)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0) > 0;
|
||||
|
||||
if disabled {
|
||||
conflicts.push(PolicyConflict {
|
||||
conflict_type: "plugin_disabled_with_rules".to_string(),
|
||||
severity: "low".to_string(),
|
||||
description: format!("插件 '{}' 已禁用,但仍有 {} 条启用规则,规则不会生效", label, active_count),
|
||||
policies: vec![ConflictPolicyRef {
|
||||
table_name: "plugin_state".to_string(),
|
||||
row_id: 0,
|
||||
name: plugin.to_string(),
|
||||
target_type: "global".to_string(),
|
||||
target_id: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"conflicts": conflicts,
|
||||
"total": conflicts.len(),
|
||||
"critical_count": conflicts.iter().filter(|c| c.severity == "critical").count(),
|
||||
"high_count": conflicts.iter().filter(|c| c.severity == "high").count(),
|
||||
"medium_count": conflicts.iter().filter(|c| c.severity == "medium").count(),
|
||||
"low_count": conflicts.iter().filter(|c| c.severity == "low").count(),
|
||||
})))
|
||||
}
|
||||
@@ -4,6 +4,28 @@ use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::{ApiResponse, Pagination};
|
||||
|
||||
/// GET /api/devices/:uid/health-score
|
||||
pub async fn get_health_score(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match crate::health::get_device_score(&state.db, &uid).await {
|
||||
Ok(Some(score)) => Json(ApiResponse::ok(score)),
|
||||
Ok(None) => Json(ApiResponse::error("No health score available")),
|
||||
Err(e) => Json(ApiResponse::internal_error("health score", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/dashboard/health-overview
|
||||
pub async fn health_overview(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match crate::health::get_health_overview(&state.db).await {
|
||||
Ok(overview) => Json(ApiResponse::ok(overview)),
|
||||
Err(e) => Json(ApiResponse::internal_error("health overview", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeviceListParams {
|
||||
pub status: Option<String>,
|
||||
@@ -26,6 +48,10 @@ pub struct DeviceRow {
|
||||
pub last_heartbeat: Option<String>,
|
||||
pub registered_at: Option<String>,
|
||||
pub group_name: Option<String>,
|
||||
#[sqlx(default)]
|
||||
pub health_score: Option<i32>,
|
||||
#[sqlx(default)]
|
||||
pub health_level: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list(
|
||||
@@ -41,13 +67,16 @@ pub async fn list(
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let devices = sqlx::query_as::<_, DeviceRow>(
|
||||
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
|
||||
status, last_heartbeat, registered_at, group_name
|
||||
FROM devices WHERE 1=1
|
||||
AND (? IS NULL OR status = ?)
|
||||
AND (? IS NULL OR group_name = ?)
|
||||
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
|
||||
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
|
||||
"SELECT d.id, d.device_uid, d.hostname, d.ip_address, d.mac_address, d.os_version, d.client_version,
|
||||
d.status, d.last_heartbeat, d.registered_at, d.group_name,
|
||||
h.score as health_score, h.level as health_level
|
||||
FROM devices d
|
||||
LEFT JOIN device_health_scores h ON h.device_uid = d.device_uid
|
||||
WHERE 1=1
|
||||
AND (? IS NULL OR d.status = ?)
|
||||
AND (? IS NULL OR d.group_name = ?)
|
||||
AND (? IS NULL OR d.hostname LIKE '%' || ? || '%' OR d.ip_address LIKE '%' || ? || '%')
|
||||
ORDER BY d.registered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&status).bind(&status)
|
||||
.bind(&group).bind(&group)
|
||||
@@ -187,16 +216,6 @@ pub async fn remove(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
// If client is connected, send self-destruct command
|
||||
let frame = csm_protocol::Frame::new_json(
|
||||
csm_protocol::MessageType::ConfigUpdate,
|
||||
&serde_json::json!({"type": "SelfDestruct"}),
|
||||
).ok();
|
||||
|
||||
if let Some(frame) = frame {
|
||||
state.clients.send_to(&uid, frame.encode()).await;
|
||||
}
|
||||
|
||||
// Delete device and all associated data in a transaction
|
||||
let mut tx = match state.db.begin().await {
|
||||
Ok(tx) => tx,
|
||||
@@ -224,6 +243,8 @@ pub async fn remove(
|
||||
// Delete plugin-related data
|
||||
let cleanup_tables = [
|
||||
"hardware_assets",
|
||||
"software_assets",
|
||||
"asset_changes",
|
||||
"usb_events",
|
||||
"usb_file_operations",
|
||||
"usage_daily",
|
||||
@@ -231,8 +252,20 @@ pub async fn remove(
|
||||
"software_violations",
|
||||
"web_access_log",
|
||||
"popup_block_stats",
|
||||
"disk_encryption_status",
|
||||
"disk_encryption_alerts",
|
||||
"print_events",
|
||||
"clipboard_violations",
|
||||
"behavior_metrics",
|
||||
"anomaly_alerts",
|
||||
"device_health_scores",
|
||||
"patch_status",
|
||||
];
|
||||
for table in &cleanup_tables {
|
||||
// Safety: table names are hardcoded constants above, not user input.
|
||||
// Parameterized ? is used for device_uid.
|
||||
debug_assert!(table.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
|
||||
"BUG: table name contains unexpected characters: {}", table);
|
||||
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
@@ -253,6 +286,17 @@ pub async fn remove(
|
||||
if let Err(e) = tx.commit().await {
|
||||
return Json(ApiResponse::internal_error("commit device deletion", e));
|
||||
}
|
||||
|
||||
// Send self-destruct command AFTER successful commit
|
||||
let frame = csm_protocol::Frame::new_json(
|
||||
csm_protocol::MessageType::ConfigUpdate,
|
||||
&serde_json::json!({"type": "SelfDestruct"}),
|
||||
).ok();
|
||||
|
||||
if let Some(frame) = frame {
|
||||
state.clients.send_to(&uid, frame.encode()).await;
|
||||
}
|
||||
|
||||
state.clients.unregister(&uid).await;
|
||||
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
|
||||
Json(ApiResponse::ok(()))
|
||||
|
||||
@@ -50,6 +50,9 @@ pub async fn create_group(
|
||||
if name.is_empty() || name.len() > 50 {
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效"))));
|
||||
}
|
||||
if name.contains('<') || name.contains('>') || name.contains('"') || name.contains('\'') {
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符"))));
|
||||
}
|
||||
|
||||
// Check if group already exists
|
||||
let exists: bool = sqlx::query_scalar(
|
||||
@@ -78,6 +81,9 @@ pub async fn rename_group(
|
||||
if new_name.is_empty() || new_name.len() > 50 {
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称无效"))));
|
||||
}
|
||||
if new_name.contains('<') || new_name.contains('>') || new_name.contains('"') || new_name.contains('\'') {
|
||||
return Ok((StatusCode::BAD_REQUEST, Json(ApiResponse::error("分组名称包含非法字符"))));
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE devices SET group_name = ? WHERE group_name = ?"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{routing::{get, post, put, delete}, Router, Json, middleware};
|
||||
use axum::{routing::{get, post, put, delete}, Router, Json, middleware, http::StatusCode, response::IntoResponse};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
@@ -9,23 +9,31 @@ pub mod usb;
|
||||
pub mod alerts;
|
||||
pub mod plugins;
|
||||
pub mod groups;
|
||||
pub mod conflict;
|
||||
|
||||
pub fn routes(state: AppState) -> Router<AppState> {
|
||||
let public = Router::new()
|
||||
.route("/api/auth/login", post(auth::login))
|
||||
.route("/api/auth/refresh", post(auth::refresh))
|
||||
.route("/api/auth/logout", post(auth::logout))
|
||||
.route("/health", get(health_check))
|
||||
.with_state(state.clone());
|
||||
|
||||
// Read-only routes (any authenticated user)
|
||||
let read_routes = Router::new()
|
||||
// Auth
|
||||
.route("/api/auth/me", get(auth::me))
|
||||
.route("/api/auth/change-password", put(auth::change_password))
|
||||
// WebSocket ticket (requires auth cookie)
|
||||
.route("/api/ws/ticket", post(auth::create_ws_ticket))
|
||||
// Devices
|
||||
.route("/api/devices", get(devices::list))
|
||||
.route("/api/devices/:uid", get(devices::get_detail))
|
||||
.route("/api/devices/:uid/status", get(devices::get_status))
|
||||
.route("/api/devices/:uid/history", get(devices::get_history))
|
||||
.route("/api/devices/:uid/health-score", get(devices::get_health_score))
|
||||
// Dashboard
|
||||
.route("/api/dashboard/health-overview", get(devices::health_overview))
|
||||
// Assets
|
||||
.route("/api/assets/hardware", get(assets::list_hardware))
|
||||
.route("/api/assets/software", get(assets::list_software))
|
||||
@@ -40,6 +48,8 @@ pub fn routes(state: AppState) -> Router<AppState> {
|
||||
.route("/api/alerts/records", get(alerts::list_records))
|
||||
// Plugin read routes
|
||||
.merge(plugins::read_routes())
|
||||
// Policy conflict scan
|
||||
.route("/api/policies/conflicts", get(conflict::scan_conflicts))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
|
||||
|
||||
// Write routes (admin only)
|
||||
@@ -50,6 +60,8 @@ pub fn routes(state: AppState) -> Router<AppState> {
|
||||
.route("/api/groups", post(groups::create_group))
|
||||
.route("/api/groups/:name", put(groups::rename_group).delete(groups::delete_group))
|
||||
.route("/api/devices/:uid/group", put(groups::move_device))
|
||||
// TLS cert rotation
|
||||
.route("/api/system/tls-rotate", post(system_tls_rotate))
|
||||
// USB (write)
|
||||
.route("/api/usb/policies", post(usb::create_policy))
|
||||
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
|
||||
@@ -76,6 +88,45 @@ pub fn routes(state: AppState) -> Router<AppState> {
|
||||
.merge(ws_router)
|
||||
}
|
||||
|
||||
/// Trigger TLS certificate rotation for all online devices.
|
||||
/// Admin sends the new certificate PEM and a transition deadline.
|
||||
/// The server pushes a ConfigUpdate(TlsCertRotate) to all connected clients.
|
||||
#[derive(Deserialize)]
|
||||
struct TlsRotateRequest {
|
||||
/// Path to the new certificate PEM file
|
||||
cert_path: String,
|
||||
/// ISO 8601 timestamp when the old cert stops being valid (transition deadline)
|
||||
valid_until: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TlsRotateResponse {
|
||||
devices_notified: usize,
|
||||
}
|
||||
|
||||
async fn system_tls_rotate(
|
||||
axum::extract::State(state): axum::extract::State<AppState>,
|
||||
Json(req): Json<TlsRotateRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let cert_pem = match tokio::fs::read(&req.cert_path).await {
|
||||
Ok(pem) => pem,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ApiResponse::<TlsRotateResponse>::error(
|
||||
format!("Cannot read cert file {}: {}", req.cert_path, e),
|
||||
)),
|
||||
).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let count = crate::tcp::push_tls_cert_rotation(&state.clients, &cert_pem, &req.valid_until).await;
|
||||
|
||||
(StatusCode::OK, Json(ApiResponse::ok(TlsRotateResponse {
|
||||
devices_notified: count,
|
||||
}))).into_response()
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct HealthResponse {
|
||||
status: &'static str,
|
||||
|
||||
48
crates/server/src/api/plugins/anomaly.rs
Normal file
48
crates/server/src/api/plugins/anomaly.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use axum::{extract::{State, Path, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AnomalyListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
/// GET /api/plugins/anomaly/alerts
|
||||
pub async fn list_anomaly_alerts(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AnomalyListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let page = params.page.unwrap_or(1);
|
||||
let page_size = params.page_size.unwrap_or(20).min(100);
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty());
|
||||
|
||||
match crate::anomaly::get_anomaly_summary(&state.db, device_uid, page, page_size).await {
|
||||
Ok(result) => Json(ApiResponse::ok(result)),
|
||||
Err(e) => Json(ApiResponse::internal_error("anomaly alerts", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// PUT /api/plugins/anomaly/alerts/:id/handle
|
||||
/// Mark an anomaly alert as handled.
|
||||
pub async fn handle_anomaly_alert(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
claims: axum::Extension<crate::api::auth::Claims>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
let result = sqlx::query(
|
||||
"UPDATE anomaly_alerts SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&claims.username)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => Json(ApiResponse::ok(())),
|
||||
Ok(_) => Json(ApiResponse::error("Alert not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("handle anomaly alert", e)),
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ pub struct CreateRuleRequest {
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled, updated_at \
|
||||
FROM clipboard_rules ORDER BY updated_at DESC"
|
||||
FROM clipboard_rules ORDER BY updated_at DESC LIMIT 500"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
@@ -127,6 +127,18 @@ pub async fn update_rule(
|
||||
let content_pattern = body.content_pattern.or_else(|| existing.get::<Option<String>, _>("content_pattern"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Validate merged values
|
||||
if let Some(ref rt) = rule_type {
|
||||
if !matches!(rt.as_str(), "allow" | "block") {
|
||||
return Json(ApiResponse::error("rule_type must be 'allow' or 'block'"));
|
||||
}
|
||||
}
|
||||
if let Some(ref d) = direction {
|
||||
if !matches!(d.as_str(), "in" | "out" | "both") {
|
||||
return Json(ApiResponse::error("direction must be 'in', 'out', or 'both'"));
|
||||
}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE clipboard_rules SET rule_type = ?, direction = ?, source_process = ?, target_process = ?, content_pattern = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ pub async fn list_status(
|
||||
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
|
||||
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
|
||||
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
|
||||
ORDER BY s.device_uid, s.drive_letter"
|
||||
ORDER BY s.device_uid, s.drive_letter LIMIT 500"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
@@ -58,7 +58,7 @@ pub async fn list_alerts(State(state): State<AppState>) -> Json<ApiResponse<serd
|
||||
match sqlx::query(
|
||||
"SELECT a.id, a.device_uid, a.drive_letter, a.alert_type, a.status, a.created_at, a.resolved_at, \
|
||||
d.hostname FROM encryption_alerts a LEFT JOIN devices d ON a.device_uid = d.device_uid \
|
||||
ORDER BY a.created_at DESC"
|
||||
ORDER BY a.created_at DESC LIMIT 500"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
|
||||
@@ -8,6 +8,8 @@ pub mod disk_encryption;
|
||||
pub mod print_audit;
|
||||
pub mod clipboard_control;
|
||||
pub mod plugin_control;
|
||||
pub mod patch;
|
||||
pub mod anomaly;
|
||||
|
||||
use axum::{Router, routing::{get, post, put}};
|
||||
use crate::AppState;
|
||||
@@ -25,6 +27,7 @@ pub fn read_routes() -> Router<AppState> {
|
||||
// Software Blocker
|
||||
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
|
||||
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
|
||||
.route("/api/plugins/software-blocker/whitelist", get(software_blocker::list_whitelist))
|
||||
// Popup Blocker
|
||||
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
|
||||
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
|
||||
@@ -36,7 +39,6 @@ pub fn read_routes() -> Router<AppState> {
|
||||
// Disk Encryption
|
||||
.route("/api/plugins/disk-encryption/status", get(disk_encryption::list_status))
|
||||
.route("/api/plugins/disk-encryption/alerts", get(disk_encryption::list_alerts))
|
||||
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
|
||||
// Print Audit
|
||||
.route("/api/plugins/print-audit/events", get(print_audit::list_events))
|
||||
.route("/api/plugins/print-audit/events/:id", get(print_audit::get_event))
|
||||
@@ -45,6 +47,12 @@ pub fn read_routes() -> Router<AppState> {
|
||||
.route("/api/plugins/clipboard-control/violations", get(clipboard_control::list_violations))
|
||||
// Plugin Control
|
||||
.route("/api/plugins/control", get(plugin_control::list_plugins))
|
||||
// Patch Management
|
||||
.route("/api/plugins/patch/status", get(patch::list_patch_status))
|
||||
.route("/api/plugins/patch/summary", get(patch::patch_summary))
|
||||
.route("/api/plugins/patch/device/:uid", get(patch::device_patches))
|
||||
// Anomaly Detection
|
||||
.route("/api/plugins/anomaly/alerts", get(anomaly::list_anomaly_alerts))
|
||||
}
|
||||
|
||||
/// Write plugin routes (admin only — require_admin middleware applied by caller)
|
||||
@@ -56,6 +64,8 @@ pub fn write_routes() -> Router<AppState> {
|
||||
// Software Blocker
|
||||
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
|
||||
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
|
||||
.route("/api/plugins/software-blocker/whitelist", post(software_blocker::add_to_whitelist))
|
||||
.route("/api/plugins/software-blocker/whitelist/:id", put(software_blocker::update_whitelist).delete(software_blocker::remove_from_whitelist))
|
||||
// Popup Blocker
|
||||
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
|
||||
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
|
||||
@@ -65,6 +75,10 @@ pub fn write_routes() -> Router<AppState> {
|
||||
// Clipboard Control
|
||||
.route("/api/plugins/clipboard-control/rules", post(clipboard_control::create_rule))
|
||||
.route("/api/plugins/clipboard-control/rules/:id", put(clipboard_control::update_rule).delete(clipboard_control::delete_rule))
|
||||
// Disk Encryption
|
||||
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
|
||||
// Plugin Control (enable/disable)
|
||||
.route("/api/plugins/control/:plugin_name", put(plugin_control::set_plugin_state))
|
||||
// Anomaly Detection — handle alert
|
||||
.route("/api/plugins/anomaly/alerts/:id/handle", put(anomaly::handle_anomaly_alert))
|
||||
}
|
||||
|
||||
146
crates/server/src/api/plugins/patch.rs
Normal file
146
crates/server/src/api/plugins/patch.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use axum::{extract::{State, Path, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PatchListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub severity: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
pub installed: Option<i32>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
/// GET /api/plugins/patch/status
|
||||
pub async fn list_patch_status(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<PatchListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty());
|
||||
let severity = params.severity.as_deref().filter(|s| !s.is_empty());
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT p.*, d.hostname FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \
|
||||
WHERE 1=1 \
|
||||
AND (? IS NULL OR p.device_uid = ?) \
|
||||
AND (? IS NULL OR p.severity = ?) \
|
||||
ORDER BY p.updated_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(device_uid).bind(device_uid)
|
||||
.bind(severity).bind(severity)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"hostname": r.get::<String, _>("hostname"),
|
||||
"kb_id": r.get::<String, _>("kb_id"),
|
||||
"title": r.get::<String, _>("title"),
|
||||
"severity": r.get::<Option<String>, _>("severity"),
|
||||
"is_installed": r.get::<i32, _>("is_installed"),
|
||||
"installed_at": r.get::<Option<String>, _>("installed_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect();
|
||||
|
||||
// Summary stats (scoped to same filters as main query)
|
||||
let total_installed: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM patch_status WHERE is_installed = 1 \
|
||||
AND (? IS NULL OR device_uid = ?) \
|
||||
AND (? IS NULL OR severity = ?)"
|
||||
)
|
||||
.bind(device_uid).bind(device_uid)
|
||||
.bind(severity).bind(severity)
|
||||
.fetch_one(&state.db).await.unwrap_or(0);
|
||||
let total_missing: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM patch_status WHERE is_installed = 0 \
|
||||
AND (? IS NULL OR device_uid = ?) \
|
||||
AND (? IS NULL OR severity = ?)"
|
||||
)
|
||||
.bind(device_uid).bind(device_uid)
|
||||
.bind(severity).bind(severity)
|
||||
.fetch_one(&state.db).await.unwrap_or(0);
|
||||
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"patches": items,
|
||||
"summary": {
|
||||
"total_installed": total_installed,
|
||||
"total_missing": total_missing,
|
||||
},
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query patch status", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/plugins/patch/summary — per-device patch summary
|
||||
pub async fn patch_summary(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT p.device_uid, d.hostname, \
|
||||
COUNT(*) as total_patches, \
|
||||
SUM(CASE WHEN p.is_installed = 1 THEN 1 ELSE 0 END) as installed, \
|
||||
SUM(CASE WHEN p.is_installed = 0 THEN 1 ELSE 0 END) as missing, \
|
||||
MAX(p.updated_at) as last_scan \
|
||||
FROM patch_status p JOIN devices d ON d.device_uid = p.device_uid \
|
||||
GROUP BY p.device_uid ORDER BY missing DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let devices: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"hostname": r.get::<String, _>("hostname"),
|
||||
"total_patches": r.get::<i64, _>("total_patches"),
|
||||
"installed": r.get::<i64, _>("installed"),
|
||||
"missing": r.get::<i64, _>("missing"),
|
||||
"last_scan": r.get::<String, _>("last_scan"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({ "devices": devices })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("patch summary", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/plugins/patch/device/:uid — patches for a single device
|
||||
pub async fn device_patches(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT kb_id, title, severity, is_installed, installed_at, updated_at \
|
||||
FROM patch_status WHERE device_uid = ? ORDER BY updated_at DESC"
|
||||
)
|
||||
.bind(&uid)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let patches: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"kb_id": r.get::<String, _>("kb_id"),
|
||||
"title": r.get::<String, _>("title"),
|
||||
"severity": r.get::<Option<String>, _>("severity"),
|
||||
"is_installed": r.get::<i32, _>("is_installed"),
|
||||
"installed_at": r.get::<Option<String>, _>("installed_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({ "patches": patches })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("device patches", e)),
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ pub struct CreateRuleRequest {
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
|
||||
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC LIMIT 500")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
@@ -47,6 +47,16 @@ pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRu
|
||||
if !has_filter {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
|
||||
}
|
||||
// Length validation for filter fields
|
||||
if let Some(ref t) = req.window_title {
|
||||
if t.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_title too long (max 255)"))); }
|
||||
}
|
||||
if let Some(ref c) = req.window_class {
|
||||
if c.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("window_class too long (max 255)"))); }
|
||||
}
|
||||
if let Some(ref p) = req.process_name {
|
||||
if p.len() > 255 { return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("process_name too long (max 255)"))); }
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
|
||||
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
|
||||
@@ -81,6 +91,14 @@ pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Jso
|
||||
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Ensure at least one filter is non-empty after update
|
||||
let has_filter = window_title.as_ref().map_or(false, |s| !s.is_empty())
|
||||
|| window_class.as_ref().map_or(false, |s| !s.is_empty())
|
||||
|| process_name.as_ref().map_or(false, |s| !s.is_empty());
|
||||
if !has_filter {
|
||||
return Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required"));
|
||||
}
|
||||
|
||||
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&window_title)
|
||||
.bind(&window_class)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use csm_protocol::{Frame, MessageType};
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
@@ -16,7 +16,7 @@ pub struct CreateBlacklistRequest {
|
||||
}
|
||||
|
||||
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
|
||||
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC LIMIT 500")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
|
||||
@@ -53,8 +53,8 @@ pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<Cre
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
|
||||
let payload = fetch_software_payload_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
|
||||
@@ -80,6 +80,14 @@ pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>
|
||||
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Input validation (same as create)
|
||||
if name_pattern.trim().is_empty() || name_pattern.len() > 255 {
|
||||
return Json(ApiResponse::error("name_pattern must be 1-255 chars"));
|
||||
}
|
||||
if !matches!(action.as_str(), "block" | "alert") {
|
||||
return Json(ApiResponse::error("action must be 'block' or 'alert'"));
|
||||
}
|
||||
|
||||
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&name_pattern)
|
||||
.bind(&action)
|
||||
@@ -92,8 +100,8 @@ pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
|
||||
let payload = fetch_software_payload_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
@@ -110,8 +118,8 @@ pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path
|
||||
};
|
||||
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
|
||||
let payload = fetch_software_payload_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &payload, &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
@@ -134,6 +142,29 @@ pub async fn list_violations(State(state): State<AppState>, Query(f): Query<Viol
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the payload for pushing software control config to clients.
|
||||
/// Includes both blacklist (scoped by target) and whitelist (global).
|
||||
async fn fetch_software_payload_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> serde_json::Value {
|
||||
let blacklist = fetch_blacklist_for_push(db, target_type, target_id).await;
|
||||
|
||||
// Whitelist is always global — fetch all enabled entries
|
||||
let whitelist: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT name_pattern FROM software_whitelist WHERE enabled = 1"
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
serde_json::json!({
|
||||
"blacklist": blacklist,
|
||||
"whitelist": whitelist,
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_blacklist_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
@@ -156,3 +187,112 @@ async fn fetch_blacklist_for_push(
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
// ─── Whitelist management ───
|
||||
|
||||
/// GET /api/plugins/software-blocker/whitelist
|
||||
pub async fn list_whitelist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, name_pattern, reason, is_builtin, enabled, created_at FROM software_whitelist ORDER BY is_builtin DESC, created_at ASC LIMIT 500")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"whitelist": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"),
|
||||
"name_pattern": r.get::<String,_>("name_pattern"),
|
||||
"reason": r.get::<Option<String>,_>("reason"),
|
||||
"is_builtin": r.get::<bool,_>("is_builtin"),
|
||||
"enabled": r.get::<bool,_>("enabled"),
|
||||
"created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query software whitelist", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateWhitelistRequest {
|
||||
pub name_pattern: String,
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
/// POST /api/plugins/software-blocker/whitelist
|
||||
pub async fn add_to_whitelist(State(state): State<AppState>, Json(req): Json<CreateWhitelistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO software_whitelist (name_pattern, reason) VALUES (?, ?)")
|
||||
.bind(&req.name_pattern).bind(&req.reason)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
// Push updated whitelist to all online clients
|
||||
push_whitelist_to_all(&state).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add whitelist entry", e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// PUT /api/plugins/software-blocker/whitelist/:id
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateWhitelistRequest {
|
||||
pub name_pattern: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn update_whitelist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateWhitelistRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT name_pattern, enabled FROM software_whitelist WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query whitelist", e)),
|
||||
};
|
||||
|
||||
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Input validation — validate merged value
|
||||
if name_pattern.trim().is_empty() || name_pattern.len() > 255 {
|
||||
return Json(ApiResponse::error("name_pattern must be 1-255 chars"));
|
||||
}
|
||||
|
||||
match sqlx::query("UPDATE software_whitelist SET name_pattern = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&name_pattern).bind(enabled).bind(id)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
push_whitelist_to_all(&state).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update whitelist", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// DELETE /api/plugins/software-blocker/whitelist/:id
|
||||
pub async fn remove_from_whitelist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
match sqlx::query("DELETE FROM software_whitelist WHERE id = ? AND is_builtin = 0")
|
||||
.bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
push_whitelist_to_all(&state).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found or is built-in entry")),
|
||||
Err(e) => Json(ApiResponse::internal_error("remove whitelist entry", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push updated whitelist to all online clients by resending the full software control config.
|
||||
async fn push_whitelist_to_all(state: &AppState) {
|
||||
// Fetch payload once, then broadcast to all online clients
|
||||
let payload = fetch_software_payload_for_push(&state.db, "global", None).await;
|
||||
let frame = match Frame::new_json(MessageType::SoftwareBlacklist, &payload) {
|
||||
Ok(f) => f.encode(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let online = state.clients.list_online().await;
|
||||
for uid in &online {
|
||||
state.clients.send_to(uid, frame.clone()).await;
|
||||
}
|
||||
tracing::info!("Pushed updated whitelist to {} online clients", online.len());
|
||||
}
|
||||
@@ -16,7 +16,7 @@ pub struct CreateRuleRequest {
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
|
||||
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC LIMIT 500")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
@@ -75,6 +75,14 @@ pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Jso
|
||||
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Validate merged values
|
||||
if !matches!(rule_type.as_str(), "blacklist" | "whitelist" | "category") {
|
||||
return Json(ApiResponse::error("rule_type must be 'blacklist', 'whitelist', or 'category'"));
|
||||
}
|
||||
if pattern.trim().is_empty() || pattern.len() > 255 {
|
||||
return Json(ApiResponse::error("pattern must be 1-255 chars"));
|
||||
}
|
||||
|
||||
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&rule_type)
|
||||
.bind(&pattern)
|
||||
@@ -114,17 +122,42 @@ pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) ->
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
|
||||
pub struct LogFilters {
|
||||
pub device_uid: Option<String>,
|
||||
pub action: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
|
||||
"timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>() }))),
|
||||
let limit = f.page_size.unwrap_or(20).min(100);
|
||||
let offset = f.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
let device_uid = f.device_uid.as_deref().filter(|s| !s.is_empty());
|
||||
let action = f.action.as_deref().filter(|s| !s.is_empty());
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, url, action, timestamp FROM web_access_log \
|
||||
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) \
|
||||
ORDER BY timestamp DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(device_uid).bind(device_uid)
|
||||
.bind(action).bind(action)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"log": records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"),
|
||||
"device_uid": r.get::<String,_>("device_uid"),
|
||||
"url": r.get::<String,_>("url"),
|
||||
"action": r.get::<String,_>("action"),
|
||||
"timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>(),
|
||||
"page": f.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ pub async fn list_policies(
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
|
||||
FROM usb_policies ORDER BY created_at DESC"
|
||||
FROM usb_policies ORDER BY created_at DESC LIMIT 500"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
@@ -106,6 +106,11 @@ pub async fn create_policy(
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let enabled = body.enabled.unwrap_or(1);
|
||||
|
||||
// Input validation
|
||||
if body.name.trim().is_empty() || body.name.len() > 100 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name must be 1-100 chars")));
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
315
crates/server/src/health.rs
Normal file
315
crates/server/src/health.rs
Normal file
@@ -0,0 +1,315 @@
|
||||
use crate::AppState;
|
||||
use sqlx::Row;
|
||||
use tracing::{info, error};
|
||||
|
||||
/// Background task: recompute device health scores every 5 minutes
|
||||
pub async fn health_score_task(state: AppState) {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300));
|
||||
// First computation runs immediately, then every 5 minutes
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = recompute_all_scores(&state).await {
|
||||
error!("Health score computation failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn recompute_all_scores(state: &AppState) -> anyhow::Result<()> {
|
||||
// Get all device UIDs
|
||||
let devices: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT device_uid FROM devices"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let mut computed = 0u32;
|
||||
let mut errors = 0u32;
|
||||
|
||||
for uid in &devices {
|
||||
match compute_and_store_score(&state.db, uid).await {
|
||||
Ok(score) => {
|
||||
computed += 1;
|
||||
tracing::debug!("Health score for {}: {} ({})", uid, score.score, score.level);
|
||||
}
|
||||
Err(e) => {
|
||||
errors += 1;
|
||||
error!("Failed to compute health score for {}: {}", uid, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if computed > 0 {
|
||||
info!("Health scores computed: {} devices, {} errors", computed, errors);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct HealthScoreResult {
|
||||
score: i32,
|
||||
status_score: i32,
|
||||
encryption_score: i32,
|
||||
load_score: i32,
|
||||
alert_score: i32,
|
||||
compliance_score: i32,
|
||||
patch_score: i32,
|
||||
level: String,
|
||||
details: String,
|
||||
}
|
||||
|
||||
async fn compute_and_store_score(
|
||||
pool: &sqlx::SqlitePool,
|
||||
device_uid: &str,
|
||||
) -> anyhow::Result<HealthScoreResult> {
|
||||
let mut details = Vec::new();
|
||||
|
||||
// 1. Online status (15 points)
|
||||
let status_score: i32 = sqlx::query_scalar(
|
||||
"SELECT CASE WHEN status = 'online' THEN 15 ELSE 0 END FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if status_score < 15 {
|
||||
details.push("设备离线".to_string());
|
||||
}
|
||||
|
||||
// 2. Disk encryption (20 points)
|
||||
let encryption_score: i32 = sqlx::query_scalar(
|
||||
"SELECT CASE \
|
||||
WHEN COUNT(*) = 0 THEN 10 \
|
||||
WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) = COUNT(*) THEN 20 \
|
||||
WHEN SUM(CASE WHEN protection_status = 'On' THEN 1 ELSE 0 END) > 0 THEN 10 \
|
||||
ELSE 0 END \
|
||||
FROM disk_encryption_status WHERE device_uid = ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
if encryption_score < 20 {
|
||||
let unencrypted: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT drive_letter FROM disk_encryption_status WHERE device_uid = ? AND protection_status != 'On'"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if unencrypted.is_empty() && encryption_score < 20 {
|
||||
details.push("未检测到加密状态".to_string());
|
||||
} else if !unencrypted.is_empty() {
|
||||
details.push(format!("未加密驱动器: {}", unencrypted.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. System load (20 points): CPU(7) + Memory(7) + Disk(6)
|
||||
let load_row = sqlx::query(
|
||||
"SELECT cpu_usage, memory_usage, disk_usage FROM device_status WHERE device_uid = ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let load_score = if let Some(row) = load_row {
|
||||
let cpu = row.get::<f64, _>("cpu_usage");
|
||||
let mem = row.get::<f64, _>("memory_usage");
|
||||
let disk = row.get::<f64, _>("disk_usage");
|
||||
|
||||
let cpu_pts = if cpu < 70.0 { 7 } else if cpu < 90.0 { 4 } else { 0 };
|
||||
let mem_pts = if mem < 80.0 { 7 } else if mem < 95.0 { 4 } else { 0 };
|
||||
let disk_pts = if disk < 80.0 { 6 } else if disk < 95.0 { 3 } else { 0 };
|
||||
|
||||
let total = cpu_pts + mem_pts + disk_pts;
|
||||
if cpu >= 90.0 { details.push(format!("CPU过高 ({:.0}%)", cpu)); }
|
||||
if mem >= 95.0 { details.push(format!("内存过高 ({:.0}%)", mem)); }
|
||||
if disk >= 95.0 { details.push(format!("磁盘空间不足 ({:.0}%)", disk)); }
|
||||
total
|
||||
} else {
|
||||
details.push("无状态数据".to_string());
|
||||
0
|
||||
};
|
||||
|
||||
// 4. Alert clearance (15 points)
|
||||
let unhandled_alerts: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM alert_records WHERE device_uid = ? AND handled = 0"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let alert_score: i32 = if unhandled_alerts == 0 { 15 } else { 0 };
|
||||
if unhandled_alerts > 0 {
|
||||
details.push(format!("{}条未处理告警", unhandled_alerts));
|
||||
}
|
||||
|
||||
// 5. Compliance (10 points): no recent software violations
|
||||
let recent_violations: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM software_violations WHERE device_uid = ? AND timestamp > datetime('now', '-7 days')"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let compliance_score: i32 = if recent_violations == 0 { 10 } else {
|
||||
details.push(format!("近期{}次软件违规", recent_violations));
|
||||
(10 - (recent_violations as i32).min(10)).max(0)
|
||||
};
|
||||
|
||||
// 6. Patch status (20 points): reserved for future patch management
|
||||
// For now, give full score if device is online
|
||||
let patch_score: i32 = if status_score > 0 { 20 } else { 10 };
|
||||
|
||||
let score = status_score + encryption_score + load_score + alert_score + compliance_score + patch_score;
|
||||
let level = if score >= 80 {
|
||||
"healthy"
|
||||
} else if score >= 50 {
|
||||
"warning"
|
||||
} else if score > 0 {
|
||||
"critical"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
let details_json = if details.is_empty() {
|
||||
"[]".to_string()
|
||||
} else {
|
||||
serde_json::to_string(&details).unwrap_or_else(|_| "[]".to_string())
|
||||
};
|
||||
|
||||
let result = HealthScoreResult {
|
||||
score,
|
||||
status_score,
|
||||
encryption_score,
|
||||
load_score,
|
||||
alert_score,
|
||||
compliance_score,
|
||||
patch_score,
|
||||
level: level.to_string(),
|
||||
details: details_json,
|
||||
};
|
||||
|
||||
// Upsert the score
|
||||
sqlx::query(
|
||||
"INSERT INTO device_health_scores \
|
||||
(device_uid, score, status_score, encryption_score, load_score, alert_score, compliance_score, patch_score, level, details, computed_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) \
|
||||
ON CONFLICT(device_uid) DO UPDATE SET \
|
||||
score = excluded.score, status_score = excluded.status_score, \
|
||||
encryption_score = excluded.encryption_score, load_score = excluded.load_score, \
|
||||
alert_score = excluded.alert_score, compliance_score = excluded.compliance_score, \
|
||||
patch_score = excluded.patch_score, level = excluded.level, \
|
||||
details = excluded.details, computed_at = datetime('now')"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(result.score)
|
||||
.bind(result.status_score)
|
||||
.bind(result.encryption_score)
|
||||
.bind(result.load_score)
|
||||
.bind(result.alert_score)
|
||||
.bind(result.compliance_score)
|
||||
.bind(result.patch_score)
|
||||
.bind(&result.level)
|
||||
.bind(&result.details)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Compute a single device's health score on demand
|
||||
pub async fn get_device_score(
|
||||
pool: &sqlx::SqlitePool,
|
||||
device_uid: &str,
|
||||
) -> anyhow::Result<Option<serde_json::Value>> {
|
||||
// Try to get cached score
|
||||
let row = sqlx::query(
|
||||
"SELECT score, status_score, encryption_score, load_score, alert_score, compliance_score, \
|
||||
patch_score, level, details, computed_at \
|
||||
FROM device_health_scores WHERE device_uid = ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some(r) => Ok(Some(serde_json::json!({
|
||||
"device_uid": device_uid,
|
||||
"score": r.get::<i32, _>("score"),
|
||||
"breakdown": {
|
||||
"status": r.get::<i32, _>("status_score"),
|
||||
"encryption": r.get::<i32, _>("encryption_score"),
|
||||
"load": r.get::<i32, _>("load_score"),
|
||||
"alerts": r.get::<i32, _>("alert_score"),
|
||||
"compliance": r.get::<i32, _>("compliance_score"),
|
||||
"patches": r.get::<i32, _>("patch_score"),
|
||||
},
|
||||
"level": r.get::<String, _>("level"),
|
||||
"details": serde_json::from_str::<serde_json::Value>(
|
||||
&r.get::<String, _>("details")
|
||||
).unwrap_or(serde_json::json!([])),
|
||||
"computed_at": r.get::<String, _>("computed_at"),
|
||||
}))),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get health overview for all devices (dashboard aggregation)
|
||||
pub async fn get_health_overview(pool: &sqlx::SqlitePool) -> anyhow::Result<serde_json::Value> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT h.device_uid, h.score, h.level, d.hostname, d.status, d.group_name \
|
||||
FROM device_health_scores h \
|
||||
JOIN devices d ON d.device_uid = h.device_uid \
|
||||
ORDER BY h.score ASC"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let mut healthy = 0u32;
|
||||
let mut warning = 0u32;
|
||||
let mut critical = 0u32;
|
||||
let mut unknown = 0u32;
|
||||
let mut total_score = 0i64;
|
||||
|
||||
let mut devices: Vec<serde_json::Value> = Vec::with_capacity(rows.len());
|
||||
|
||||
for r in &rows {
|
||||
let level: String = r.get("level");
|
||||
match level.as_str() {
|
||||
"healthy" => healthy += 1,
|
||||
"warning" => warning += 1,
|
||||
"critical" => critical += 1,
|
||||
_ => unknown += 1,
|
||||
}
|
||||
total_score += r.get::<i32, _>("score") as i64;
|
||||
|
||||
devices.push(serde_json::json!({
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"hostname": r.get::<String, _>("hostname"),
|
||||
"status": r.get::<String, _>("status"),
|
||||
"group_name": r.get::<String, _>("group_name"),
|
||||
"score": r.get::<i32, _>("score"),
|
||||
"level": level,
|
||||
}));
|
||||
}
|
||||
|
||||
let total = devices.len().max(1);
|
||||
let avg_score = total_score as f64 / total as f64;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"summary": {
|
||||
"total": total,
|
||||
"healthy": healthy,
|
||||
"warning": warning,
|
||||
"critical": critical,
|
||||
"unknown": unknown,
|
||||
"avg_score": (avg_score * 10.0).round() / 10.0,
|
||||
},
|
||||
"devices": devices,
|
||||
}))
|
||||
}
|
||||
@@ -7,6 +7,7 @@ use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use tokio::net::TcpListener;
|
||||
use axum::http::Method as HttpMethod;
|
||||
use tower_http::cors::CorsLayer;
|
||||
@@ -23,6 +24,8 @@ mod db;
|
||||
mod tcp;
|
||||
mod ws;
|
||||
mod alert;
|
||||
mod health;
|
||||
mod anomaly;
|
||||
|
||||
use config::AppConfig;
|
||||
|
||||
@@ -38,6 +41,7 @@ pub struct AppState {
|
||||
pub clients: Arc<tcp::ClientRegistry>,
|
||||
pub ws_hub: Arc<ws::WsHub>,
|
||||
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
|
||||
pub ws_tickets: Arc<tokio::sync::Mutex<HashMap<String, ws::TicketClaim>>>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -58,15 +62,19 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Security checks
|
||||
if config.registration_token.is_empty() {
|
||||
warn!("SECURITY: registration_token is empty — any device can register!");
|
||||
anyhow::bail!("FATAL: registration_token is empty. Set it in config.toml or via CSM_REGISTRATION_TOKEN env var. Device registration is disabled for security.");
|
||||
}
|
||||
if config.auth.jwt_secret.len() < 32 {
|
||||
warn!("SECURITY: jwt_secret is too short ({} chars) — consider using a 32+ byte key from CSM_JWT_SECRET env var", config.auth.jwt_secret.len());
|
||||
if config.auth.jwt_secret.is_empty() || config.auth.jwt_secret.len() < 32 {
|
||||
anyhow::bail!("FATAL: jwt_secret is missing or too short. Set CSM_JWT_SECRET env var with a 32+ byte random key.");
|
||||
}
|
||||
if config.server.tls.is_none() {
|
||||
warn!("SECURITY: No TLS configured — all TCP communication is plaintext. Configure [server.tls] for production.");
|
||||
if std::env::var("CSM_DEV").is_err() {
|
||||
warn!("Set CSM_DEV=1 to suppress this warning in development environments.");
|
||||
}
|
||||
}
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Initialize database
|
||||
let db = init_database(&config.database.path).await?;
|
||||
run_migrations(&db).await?;
|
||||
info!("Database initialized at {}", config.database.path);
|
||||
@@ -84,6 +92,7 @@ async fn main() -> Result<()> {
|
||||
clients: clients.clone(),
|
||||
ws_hub: ws_hub.clone(),
|
||||
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
|
||||
ws_tickets: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
};
|
||||
|
||||
// Start background tasks
|
||||
@@ -92,6 +101,12 @@ async fn main() -> Result<()> {
|
||||
alert::cleanup_task(cleanup_state).await;
|
||||
});
|
||||
|
||||
// Health score computation task
|
||||
let health_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
health::health_score_task(health_state).await;
|
||||
});
|
||||
|
||||
// Start TCP listener for client connections
|
||||
let tcp_state = state.clone();
|
||||
let tcp_addr = config.server.tcp_addr.clone();
|
||||
@@ -131,7 +146,11 @@ async fn main() -> Result<()> {
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("content-security-policy"),
|
||||
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
|
||||
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' wss:; font-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("permissions-policy"),
|
||||
axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=(), payment=()"),
|
||||
))
|
||||
.with_state(state);
|
||||
|
||||
@@ -197,6 +216,9 @@ async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
include_str!("../../../migrations/014_clipboard_control.sql"),
|
||||
include_str!("../../../migrations/015_plugin_control.sql"),
|
||||
include_str!("../../../migrations/016_encryption_alerts_unique.sql"),
|
||||
include_str!("../../../migrations/017_device_health_scores.sql"),
|
||||
include_str!("../../../migrations/018_patch_management.sql"),
|
||||
include_str!("../../../migrations/019_software_whitelist.sql"),
|
||||
];
|
||||
|
||||
// Create migrations tracking table
|
||||
@@ -257,11 +279,27 @@ async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
.await?;
|
||||
|
||||
warn!("Created default admin user (username: admin)");
|
||||
// Print password directly to stderr — bypasses tracing JSON formatter
|
||||
eprintln!("============================================================");
|
||||
eprintln!(" Generated admin password: {}", random_password);
|
||||
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
|
||||
eprintln!("============================================================");
|
||||
// Write password to restricted file instead of stderr (avoid log capture)
|
||||
let pw_path = std::path::Path::new("data/initial-password.txt");
|
||||
if let Some(parent) = pw_path.parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
match std::fs::write(pw_path, &random_password) {
|
||||
Ok(_) => {
|
||||
warn!("Initial admin password saved to data/initial-password.txt (delete after first login)");
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(pw_path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
// Windows: restrict ACL would require windows-rs; at minimum hide the file
|
||||
let _ = std::process::Command::new("attrib").args(["+H", &pw_path.to_string_lossy()]).output();
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to save initial password to file: {}. Password was: {}", e, random_password),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -278,13 +316,14 @@ fn build_cors_layer(origins: &[String]) -> CorsLayer {
|
||||
.collect();
|
||||
|
||||
if allowed_origins.is_empty() {
|
||||
// No CORS — production safe by default
|
||||
// No CORS — production safe by default (same-origin cookies work without CORS)
|
||||
CorsLayer::new()
|
||||
} else {
|
||||
CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
|
||||
.allow_methods([HttpMethod::GET, HttpMethod::POST, HttpMethod::PUT, HttpMethod::DELETE])
|
||||
.allow_headers([axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE])
|
||||
.allow_credentials(true)
|
||||
.max_age(std::time::Duration::from_secs(3600))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{info, warn, debug};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
@@ -167,7 +167,7 @@ pub async fn push_all_plugin_configs(
|
||||
}
|
||||
}
|
||||
|
||||
// Software blacklist
|
||||
// Software blacklist + whitelist
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
@@ -180,8 +180,20 @@ pub async fn push_all_plugin_configs(
|
||||
"name_pattern": r.get::<String, _>("name_pattern"),
|
||||
"action": r.get::<String, _>("action"),
|
||||
})).collect();
|
||||
if !entries.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
|
||||
|
||||
// Fetch whitelist (global, always pushed to all devices)
|
||||
let whitelist: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT name_pattern FROM software_whitelist WHERE enabled = 1"
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if !entries.is_empty() || !whitelist.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({
|
||||
"blacklist": entries,
|
||||
"whitelist": whitelist,
|
||||
})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
@@ -261,17 +273,53 @@ pub async fn push_all_plugin_configs(
|
||||
}
|
||||
}
|
||||
|
||||
// Disk encryption config — push default reporting interval (no dedicated config table)
|
||||
// Disk encryption config — read from patch_policies if available, else default
|
||||
{
|
||||
let config = csm_protocol::DiskEncryptionConfigPayload {
|
||||
enabled: true,
|
||||
report_interval_secs: 3600,
|
||||
let config = if let Ok(Some(row)) = sqlx::query(
|
||||
"SELECT auto_approve, enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1"
|
||||
)
|
||||
.fetch_optional(db)
|
||||
.await
|
||||
{
|
||||
// If patch_policies exist, infer disk encryption should be enabled
|
||||
csm_protocol::DiskEncryptionConfigPayload {
|
||||
enabled: row.get::<i32, _>("enabled") != 0,
|
||||
report_interval_secs: 3600,
|
||||
}
|
||||
} else {
|
||||
csm_protocol::DiskEncryptionConfigPayload {
|
||||
enabled: true,
|
||||
report_interval_secs: 3600,
|
||||
}
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionConfig, &config) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Patch scan config — read from patch_policies if available, else default
|
||||
{
|
||||
let config = if let Ok(Some(row)) = sqlx::query(
|
||||
"SELECT enabled FROM patch_policies WHERE target_type = 'global' AND enabled = 1 LIMIT 1"
|
||||
)
|
||||
.fetch_optional(db)
|
||||
.await
|
||||
{
|
||||
csm_protocol::PatchScanConfigPayload {
|
||||
enabled: row.get::<i32, _>("enabled") != 0,
|
||||
scan_interval_secs: 43200,
|
||||
}
|
||||
} else {
|
||||
csm_protocol::PatchScanConfigPayload {
|
||||
enabled: true,
|
||||
scan_interval_secs: 43200,
|
||||
}
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PatchScanConfig, &config) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Push plugin enable/disable state — disable any plugins that admin has turned off
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT plugin_name FROM plugin_state WHERE enabled = 0"
|
||||
@@ -297,10 +345,17 @@ pub async fn push_all_plugin_configs(
|
||||
/// Maximum accumulated read buffer size per connection (8 MB)
|
||||
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Registry of all connected client sessions
|
||||
/// Registry of all connected client sessions, including cached device secrets.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ClientRegistry {
|
||||
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
|
||||
sessions: Arc<RwLock<HashMap<String, ClientSession>>>,
|
||||
}
|
||||
|
||||
/// Per-device session data kept in memory for fast access.
|
||||
struct ClientSession {
|
||||
tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
/// Cached device_secret for HMAC verification — avoids a DB query per heartbeat.
|
||||
secret: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
@@ -308,8 +363,8 @@ impl ClientRegistry {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
|
||||
self.sessions.write().await.insert(device_uid, tx);
|
||||
pub async fn register(&self, device_uid: String, secret: Option<String>, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
|
||||
self.sessions.write().await.insert(device_uid, ClientSession { tx, secret });
|
||||
}
|
||||
|
||||
pub async fn unregister(&self, device_uid: &str) {
|
||||
@@ -317,13 +372,25 @@ impl ClientRegistry {
|
||||
}
|
||||
|
||||
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
|
||||
if let Some(tx) = self.sessions.read().await.get(device_uid) {
|
||||
tx.send(data).await.is_ok()
|
||||
if let Some(session) = self.sessions.read().await.get(device_uid) {
|
||||
session.tx.send(data).await.is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cached device secret for HMAC verification (avoids DB query per heartbeat).
|
||||
pub async fn get_secret(&self, device_uid: &str) -> Option<String> {
|
||||
self.sessions.read().await.get(device_uid).and_then(|s| s.secret.clone())
|
||||
}
|
||||
|
||||
/// Backfill cached device secret after a cache miss (e.g. server restart).
|
||||
pub async fn set_secret(&self, device_uid: &str, secret: String) {
|
||||
if let Some(session) = self.sessions.write().await.get_mut(device_uid) {
|
||||
session.secret = Some(secret);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn count(&self) -> usize {
|
||||
self.sessions.read().await.len()
|
||||
}
|
||||
@@ -366,7 +433,7 @@ pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<(
|
||||
Some(acceptor) => {
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
if let Err(e) = handle_client_tls(tls_stream, state).await {
|
||||
if let Err(e) = handle_client(tls_stream, state).await {
|
||||
warn!("Client {} TLS error: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
@@ -451,6 +518,17 @@ fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &
|
||||
}
|
||||
}
|
||||
|
||||
/// Constant-time string comparison to prevent timing attacks on secrets.
|
||||
fn constant_time_eq(a: &str, b: &str) -> bool {
|
||||
use std::iter;
|
||||
if a.len() != b.len() {
|
||||
// Still do a comparison to avoid leaking length via timing
|
||||
let _ = a.as_bytes().iter().zip(iter::repeat(0u8)).map(|(x, y)| x ^ y);
|
||||
return false;
|
||||
}
|
||||
a.as_bytes().iter().zip(b.as_bytes()).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0
|
||||
}
|
||||
|
||||
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
|
||||
/// `hmac_fail_count` tracks consecutive HMAC failures; caller checks it for disconnect threshold.
|
||||
async fn process_frame(
|
||||
@@ -467,10 +545,10 @@ async fn process_frame(
|
||||
|
||||
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
|
||||
|
||||
// Validate registration token against configured token
|
||||
// Validate registration token against configured token (constant-time comparison)
|
||||
let expected_token = &state.config.registration_token;
|
||||
if !expected_token.is_empty() {
|
||||
if req.registration_token.is_empty() || req.registration_token != *expected_token {
|
||||
if req.registration_token.is_empty() || !constant_time_eq(&req.registration_token, expected_token) {
|
||||
warn!("Registration rejected for {}: invalid token", req.device_uid);
|
||||
let err_frame = Frame::new_json(MessageType::RegisterResponse,
|
||||
&serde_json::json!({"error": "invalid_registration_token"}))?;
|
||||
@@ -514,7 +592,7 @@ async fn process_frame(
|
||||
*device_uid = Some(req.device_uid.clone());
|
||||
// If this device was already connected on a different session, evict the old one
|
||||
// The new register() call will replace it in the hashmap
|
||||
state.clients.register(req.device_uid.clone(), tx.clone()).await;
|
||||
state.clients.register(req.device_uid.clone(), Some(device_secret.clone()), tx.clone()).await;
|
||||
|
||||
// Send registration response
|
||||
let config = csm_protocol::ClientConfig::default();
|
||||
@@ -539,17 +617,25 @@ async fn process_frame(
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
|
||||
let secret: Option<String> = sqlx::query_scalar(
|
||||
"SELECT device_secret FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&heartbeat.device_uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
|
||||
anyhow::anyhow!("DB error during HMAC verification")
|
||||
})?;
|
||||
// Verify HMAC — use cached secret from ClientRegistry, fall back to DB on cache miss (e.g. after restart)
|
||||
let mut secret = state.clients.get_secret(&heartbeat.device_uid).await;
|
||||
if secret.is_none() {
|
||||
// Cache miss (server restarted) — query DB and backfill cache
|
||||
let db_secret: Option<String> = sqlx::query_scalar(
|
||||
"SELECT device_secret FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&heartbeat.device_uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
|
||||
anyhow::anyhow!("DB error during HMAC verification")
|
||||
})?;
|
||||
if let Some(ref s) = db_secret {
|
||||
state.clients.set_secret(&heartbeat.device_uid, s.clone()).await;
|
||||
}
|
||||
secret = db_secret;
|
||||
}
|
||||
|
||||
if let Some(ref secret) = secret {
|
||||
if !secret.is_empty() {
|
||||
@@ -650,6 +736,40 @@ async fn process_frame(
|
||||
crate::db::DeviceRepo::upsert_software(&state.db, &sw).await?;
|
||||
}
|
||||
|
||||
MessageType::AssetChange => {
|
||||
let change: csm_protocol::AssetChange = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid asset change: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "AssetChange", &change.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let change_type_str = match change.change_type {
|
||||
csm_protocol::AssetChangeType::Hardware => "hardware",
|
||||
csm_protocol::AssetChangeType::SoftwareAdded => "software_added",
|
||||
csm_protocol::AssetChangeType::SoftwareRemoved => "software_removed",
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO asset_changes (device_uid, change_type, change_detail, detected_at) \
|
||||
VALUES (?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&change.device_uid)
|
||||
.bind(change_type_str)
|
||||
.bind(serde_json::to_string(&change.change_detail).map_err(|e| anyhow::anyhow!("Failed to serialize asset change detail: {}", e))?)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting asset change: {}", e))?;
|
||||
|
||||
debug!("Asset change: {} {:?} for device {}", change_type_str, change.change_detail, change.device_uid);
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "asset_change",
|
||||
"device_uid": change.device_uid,
|
||||
"change_type": change_type_str,
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::UsageReport => {
|
||||
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
|
||||
@@ -910,23 +1030,86 @@ async fn process_frame(
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for rule_stat in &stats.rule_stats {
|
||||
sqlx::query(
|
||||
"INSERT INTO popup_block_stats (device_uid, rule_id, blocked_count, period_secs, reported_at) \
|
||||
VALUES (?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&stats.device_uid)
|
||||
.bind(rule_stat.rule_id)
|
||||
.bind(rule_stat.hits as i32)
|
||||
.bind(stats.period_secs as i32)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
// Upsert aggregate stats per device per day
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO popup_block_stats (device_uid, blocked_count, date) \
|
||||
VALUES (?, ?, ?) \
|
||||
ON CONFLICT(device_uid, date) DO UPDATE SET \
|
||||
blocked_count = blocked_count + excluded.blocked_count"
|
||||
)
|
||||
.bind(&stats.device_uid)
|
||||
.bind(stats.blocked_count as i32)
|
||||
.bind(&today)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
debug!("Popup block stats: {} blocked {} windows in {}s", stats.device_uid, stats.blocked_count, stats.period_secs);
|
||||
}
|
||||
|
||||
MessageType::PatchStatusReport => {
|
||||
let payload: csm_protocol::PatchStatusPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid patch status report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "PatchStatusReport", &payload.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for patch in &payload.patches {
|
||||
sqlx::query(
|
||||
"INSERT INTO patch_status (device_uid, kb_id, title, severity, is_installed, installed_at, updated_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, datetime('now')) \
|
||||
ON CONFLICT(device_uid, kb_id) DO UPDATE SET \
|
||||
title = excluded.title, severity = COALESCE(excluded.severity, patch_status.severity), \
|
||||
is_installed = excluded.is_installed, installed_at = excluded.installed_at, \
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&payload.device_uid)
|
||||
.bind(&patch.kb_id)
|
||||
.bind(&patch.title)
|
||||
.bind(&patch.severity)
|
||||
.bind(patch.is_installed as i32)
|
||||
.bind(&patch.installed_at)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting patch status: {}", e))?;
|
||||
}
|
||||
|
||||
info!("Patch status reported: {} ({} patches)", payload.device_uid, payload.patches.len());
|
||||
}
|
||||
|
||||
MessageType::BehaviorMetricsReport => {
|
||||
let metrics: csm_protocol::BehaviorMetricsPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid behavior metrics: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "BehaviorMetricsReport", &metrics.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO behavior_metrics (device_uid, clipboard_ops_count, clipboard_ops_night, print_jobs_count, usb_file_ops_count, new_processes_count, period_secs, reported_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&metrics.device_uid)
|
||||
.bind(metrics.clipboard_ops_count as i32)
|
||||
.bind(metrics.clipboard_ops_night as i32)
|
||||
.bind(metrics.print_jobs_count as i32)
|
||||
.bind(metrics.usb_file_ops_count as i32)
|
||||
.bind(metrics.new_processes_count as i32)
|
||||
.bind(metrics.period_secs as i32)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting behavior metrics: {}", e))?;
|
||||
|
||||
// Run anomaly detection inline
|
||||
crate::anomaly::check_anomalies(&state.db, &state.ws_hub, &metrics).await;
|
||||
|
||||
debug!("Behavior metrics saved: {} (clipboard={}, print={}, usb_file={}, procs={})",
|
||||
metrics.device_uid, metrics.clipboard_ops_count, metrics.print_jobs_count,
|
||||
metrics.usb_file_ops_count, metrics.new_processes_count);
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
@@ -935,13 +1118,14 @@ async fn process_frame(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a single client TCP connection
|
||||
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
|
||||
/// Handle a single client TCP connection (plaintext or TLS)
|
||||
async fn handle_client<S>(stream: S, state: AppState) -> anyhow::Result<()>
|
||||
where
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let _ = stream.set_nodelay(true);
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let (mut reader, mut writer) = tokio::io::split(stream);
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
@@ -1018,81 +1202,50 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a TLS-wrapped client connection
|
||||
async fn handle_client_tls(
|
||||
stream: tokio_rustls::server::TlsStream<TcpStream>,
|
||||
state: AppState,
|
||||
) -> anyhow::Result<()> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let (mut reader, mut writer) = tokio::io::split(stream);
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
let hmac_fail_count = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
if writer.write_all(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
/// Push a TLS certificate rotation notice to all online devices.
|
||||
/// Computes the fingerprint of the new certificate and sends ConfigUpdate(TlsCertRotate).
|
||||
pub async fn push_tls_cert_rotation(clients: &ClientRegistry, new_cert_pem: &[u8], valid_until: &str) -> usize {
|
||||
// Compute SHA-256 fingerprint of the new certificate
|
||||
let certs: Vec<_> = match rustls_pemfile::certs(&mut &new_cert_pem[..]).collect::<Result<Vec<_>, _>>() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse new certificate for rotation: {:?}", e);
|
||||
return 0;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Reader loop with idle timeout
|
||||
'reader: loop {
|
||||
let read_result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
|
||||
reader.read(&mut buffer),
|
||||
).await;
|
||||
let end_entity = match certs.first() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
warn!("No certificates found in PEM for rotation");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
let n = match read_result {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => {
|
||||
warn!("Idle timeout for TLS device {:?}, disconnecting", device_uid);
|
||||
break;
|
||||
}
|
||||
let fingerprint = {
|
||||
use sha2::{Sha256, Digest};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(end_entity.as_ref());
|
||||
hex::encode(hasher.finalize())
|
||||
};
|
||||
|
||||
info!("Pushing TLS cert rotation: new fingerprint={}... valid_until={}", &fingerprint[..16], valid_until);
|
||||
|
||||
let config_update = csm_protocol::ConfigUpdateType::TlsCertRotate {
|
||||
new_cert_hash: fingerprint,
|
||||
valid_until: valid_until.to_string(),
|
||||
};
|
||||
|
||||
let online = clients.list_online().await;
|
||||
let mut pushed = 0usize;
|
||||
for uid in &online {
|
||||
let frame = match Frame::new_json(MessageType::ConfigUpdate, &config_update) {
|
||||
Ok(f) => f,
|
||||
Err(_) => continue,
|
||||
};
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
if read_buf.len() > MAX_READ_BUF_SIZE {
|
||||
warn!("TLS connection exceeded max buffer size, dropping");
|
||||
break;
|
||||
}
|
||||
|
||||
while let Some(frame) = Frame::decode(&read_buf)? {
|
||||
let frame_size = frame.encoded_size();
|
||||
read_buf.drain(..frame_size);
|
||||
|
||||
if frame.version != PROTOCOL_VERSION {
|
||||
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !rate_limiter.check() {
|
||||
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
|
||||
// Disconnect if too many consecutive HMAC failures
|
||||
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
|
||||
warn!("Too many HMAC failures for TLS device {:?}, disconnecting", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
if clients.send_to(uid, frame.encode()).await {
|
||||
pushed += 1;
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_on_disconnect(&state, &device_uid).await;
|
||||
write_task.abort();
|
||||
Ok(())
|
||||
pushed
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::extract::Query;
|
||||
use jsonwebtoken::{decode, Validation, DecodingKey};
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::broadcast;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, warn};
|
||||
use crate::api::auth::Claims;
|
||||
use crate::AppState;
|
||||
|
||||
/// WebSocket hub for broadcasting real-time events to admin browsers
|
||||
@@ -32,65 +30,73 @@ impl WsHub {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsAuthParams {
|
||||
pub token: Option<String>,
|
||||
/// Claim stored when a WS ticket is created. Consumed on WS connection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TicketClaim {
|
||||
pub user_id: i64,
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: std::time::Instant,
|
||||
}
|
||||
|
||||
/// HTTP upgrade handler for WebSocket connections
|
||||
/// Validates JWT token from query parameter before upgrading
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsTicketParams {
|
||||
pub ticket: Option<String>,
|
||||
}
|
||||
|
||||
/// HTTP upgrade handler for WebSocket connections.
|
||||
/// Validates a one-time ticket (obtained via POST /api/ws/ticket) before upgrading.
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Query(params): Query<WsAuthParams>,
|
||||
Query(params): Query<WsTicketParams>,
|
||||
axum::extract::State(state): axum::extract::State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let token = match params.token {
|
||||
let ticket = match params.ticket {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("WebSocket connection rejected: no token provided");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
|
||||
warn!("WebSocket connection rejected: no ticket provided");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Missing ticket").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let claims = match decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(c) => c.claims,
|
||||
Err(e) => {
|
||||
warn!("WebSocket connection rejected: invalid token - {}", e);
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
|
||||
// Consume (remove) the ticket from the store — single use
|
||||
let claim = {
|
||||
let mut tickets = state.ws_tickets.lock().await;
|
||||
match tickets.remove(&ticket) {
|
||||
Some(claim) => claim,
|
||||
None => {
|
||||
warn!("WebSocket connection rejected: invalid or expired ticket");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid or expired ticket").into_response();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if claims.token_type != "access" {
|
||||
warn!("WebSocket connection rejected: not an access token");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
|
||||
// Check ticket age (30 second TTL)
|
||||
if claim.created_at.elapsed().as_secs() > 30 {
|
||||
warn!("WebSocket connection rejected: ticket expired");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Ticket expired").into_response();
|
||||
}
|
||||
|
||||
let hub = state.ws_hub.clone();
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, claim, hub))
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
|
||||
debug!("WebSocket client connected: user={}", claims.username);
|
||||
async fn handle_socket(mut socket: WebSocket, claim: TicketClaim, hub: Arc<WsHub>) {
|
||||
debug!("WebSocket client connected: user={}", claim.username);
|
||||
|
||||
let welcome = serde_json::json!({
|
||||
"type": "connected",
|
||||
"message": "CSM real-time feed active",
|
||||
"user": claims.username
|
||||
"user": claim.username
|
||||
});
|
||||
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Subscribe to broadcast hub for real-time events
|
||||
let mut rx = hub.subscribe();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward broadcast messages to WebSocket client
|
||||
msg = rx.recv() => {
|
||||
match msg {
|
||||
Ok(text) => {
|
||||
@@ -104,7 +110,6 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
// Handle incoming WebSocket messages (ping/close)
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
@@ -121,5 +126,5 @@ async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
|
||||
}
|
||||
}
|
||||
|
||||
debug!("WebSocket client disconnected: user={}", claims.username);
|
||||
debug!("WebSocket client disconnected: user={}", claim.username);
|
||||
}
|
||||
|
||||
20
migrations/017_device_health_scores.sql
Normal file
20
migrations/017_device_health_scores.sql
Normal file
@@ -0,0 +1,20 @@
|
||||
-- 017_device_health_scores.sql: Device health scoring system
|
||||
|
||||
CREATE TABLE IF NOT EXISTS device_health_scores (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
|
||||
score INTEGER NOT NULL DEFAULT 0 CHECK(score >= 0 AND score <= 100),
|
||||
status_score INTEGER NOT NULL DEFAULT 0,
|
||||
encryption_score INTEGER NOT NULL DEFAULT 0,
|
||||
load_score INTEGER NOT NULL DEFAULT 0,
|
||||
alert_score INTEGER NOT NULL DEFAULT 0,
|
||||
compliance_score INTEGER NOT NULL DEFAULT 0,
|
||||
patch_score INTEGER NOT NULL DEFAULT 0,
|
||||
level TEXT NOT NULL DEFAULT 'unknown' CHECK(level IN ('healthy', 'warning', 'critical', 'unknown')),
|
||||
details TEXT,
|
||||
computed_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
UNIQUE(device_uid)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_health_scores_level ON device_health_scores(level);
|
||||
CREATE INDEX IF NOT EXISTS idx_health_scores_computed ON device_health_scores(computed_at);
|
||||
59
migrations/018_patch_management.sql
Normal file
59
migrations/018_patch_management.sql
Normal file
@@ -0,0 +1,59 @@
|
||||
-- 018_patch_management.sql: Patch management system
|
||||
|
||||
CREATE TABLE IF NOT EXISTS patch_status (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
|
||||
kb_id TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
severity TEXT,
|
||||
is_installed INTEGER NOT NULL DEFAULT 0,
|
||||
discovered_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
installed_at TEXT,
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
UNIQUE(device_uid, kb_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS patch_policies (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'device', 'group')),
|
||||
target_id TEXT,
|
||||
auto_approve INTEGER NOT NULL DEFAULT 0,
|
||||
severity_filter TEXT NOT NULL DEFAULT 'important',
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
-- Behavior metrics for anomaly detection
|
||||
CREATE TABLE IF NOT EXISTS behavior_metrics (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
|
||||
clipboard_ops_count INTEGER NOT NULL DEFAULT 0,
|
||||
clipboard_ops_night INTEGER NOT NULL DEFAULT 0,
|
||||
print_jobs_count INTEGER NOT NULL DEFAULT 0,
|
||||
usb_file_ops_count INTEGER NOT NULL DEFAULT 0,
|
||||
new_processes_count INTEGER NOT NULL DEFAULT 0,
|
||||
period_secs INTEGER NOT NULL DEFAULT 3600,
|
||||
reported_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
-- Anomaly alerts generated by the detection engine
|
||||
CREATE TABLE IF NOT EXISTS anomaly_alerts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
|
||||
anomaly_type TEXT NOT NULL,
|
||||
severity TEXT NOT NULL DEFAULT 'medium' CHECK(severity IN ('low', 'medium', 'high', 'critical')),
|
||||
detail TEXT NOT NULL,
|
||||
metric_value REAL,
|
||||
baseline_value REAL,
|
||||
handled INTEGER NOT NULL DEFAULT 0,
|
||||
handled_by TEXT,
|
||||
handled_at TEXT,
|
||||
triggered_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_patch_status_device ON patch_status(device_uid);
|
||||
CREATE INDEX IF NOT EXISTS idx_patch_status_severity ON patch_status(severity, is_installed);
|
||||
CREATE INDEX IF NOT EXISTS idx_behavior_metrics_device_time ON behavior_metrics(device_uid, reported_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_anomaly_alerts_device ON anomaly_alerts(device_uid);
|
||||
CREATE INDEX IF NOT EXISTS idx_anomaly_alerts_unhandled ON anomaly_alerts(handled) WHERE handled = 0;
|
||||
54
migrations/019_software_whitelist.sql
Normal file
54
migrations/019_software_whitelist.sql
Normal file
@@ -0,0 +1,54 @@
|
||||
-- Software whitelist: processes that should NEVER be blocked even if matched by blacklist rules.
|
||||
-- This provides a safety net to prevent false positives from killing legitimate applications.
|
||||
CREATE TABLE IF NOT EXISTS software_whitelist (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name_pattern TEXT NOT NULL,
|
||||
reason TEXT,
|
||||
is_builtin INTEGER NOT NULL DEFAULT 0, -- 1 = system default, 0 = admin-added
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
-- Built-in whitelist entries for common safe applications
|
||||
INSERT INTO software_whitelist (name_pattern, reason, is_builtin) VALUES
|
||||
-- Browsers
|
||||
('chrome.exe', 'Google Chrome browser', 1),
|
||||
('msedge.exe', 'Microsoft Edge browser', 1),
|
||||
('firefox.exe', 'Mozilla Firefox browser', 1),
|
||||
('iexplore.exe', 'Internet Explorer', 1),
|
||||
('opera.exe', 'Opera browser', 1),
|
||||
('brave.exe', 'Brave browser', 1),
|
||||
('vivaldi.exe', 'Vivaldi browser', 1),
|
||||
-- Development tools & IDEs
|
||||
('code.exe', 'Visual Studio Code', 1),
|
||||
('devenv.exe', 'Visual Studio', 1),
|
||||
('idea64.exe', 'IntelliJ IDEA', 1),
|
||||
('webstorm64.exe', 'WebStorm', 1),
|
||||
('pycharm64.exe', 'PyCharm', 1),
|
||||
('goland64.exe', 'GoLand', 1),
|
||||
('clion64.exe', 'CLion', 1),
|
||||
('rider64.exe', 'Rider', 1),
|
||||
('trae.exe', 'Trae IDE', 1),
|
||||
('windsurf.exe', 'Windsurf IDE', 1),
|
||||
('cursor.exe', 'Cursor IDE', 1),
|
||||
-- Office & productivity
|
||||
('winword.exe', 'Microsoft Word', 1),
|
||||
('excel.exe', 'Microsoft Excel', 1),
|
||||
('powerpnt.exe', 'Microsoft PowerPoint', 1),
|
||||
('outlook.exe', 'Microsoft Outlook', 1),
|
||||
('onenote.exe', 'Microsoft OneNote', 1),
|
||||
('teams.exe', 'Microsoft Teams', 1),
|
||||
('wps.exe', 'WPS Office', 1),
|
||||
-- Terminal & system tools
|
||||
('cmd.exe', 'Command Prompt', 1),
|
||||
('powershell.exe', 'PowerShell', 1),
|
||||
('pwsh.exe', 'PowerShell Core', 1),
|
||||
('WindowsTerminal.exe', 'Windows Terminal', 1),
|
||||
-- Communication
|
||||
('wechat.exe', 'WeChat', 1),
|
||||
('dingtalk.exe', 'DingTalk', 1),
|
||||
('feishu.exe', 'Feishu/Lark', 1),
|
||||
('qq.exe', 'QQ', 1),
|
||||
('tim.exe', 'Tencent TIM', 1),
|
||||
-- CSM
|
||||
('csm-client.exe', 'CSM Client itself', 1);
|
||||
@@ -1,5 +1,7 @@
|
||||
/**
|
||||
* Shared API client with authentication and error handling
|
||||
* Shared API client with cookie-based authentication.
|
||||
* Tokens are managed via HttpOnly cookies set by the server —
|
||||
* the frontend never reads or stores JWT tokens.
|
||||
*/
|
||||
|
||||
const API_BASE = import.meta.env.VITE_API_BASE || ''
|
||||
@@ -21,43 +23,37 @@ export class ApiError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
function getToken(): string | null {
|
||||
const token = localStorage.getItem('token')
|
||||
if (!token || token.trim() === '') return null
|
||||
return token
|
||||
}
|
||||
|
||||
function clearAuth() {
|
||||
localStorage.removeItem('token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
window.location.href = '/login'
|
||||
}
|
||||
|
||||
let refreshPromise: Promise<boolean> | null = null
|
||||
|
||||
/** Cached user info from /api/auth/me */
|
||||
let cachedUser: { id: number; username: string; role: string } | null = null
|
||||
|
||||
export function getCachedUser() {
|
||||
return cachedUser
|
||||
}
|
||||
|
||||
export function clearCachedUser() {
|
||||
cachedUser = null
|
||||
}
|
||||
|
||||
async function tryRefresh(): Promise<boolean> {
|
||||
// Coalesce concurrent refresh attempts
|
||||
if (refreshPromise) return refreshPromise
|
||||
|
||||
refreshPromise = (async () => {
|
||||
const refreshToken = localStorage.getItem('refresh_token')
|
||||
if (!refreshToken || refreshToken.trim() === '') return false
|
||||
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/auth/refresh`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refresh_token: refreshToken }),
|
||||
credentials: 'same-origin',
|
||||
})
|
||||
|
||||
if (!response.ok) return false
|
||||
|
||||
const result = await response.json()
|
||||
if (!result.success || !result.data?.access_token) return false
|
||||
if (!result.success) return false
|
||||
|
||||
localStorage.setItem('token', result.data.access_token)
|
||||
if (result.data.refresh_token) {
|
||||
localStorage.setItem('refresh_token', result.data.refresh_token)
|
||||
// Update cached user from refresh response
|
||||
if (result.data?.user) {
|
||||
cachedUser = result.data.user
|
||||
}
|
||||
return true
|
||||
} catch {
|
||||
@@ -74,13 +70,8 @@ async function request<T>(
|
||||
path: string,
|
||||
options: RequestInit = {},
|
||||
): Promise<T> {
|
||||
const token = getToken()
|
||||
const headers = new Headers(options.headers || {})
|
||||
|
||||
if (token) {
|
||||
headers.set('Authorization', `Bearer ${token}`)
|
||||
}
|
||||
|
||||
if (options.body && typeof options.body === 'string') {
|
||||
headers.set('Content-Type', 'application/json')
|
||||
}
|
||||
@@ -88,18 +79,21 @@ async function request<T>(
|
||||
const response = await fetch(`${API_BASE}${path}`, {
|
||||
...options,
|
||||
headers,
|
||||
credentials: 'same-origin',
|
||||
})
|
||||
|
||||
// Handle 401 - try refresh before giving up
|
||||
if (response.status === 401) {
|
||||
const refreshed = await tryRefresh()
|
||||
if (refreshed) {
|
||||
// Retry the original request with new token
|
||||
const newToken = getToken()
|
||||
headers.set('Authorization', `Bearer ${newToken}`)
|
||||
const retryResponse = await fetch(`${API_BASE}${path}`, { ...options, headers })
|
||||
const retryResponse = await fetch(`${API_BASE}${path}`, {
|
||||
...options,
|
||||
headers,
|
||||
credentials: 'same-origin',
|
||||
})
|
||||
if (retryResponse.status === 401) {
|
||||
clearAuth()
|
||||
clearCachedUser()
|
||||
window.location.href = '/login'
|
||||
throw new ApiError(401, 'UNAUTHORIZED', 'Session expired')
|
||||
}
|
||||
const retryContentType = retryResponse.headers.get('content-type')
|
||||
@@ -112,7 +106,8 @@ async function request<T>(
|
||||
}
|
||||
return retryResult.data as T
|
||||
}
|
||||
clearAuth()
|
||||
clearCachedUser()
|
||||
window.location.href = '/login'
|
||||
throw new ApiError(401, 'UNAUTHORIZED', 'Session expired')
|
||||
}
|
||||
|
||||
@@ -159,11 +154,12 @@ export const api = {
|
||||
return request<T>(path, { method: 'DELETE' })
|
||||
},
|
||||
|
||||
/** Login doesn't use the auth header */
|
||||
async login(username: string, password: string): Promise<{ access_token: string; refresh_token: string; user: { id: number; username: string; role: string } }> {
|
||||
/** Login — server sets HttpOnly cookies, we only get user info back */
|
||||
async login(username: string, password: string): Promise<{ user: { id: number; username: string; role: string } }> {
|
||||
const response = await fetch(`${API_BASE}/api/auth/login`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'same-origin',
|
||||
body: JSON.stringify({ username, password }),
|
||||
})
|
||||
|
||||
@@ -172,12 +168,27 @@ export const api = {
|
||||
throw new ApiError(response.status, 'LOGIN_FAILED', result.error || 'Login failed')
|
||||
}
|
||||
|
||||
localStorage.setItem('token', result.data.access_token)
|
||||
localStorage.setItem('refresh_token', result.data.refresh_token)
|
||||
cachedUser = result.data.user
|
||||
return result.data
|
||||
},
|
||||
|
||||
logout() {
|
||||
clearAuth()
|
||||
/** Logout — server clears cookies */
|
||||
async logout(): Promise<void> {
|
||||
try {
|
||||
await fetch(`${API_BASE}/api/auth/logout`, {
|
||||
method: 'POST',
|
||||
credentials: 'same-origin',
|
||||
})
|
||||
} catch {
|
||||
// Ignore errors during logout
|
||||
}
|
||||
clearCachedUser()
|
||||
},
|
||||
|
||||
/** Check current auth status via /api/auth/me */
|
||||
async me(): Promise<{ user: { id: number; username: string; role: string }; expires_at: string }> {
|
||||
const result = await request<{ user: { id: number; username: string; role: string }; expires_at: string }>('/api/auth/me')
|
||||
cachedUser = (result as { user: { id: number; username: string; role: string } }).user
|
||||
return result
|
||||
},
|
||||
}
|
||||
|
||||
@@ -27,40 +27,43 @@ const router = createRouter({
|
||||
{ path: 'plugins/print-audit', name: 'PrintAudit', component: () => import('../views/plugins/PrintAudit.vue') },
|
||||
{ path: 'plugins/clipboard-control', name: 'ClipboardControl', component: () => import('../views/plugins/ClipboardControl.vue') },
|
||||
{ path: 'plugins/plugin-control', name: 'PluginControl', component: () => import('../views/plugins/PluginControl.vue') },
|
||||
{ path: 'plugins/patch', name: 'PatchManagement', component: () => import('../views/plugins/PatchManagement.vue') },
|
||||
{ path: 'plugins/anomaly', name: 'AnomalyDetection', component: () => import('../views/plugins/AnomalyDetection.vue') },
|
||||
],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
/** Check if a JWT token is structurally valid and not expired */
|
||||
function isTokenValid(token: string): boolean {
|
||||
if (!token || token.trim() === '') return false
|
||||
try {
|
||||
const parts = token.split('.')
|
||||
if (parts.length !== 3) return false
|
||||
const payload = JSON.parse(atob(parts[1]))
|
||||
if (!payload.exp) return false
|
||||
// Reject if token expires within 30 seconds
|
||||
return payload.exp * 1000 > Date.now() + 30_000
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
/** Track whether we've already validated auth this session */
|
||||
let authChecked = false
|
||||
|
||||
router.beforeEach((to, _from, next) => {
|
||||
router.beforeEach(async (to, _from, next) => {
|
||||
if (to.path === '/login') {
|
||||
next()
|
||||
return
|
||||
}
|
||||
|
||||
const token = localStorage.getItem('token')
|
||||
if (!token || !isTokenValid(token)) {
|
||||
localStorage.removeItem('token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
next('/login')
|
||||
} else {
|
||||
// If we've already verified auth this session, allow navigation
|
||||
// (cookies are sent automatically, no need to check on every route change)
|
||||
if (authChecked) {
|
||||
next()
|
||||
return
|
||||
}
|
||||
|
||||
// Check auth status via /api/auth/me (reads access_token cookie)
|
||||
try {
|
||||
const { me } = await import('../lib/api')
|
||||
await me()
|
||||
authChecked = true
|
||||
next()
|
||||
} catch {
|
||||
next('/login')
|
||||
}
|
||||
})
|
||||
|
||||
/** Reset auth check flag (called after logout) */
|
||||
export function resetAuthCheck() {
|
||||
authChecked = false
|
||||
}
|
||||
|
||||
export default router
|
||||
|
||||
@@ -41,6 +41,32 @@
|
||||
<div class="stat-label">USB事件(24h)</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="stat-card" @click="showHealthDetail = true" style="cursor:pointer">
|
||||
<div class="stat-icon" :class="healthIconClass">
|
||||
<span style="font-size:20px;font-weight:800;line-height:26px">{{ healthAvg }}</span>
|
||||
</div>
|
||||
<div class="stat-info">
|
||||
<div class="stat-value" :class="healthTextClass">{{ healthAvg }}</div>
|
||||
<div class="stat-label">健康评分</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Health overview bar -->
|
||||
<div v-if="healthSummary.total > 0" class="health-bar">
|
||||
<div class="health-bar-segment healthy" :style="{ flex: healthSummary.healthy }" :title="`${healthSummary.healthy} 健康`">
|
||||
<span v-if="healthSummary.healthy > 0">{{ healthSummary.healthy }} 健康</span>
|
||||
</div>
|
||||
<div class="health-bar-segment warning" :style="{ flex: healthSummary.warning }" :title="`${healthSummary.warning} 告警`">
|
||||
<span v-if="healthSummary.warning > 0">{{ healthSummary.warning }} 告警</span>
|
||||
</div>
|
||||
<div class="health-bar-segment critical" :style="{ flex: healthSummary.critical }" :title="`${healthSummary.critical} 严重`">
|
||||
<span v-if="healthSummary.critical > 0">{{ healthSummary.critical }} 严重</span>
|
||||
</div>
|
||||
<div class="health-bar-segment unknown" :style="{ flex: healthSummary.unknown || 0 }" :title="`${healthSummary.unknown || 0} 未知`">
|
||||
<span v-if="(healthSummary.unknown || 0) > 0">{{ healthSummary.unknown }} 未知</span>
|
||||
</div>
|
||||
<div class="health-bar-label">策略冲突: {{ conflictCount }} 项</div>
|
||||
</div>
|
||||
|
||||
<!-- Charts row -->
|
||||
@@ -138,7 +164,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
import { ref, computed, onMounted, onUnmounted } from 'vue'
|
||||
import { Monitor, Platform, Bell, Connection, Top } from '@element-plus/icons-vue'
|
||||
import * as echarts from 'echarts'
|
||||
import { api } from '@/lib/api'
|
||||
@@ -148,6 +174,24 @@ const recentAlerts = ref<Array<{ id: number; severity: string; detail: string; t
|
||||
const recentUsbEvents = ref<Array<{ device_name: string; event_type: string; device_uid: string; event_time: string }>>([])
|
||||
const topDevices = ref<Array<{ hostname: string; cpu_usage: number; memory_usage: number; status: string }>>([])
|
||||
|
||||
const healthSummary = ref<{ total: number; healthy: number; warning: number; critical: number; unknown: number; avg_score: number }>({ total: 0, healthy: 0, warning: 0, critical: 0, unknown: 0, avg_score: 0 })
|
||||
const conflictCount = ref(0)
|
||||
const showHealthDetail = ref(false)
|
||||
|
||||
const healthAvg = computed(() => Math.round(healthSummary.value.avg_score))
|
||||
const healthIconClass = computed(() => {
|
||||
const s = healthAvg.value
|
||||
if (s >= 80) return 'health-good'
|
||||
if (s >= 50) return 'health-warn'
|
||||
return 'health-bad'
|
||||
})
|
||||
const healthTextClass = computed(() => {
|
||||
const s = healthAvg.value
|
||||
if (s >= 80) return 'text-good'
|
||||
if (s >= 50) return 'text-warn'
|
||||
return 'text-bad'
|
||||
})
|
||||
|
||||
const cpuChartRef = ref<HTMLElement>()
|
||||
let chart: echarts.ECharts | null = null
|
||||
let timer: ReturnType<typeof setInterval> | null = null
|
||||
@@ -155,10 +199,12 @@ let resizeHandler: (() => void) | null = null
|
||||
|
||||
async function fetchDashboard() {
|
||||
try {
|
||||
const [devicesData, alertsData, usbData] = await Promise.all([
|
||||
const [devicesData, alertsData, usbData, healthData, conflictData] = await Promise.all([
|
||||
api.get<any>('/api/devices'),
|
||||
api.get<any>('/api/alerts/records?handled=0&page_size=10'),
|
||||
api.get<any>('/api/usb/events?page_size=10'),
|
||||
api.get<any>('/api/dashboard/health-overview').catch(() => null),
|
||||
api.get<any>('/api/policies/conflicts').catch(() => null),
|
||||
])
|
||||
|
||||
const devices = devicesData.devices || []
|
||||
@@ -179,6 +225,16 @@ async function fetchDashboard() {
|
||||
const events = usbData.events || []
|
||||
stats.value.usbEvents = events.length
|
||||
recentUsbEvents.value = events.slice(0, 8)
|
||||
|
||||
// Health overview
|
||||
if (healthData?.summary) {
|
||||
healthSummary.value = healthData.summary
|
||||
}
|
||||
|
||||
// Conflict count
|
||||
if (conflictData?.total !== undefined) {
|
||||
conflictCount.value = conflictData.total
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch dashboard data', e)
|
||||
}
|
||||
@@ -374,4 +430,51 @@ onUnmounted(() => {
|
||||
color: var(--csm-text-tertiary);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
/* Health bar */
|
||||
.health-bar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
height: 32px;
|
||||
border-radius: 6px;
|
||||
overflow: hidden;
|
||||
margin-top: 16px;
|
||||
background: #f1f5f9;
|
||||
font-size: 12px;
|
||||
color: #fff;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.health-bar-segment {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
padding: 0 8px;
|
||||
transition: flex 0.3s ease;
|
||||
}
|
||||
|
||||
.health-bar-segment.healthy { background: #16a34a; }
|
||||
.health-bar-segment.warning { background: #d97706; }
|
||||
.health-bar-segment.critical { background: #dc2626; }
|
||||
.health-bar-segment.unknown { background: #94a3b8; }
|
||||
|
||||
.health-bar-label {
|
||||
position: absolute;
|
||||
right: 12px;
|
||||
font-size: 12px;
|
||||
color: #64748b;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* Health score colors */
|
||||
.stat-icon.health-good { background: #f0fdf4; color: #16a34a; }
|
||||
.stat-icon.health-warn { background: #fffbeb; color: #d97706; }
|
||||
.stat-icon.health-bad { background: #fef2f2; color: #dc2626; }
|
||||
|
||||
.text-good { color: #16a34a !important; }
|
||||
.text-warn { color: #d97706 !important; }
|
||||
.text-bad { color: #dc2626 !important; }
|
||||
</style>
|
||||
|
||||
@@ -143,6 +143,15 @@
|
||||
<el-tag size="small" effect="plain" round>{{ row.group_name || '默认' }}</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column label="健康" width="90" sortable :sort-method="(a: any, b: any) => (a.health_score ?? 0) - (b.health_score ?? 0)">
|
||||
<template #default="{ row }">
|
||||
<div v-if="row.health_score != null" class="health-cell">
|
||||
<span class="health-dot" :class="row.health_level"></span>
|
||||
<span class="health-value" :class="'text-' + healthClass(row.health_score)">{{ row.health_score }}</span>
|
||||
</div>
|
||||
<span v-else class="health-cell">-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column label="CPU" width="100">
|
||||
<template #default="{ row }">
|
||||
<div class="usage-cell">
|
||||
@@ -380,6 +389,13 @@ function getProgressColor(value?: number): string {
|
||||
return '#16a34a'
|
||||
}
|
||||
|
||||
function healthClass(score?: number): string {
|
||||
if (score == null) return 'unknown'
|
||||
if (score >= 80) return 'good'
|
||||
if (score >= 50) return 'warn'
|
||||
return 'bad'
|
||||
}
|
||||
|
||||
function formatTime(t: string | null): string {
|
||||
if (!t) return '-'
|
||||
const d = new Date(t)
|
||||
@@ -795,6 +811,31 @@ async function handleMoveSubmit() {
|
||||
color: var(--csm-text-primary);
|
||||
}
|
||||
|
||||
/* Health cell */
|
||||
.health-cell {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.health-dot {
|
||||
width: 7px;
|
||||
height: 7px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.health-dot.healthy { background: #16a34a; box-shadow: 0 0 4px rgba(22,163,74,0.4); }
|
||||
.health-dot.warning { background: #d97706; }
|
||||
.health-dot.critical { background: #dc2626; box-shadow: 0 0 4px rgba(220,38,38,0.3); }
|
||||
.health-dot.unknown { background: #94a3b8; }
|
||||
|
||||
.health-value { font-size: 13px; font-weight: 600; }
|
||||
.text-good { color: #16a34a; }
|
||||
.text-warn { color: #d97706; }
|
||||
.text-bad { color: #dc2626; }
|
||||
.text-unknown { color: #94a3b8; }
|
||||
|
||||
/* Pagination */
|
||||
.pagination-bar {
|
||||
display: flex;
|
||||
|
||||
@@ -76,6 +76,12 @@
|
||||
<el-menu-item index="/plugins/plugin-control">
|
||||
<template #title><span>插件控制</span></template>
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/plugins/patch">
|
||||
<template #title><span>补丁管理</span></template>
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/plugins/anomaly">
|
||||
<template #title><span>异常检测</span></template>
|
||||
</el-menu-item>
|
||||
</el-sub-menu>
|
||||
|
||||
<el-menu-item index="/settings">
|
||||
@@ -143,7 +149,8 @@ import {
|
||||
Monitor, Platform, Connection, Bell, Setting,
|
||||
ArrowDown, Grid, Expand, Fold, SwitchButton
|
||||
} from '@element-plus/icons-vue'
|
||||
import { api } from '@/lib/api'
|
||||
import { api, getCachedUser } from '@/lib/api'
|
||||
import { resetAuthCheck } from '@/router'
|
||||
|
||||
const route = useRoute()
|
||||
const router = useRouter()
|
||||
@@ -153,15 +160,8 @@ const currentRoute = computed(() => route.path)
|
||||
const unreadAlerts = ref(0)
|
||||
const username = ref('')
|
||||
|
||||
function decodeUsername(): string {
|
||||
try {
|
||||
const token = localStorage.getItem('token')
|
||||
if (!token) return ''
|
||||
const payload = JSON.parse(atob(token.split('.')[1]))
|
||||
return payload.username || ''
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
function getCachedUsername(): string {
|
||||
return getCachedUser()?.username || ''
|
||||
}
|
||||
|
||||
async function fetchUnreadAlerts() {
|
||||
@@ -189,18 +189,20 @@ const pageTitles: Record<string, string> = {
|
||||
'/plugins/print-audit': '打印审计',
|
||||
'/plugins/clipboard-control': '剪贴板管控',
|
||||
'/plugins/plugin-control': '插件控制',
|
||||
'/plugins/patch': '补丁管理',
|
||||
'/plugins/anomaly': '异常检测',
|
||||
}
|
||||
|
||||
const pageTitle = computed(() => pageTitles[route.path] || '仪表盘')
|
||||
|
||||
onMounted(() => {
|
||||
username.value = decodeUsername()
|
||||
username.value = getCachedUsername()
|
||||
fetchUnreadAlerts()
|
||||
})
|
||||
|
||||
function handleLogout() {
|
||||
localStorage.removeItem('token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
async function handleLogout() {
|
||||
await api.logout()
|
||||
resetAuthCheck()
|
||||
router.push('/login')
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -85,7 +85,7 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onMounted } from 'vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { api } from '@/lib/api'
|
||||
import { api, getCachedUser } from '@/lib/api'
|
||||
|
||||
const version = ref('0.1.0')
|
||||
const dbInfo = ref('SQLite (WAL mode)')
|
||||
@@ -96,14 +96,11 @@ const pwdForm = reactive({ oldPassword: '', newPassword: '', confirmPassword: ''
|
||||
const pwdLoading = ref(false)
|
||||
|
||||
onMounted(() => {
|
||||
try {
|
||||
const token = localStorage.getItem('token')
|
||||
if (token) {
|
||||
const payload = JSON.parse(atob(token.split('.')[1]))
|
||||
user.username = payload.username || 'admin'
|
||||
user.role = payload.role || 'admin'
|
||||
}
|
||||
} catch (e) { console.error('Failed to decode token for username', e) }
|
||||
const cached = getCachedUser()
|
||||
if (cached) {
|
||||
user.username = cached.username
|
||||
user.role = cached.role
|
||||
}
|
||||
|
||||
api.get<any>('/health')
|
||||
.then((data: any) => {
|
||||
|
||||
90
web/src/views/plugins/AnomalyDetection.vue
Normal file
90
web/src/views/plugins/AnomalyDetection.vue
Normal file
@@ -0,0 +1,90 @@
|
||||
<template>
|
||||
<div class="page-container">
|
||||
<div class="csm-card">
|
||||
<div class="csm-card-header">
|
||||
<span>异常行为检测</span>
|
||||
<el-tag v-if="unhandled > 0" type="danger" effect="light" size="small">{{ unhandled }} 未处理</el-tag>
|
||||
<el-tag v-else type="success" effect="light" size="small">无异常</el-tag>
|
||||
</div>
|
||||
<div class="csm-card-body">
|
||||
<el-table :data="alerts" v-loading="loading" size="small" max-height="520">
|
||||
<el-table-column prop="hostname" label="终端" width="140" />
|
||||
<el-table-column label="异常类型" width="180">
|
||||
<template #default="{ row }">
|
||||
<span>{{ anomalyLabel(row.anomaly_type) }}</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column label="严重性" width="90">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="severityType(row.severity)" size="small" effect="light">{{ row.severity }}</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="detail" label="详情" min-width="300" show-overflow-tooltip />
|
||||
<el-table-column prop="triggered_at" label="检测时间" width="160" />
|
||||
</el-table>
|
||||
<div v-if="alerts.length === 0 && !loading" style="padding:40px 0;text-align:center;color:#94a3b8">
|
||||
暂无异常行为告警,系统运行正常
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { api } from '@/lib/api'
|
||||
|
||||
const alerts = ref<any[]>([])
|
||||
const loading = ref(false)
|
||||
const unhandled = ref(0)
|
||||
|
||||
const anomalyLabels: Record<string, string> = {
|
||||
night_clipboard_spike: '非工作时间剪贴板异常',
|
||||
usb_file_exfiltration: 'USB文件大量拷贝',
|
||||
high_print_volume: '打印量异常',
|
||||
process_spawn_spike: '进程启动频率异常',
|
||||
}
|
||||
|
||||
function anomalyLabel(type: string): string {
|
||||
return anomalyLabels[type] || type
|
||||
}
|
||||
|
||||
function severityType(s: string): string {
|
||||
if (s === 'critical') return 'danger'
|
||||
if (s === 'high') return 'warning'
|
||||
if (s === 'medium') return ''
|
||||
return 'info'
|
||||
}
|
||||
|
||||
async function fetchData() {
|
||||
loading.value = true
|
||||
try {
|
||||
const data = await api.get<any>('/api/plugins/anomaly/alerts?page_size=50')
|
||||
alerts.value = data.alerts || []
|
||||
unhandled.value = data.unhandled_count || 0
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch anomaly alerts', e)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(fetchData)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.csm-card-header {
|
||||
font-weight: 600;
|
||||
font-size: 15px;
|
||||
color: var(--csm-text-primary);
|
||||
padding: 16px 20px;
|
||||
border-bottom: 1px solid var(--csm-border-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 12px;
|
||||
}
|
||||
.csm-card-body {
|
||||
padding: 16px 20px;
|
||||
}
|
||||
</style>
|
||||
92
web/src/views/plugins/PatchManagement.vue
Normal file
92
web/src/views/plugins/PatchManagement.vue
Normal file
@@ -0,0 +1,92 @@
|
||||
<template>
|
||||
<div class="page-container">
|
||||
<div class="csm-card">
|
||||
<div class="csm-card-header">
|
||||
<span>补丁管理</span>
|
||||
<el-tag type="info" effect="plain" size="small">{{ summary.total_installed }} 已安装 / {{ summary.total_missing }} 缺失</el-tag>
|
||||
</div>
|
||||
<div class="csm-card-body">
|
||||
<el-table :data="patches" v-loading="loading" size="small" max-height="520">
|
||||
<el-table-column prop="hostname" label="终端" width="140" />
|
||||
<el-table-column prop="kb_id" label="补丁编号" width="120" />
|
||||
<el-table-column prop="title" label="描述" min-width="280" show-overflow-tooltip />
|
||||
<el-table-column prop="severity" label="严重性" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.severity" :type="severityType(row.severity)" size="small" effect="light">{{ row.severity }}</el-tag>
|
||||
<span v-else class="text-muted">-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column label="状态" width="90">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="row.is_installed ? 'success' : 'danger'" size="small" effect="light">
|
||||
{{ row.is_installed ? '已安装' : '缺失' }}
|
||||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="installed_at" label="安装时间" width="120">
|
||||
<template #default="{ row }">{{ row.installed_at || '-' }}</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<div style="display:flex;justify-content:flex-end;padding-top:12px">
|
||||
<el-pagination :total="total" :page-size="pageSize" layout="total, prev, pager, next" @current-change="handlePage" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { api } from '@/lib/api'
|
||||
|
||||
const patches = ref<any[]>([])
|
||||
const loading = ref(false)
|
||||
const total = ref(0)
|
||||
const page = ref(1)
|
||||
const pageSize = 20
|
||||
const summary = ref({ total_installed: 0, total_missing: 0 })
|
||||
|
||||
async function fetchData() {
|
||||
loading.value = true
|
||||
try {
|
||||
const data = await api.get<any>(`/api/plugins/patch/status?page=${page.value}&page_size=${pageSize}`)
|
||||
patches.value = data.patches || []
|
||||
total.value = data.total || 0
|
||||
if (data.summary) summary.value = data.summary
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch patches', e)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function handlePage(p: number) {
|
||||
page.value = p
|
||||
fetchData()
|
||||
}
|
||||
|
||||
function severityType(s: string): string {
|
||||
if (s === 'Critical') return 'danger'
|
||||
if (s === 'Important') return 'warning'
|
||||
return 'info'
|
||||
}
|
||||
|
||||
onMounted(fetchData)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.csm-card-header {
|
||||
font-weight: 600;
|
||||
font-size: 15px;
|
||||
color: var(--csm-text-primary);
|
||||
padding: 16px 20px;
|
||||
border-bottom: 1px solid var(--csm-border-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
}
|
||||
.csm-card-body {
|
||||
padding: 16px 20px;
|
||||
}
|
||||
.text-muted { color: #94a3b8; }
|
||||
</style>
|
||||
255
wiki/SECURITY-AUDIT.md
Normal file
255
wiki/SECURITY-AUDIT.md
Normal file
@@ -0,0 +1,255 @@
|
||||
# CSM 安全审计报告
|
||||
|
||||
> **审计日期**: 2026-04-11 | **审计范围**: 全系统 (Server + Client + Protocol + Frontend)
|
||||
> **方法论**: OWASP Top 10 (2021), CWE Top 25, 手动源码审查 + 攻击者视角分析
|
||||
|
||||
---
|
||||
|
||||
## 执行摘要
|
||||
|
||||
| 严重级别 | 数量 | 说明 |
|
||||
|----------|------|------|
|
||||
| **CRITICAL** | 4 | 可直接被远程利用,导致系统完全沦陷 |
|
||||
| **HIGH** | 12 | 需特定条件但影响重大 |
|
||||
| **MEDIUM** | 12 | 有限影响或需较高权限 |
|
||||
| **LOW** | 8 | 纵深防御建议 |
|
||||
|
||||
**最关键发现**: JWT Secret 硬编码在版本控制中、注册 Token 为空允许任意设备注册、凭据文件无 ACL 保护、默认无 TLS 传输加密。这四个问题组合意味着攻击者可以在数分钟内完全接管系统。
|
||||
|
||||
---
|
||||
|
||||
## CRITICAL (4)
|
||||
|
||||
### AUD-001: JWT Secret 硬编码在 config.toml 中
|
||||
|
||||
- **文件**: `config.toml:12` → `crates/server/src/config.rs`
|
||||
- **CWE**: CWE-798 (硬编码凭证)
|
||||
- **OWASP**: A07:2021 - 安全配置错误
|
||||
|
||||
**漏洞代码**:
|
||||
```toml
|
||||
jwt_secret = "39ffc129-dd62-4eb4-bbc0-8bf4b8e2ccc7"
|
||||
```
|
||||
|
||||
**攻击场景**: 任何能访问仓库的人可提取 secret,为任意用户(含 admin)伪造 JWT,获得系统完全管理控制权,可推送恶意配置到所有医院终端、禁用安全控制、篡改审计记录。
|
||||
|
||||
**修复**:
|
||||
1. 从 `config.toml` 中移除硬编码 secret
|
||||
2. 通过 `CSM_JWT_SECRET` 环境变量独占加载
|
||||
3. secret 为空时拒绝启动
|
||||
4. **立即轮换**已泄露的 secret `39ffc129-dd62-4eb4-bbc0-8bf4b8e2ccc7`
|
||||
5. 将 `config.toml` 加入 `.gitignore`
|
||||
|
||||
---
|
||||
|
||||
### AUD-002: 空 Registration Token — 任意设备注册
|
||||
|
||||
- **文件**: `config.toml:1`, `crates/server/src/tcp.rs:549-558`
|
||||
- **CWE**: CWE-306 (关键功能缺失认证)
|
||||
- **OWASP**: A07:2021
|
||||
|
||||
**漏洞代码**:
|
||||
```toml
|
||||
registration_token = ""
|
||||
```
|
||||
```rust
|
||||
if !expected_token.is_empty() { // 空字符串直接跳过验证
|
||||
```
|
||||
|
||||
**攻击场景**: TCP 端口 9999 可达的任何攻击者可注册恶意设备,注入伪造审计数据掩盖安全事件,获取所有插件配置(含安全策略)。
|
||||
|
||||
**修复**: 设置强 registration token,空 token 时拒绝启动。
|
||||
|
||||
---
|
||||
|
||||
### AUD-003: Windows 凭据文件无 ACL 保护
|
||||
|
||||
- **文件**: `crates/client/src/main.rs:280-283`
|
||||
- **CWE**: CWE-732 (关键资源权限不当)
|
||||
- **OWASP**: A01:2021
|
||||
|
||||
**漏洞代码**:
|
||||
```rust
|
||||
#[cfg(not(unix))]
|
||||
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
|
||||
std::fs::write(path, content) // 无任何 ACL 设置
|
||||
}
|
||||
```
|
||||
|
||||
**攻击场景**: `device_secret.txt`(HMAC 密钥)和 `device_uid.txt`(设备身份)以默认权限写入,设备上任何用户进程可读取。攻击者可提取密钥伪造心跳,在不同机器上模拟设备。
|
||||
|
||||
**修复**: 使用 `icacls` 或 `SetSecurityInfo` 设置仅 SYSTEM 可访问的 ACL。
|
||||
|
||||
---
|
||||
|
||||
### AUD-004: 默认无 TLS — 明文传输所有敏感数据
|
||||
|
||||
- **文件**: `config.toml` (无 `[server.tls]`), `crates/server/src/tcp.rs:411-414`
|
||||
- **CWE**: CWE-319 (明文传输敏感信息)
|
||||
- **OWASP**: A02:2021
|
||||
|
||||
**攻击场景**: TCP 端口 9999 以明文运行。`device_secret`、所有插件配置、Web 过滤规则、USB 策略、软件黑名单均以明文 JSON 传输。网络嗅探者可提取设备认证密钥并逆向工程安全策略。
|
||||
|
||||
**修复**: 生产环境强制 TLS,无 TLS 时拒绝启动(`CSM_DEV=1` 除外)。
|
||||
|
||||
---
|
||||
|
||||
## HIGH (12)
|
||||
|
||||
### AUD-005: Refresh Token 未存储 — 撤销机制不完整
|
||||
|
||||
- **文件**: `crates/server/src/api/auth.rs:131-133`
|
||||
- **CWE**: CWE-613
|
||||
- `refresh_tokens` 表存在但从未写入。登录不存储 token record,刷新仅检查 family 是否撤销,不验证 token 是否曾被实际颁发。无法强制注销所有 session。
|
||||
|
||||
### AUD-006: Refresh Token Family Rotation 存在 TOCTOU 竞争条件
|
||||
|
||||
- **文件**: `crates/server/src/api/auth.rs:167-183`
|
||||
- **CWE**: CWE-367
|
||||
- 检查 family 撤销状态与执行撤销之间无事务保护。并发使用同一 stolen token 的两个请求均可通过检查,攻击者获得全新的未撤销 token family。
|
||||
|
||||
### AUD-007: 客户端不验证服务器身份
|
||||
|
||||
- **文件**: `crates/client/src/network/mod.rs:149-162`
|
||||
- **CWE**: CWE-295
|
||||
- 客户端盲目连接任何响应的服务器。ARP/DNS 欺骗攻击者可推送恶意配置(禁用所有安全插件、注入有害规则)。HMAC 仅保护心跳,不保护配置推送。
|
||||
|
||||
### AUD-008: PowerShell 命令注入面
|
||||
|
||||
- **文件**: `crates/client/src/asset/mod.rs:82-83`, `crates/client/src/clipboard_control/mod.rs:143-159`
|
||||
- **CWE**: CWE-78
|
||||
- `powershell_lines()` 通过 `format!()` 拼接命令参数。若服务器推送含引号/转义字符的恶意规则,可能导致 PowerShell 命令注入。
|
||||
|
||||
### AUD-009: 服务停止/卸载未受保护
|
||||
|
||||
- **文件**: `crates/client/src/service.rs:17-69`
|
||||
- **CWE**: CWE-284
|
||||
- `csm-client.exe --uninstall` 无认证保护。终端管理员权限用户可完全移除安全代理,绕过所有监控。无服务恢复策略、无看门狗进程、无反调试保护。
|
||||
|
||||
### AUD-010: JWT Token 存储在 localStorage
|
||||
|
||||
- **文件**: `web/src/lib/api.ts:25-28, 175-176`
|
||||
- **CWE**: CWE-922
|
||||
- Access token 和 refresh token 均存储在 `localStorage`,可被同源任意 JS 访问。XSS 漏洞可直接窃取 7 天有效期的 refresh token。
|
||||
|
||||
### AUD-011: CSP 允许 unsafe-inline + unsafe-eval
|
||||
|
||||
- **文件**: `crates/server/src/main.rs:142-143`
|
||||
- **CWE**: CWE-693
|
||||
- `script-src 'self' 'unsafe-inline' 'unsafe-eval'` 使 CSP 对 XSS 几乎无效。结合 localStorage token 存储,单个 XSS 即可导致管理员会话完全沦陷。
|
||||
|
||||
### AUD-012: WebSocket JWT 在 URL 查询参数中
|
||||
|
||||
- **文件**: `crates/server/src/ws.rs:36-73`
|
||||
- **CWE**: CWE-312
|
||||
- JWT 通过 `/ws?token=eyJ...` 传输。Token 出现在浏览器历史、服务器访问日志、代理日志中。且 WebSocket handler 不检查用户角色,非管理员可接收所有广播事件。
|
||||
|
||||
### AUD-013: 告警规则 Webhook SSRF
|
||||
|
||||
- **文件**: `crates/server/src/api/alerts.rs:115-131`
|
||||
- **CWE**: CWE-918
|
||||
- `notify_webhook` 字段无 URL 验证。可设置为 `http://169.254.169.254/latest/meta-data/` (AWS 元数据) 或 `file:///etc/passwd`,将服务器变成 SSRF 代理。
|
||||
|
||||
### AUD-014: 仅基于用户名的速率限制可绕过
|
||||
|
||||
- **文件**: `crates/server/src/api/auth.rs:101`
|
||||
- **CWE**: CWE-307
|
||||
- 速率限制仅以用户名为 key。攻击者可用 `Admin`、`ADMIN` 等变体绕过。无 IP 限制,可分布式暴力破解。
|
||||
|
||||
### AUD-015: 磁盘加密确认在只读路由层 — 权限提升
|
||||
|
||||
- **文件**: `crates/server/src/api/plugins/mod.rs:42`
|
||||
- **CWE**: CWE-862
|
||||
- PUT `acknowledge_alert` 在 `read_routes()` 中(仅需认证,不需 admin)。任何认证用户可确认(忽略)加密告警,掩盖合规违规。
|
||||
|
||||
### AUD-016: 初始管理员密码输出到 stderr
|
||||
|
||||
- **文件**: `crates/server/src/main.rs:270-275`
|
||||
- **CWE**: CWE-532
|
||||
- 初始密码通过 `eprintln!` 输出。容器化部署中 stderr 被日志聚合系统捕获,有日志访问权限者可获取管理员密码。
|
||||
|
||||
---
|
||||
|
||||
## MEDIUM (12)
|
||||
|
||||
| # | 发现 | 文件 | CWE |
|
||||
|---|------|------|-----|
|
||||
| AUD-017 | 多个 Update handler 跳过输入验证 (软件黑名单/Web过滤器/剪贴板) | `software_blocker.rs:79`, `web_filter.rs:62`, `clipboard_control.rs:107` | CWE-20 |
|
||||
| AUD-018 | USB 策略 rules 字段接受任意 JSON 无验证 | `usb.rs:94-135` | CWE-20 |
|
||||
| AUD-019 | 密码无最大长度限制 (bcrypt 72 字节截断) | `auth.rs:303` | CWE-20 |
|
||||
| AUD-020 | 多个字段缺少长度验证 (弹出窗口/剪贴板/USB策略名) | 多处 | CWE-20 |
|
||||
| AUD-021 | 多个列表端点无分页 (黑名单/白名单/规则/策略) | 多处 | CWE-770 |
|
||||
| AUD-022 | 磁盘加密状态列表无分页可全库转储 | `disk_encryption.rs:12-55` | CWE-770 |
|
||||
| AUD-023 | JWT 角色仅信任 claim 不查库 (降级延迟) | `auth.rs:273` | CWE-863 |
|
||||
| AUD-024 | 缺少 HSTS 头 | `main.rs:123-143` | CWE-319 |
|
||||
| AUD-025 | CORS 配置需严格限制 | `main.rs:284-301` | CWE-942 |
|
||||
| AUD-026 | 日志中泄露设备 UID 和服务器地址 | `main.rs:62`, `network/mod.rs:34` | CWE-532 |
|
||||
| AUD-027 | 注册 Token 回退空字符串 | `main.rs:72` | CWE-254 |
|
||||
| AUD-028 | conflict.rs 中 format! SQL 模式 (当前安全但脆弱) | `conflict.rs:205` | CWE-89 |
|
||||
|
||||
---
|
||||
|
||||
## LOW (8)
|
||||
|
||||
| # | 发现 | 文件 |
|
||||
|---|------|------|
|
||||
| AUD-029 | 组名未过滤 HTML 特殊字符 | `groups.rs:72-96` |
|
||||
| AUD-030 | 弹出窗口阻止器更新可创建无过滤器的规则 | `popup_blocker.rs:67-104` |
|
||||
| AUD-031 | 设备删除非原子 (自毁帧在事务前发送) | `devices.rs:215-306` |
|
||||
| AUD-032 | 受保护进程列表硬编码且可修补绕过 | `software_blocker/mod.rs:9-73` |
|
||||
| AUD-033 | hosts 文件修改可能与 EDR 冲突 | `web_filter/mod.rs:59-93` |
|
||||
| AUD-034 | 软件拦截器 TOCTOU 竞争条件 (已缓解) | `software_blocker/mod.rs:329-386` |
|
||||
| AUD-035 | 前端路由守卫不验证 JWT 签名 | `router/index.ts:38-49` |
|
||||
| AUD-036 | WebSocket 不验证入站消息 (当前丢弃) | `ws.rs:108-119` |
|
||||
|
||||
---
|
||||
|
||||
## 修复优先级
|
||||
|
||||
### P0 — 立即 (24h)
|
||||
|
||||
| 修复项 | 对应发现 | 工作量 |
|
||||
|--------|---------|--------|
|
||||
| 轮换 JWT Secret,移至环境变量 | AUD-001 | 30min |
|
||||
| 设置非空 registration_token | AUD-002 | 15min |
|
||||
| 凭据文件添加 Windows ACL | AUD-003 | 1h |
|
||||
| 生产环境强制 TLS | AUD-004 | 2h |
|
||||
|
||||
### P1 — 短期 (1 周)
|
||||
|
||||
| 修复项 | 对应发现 | 工作量 |
|
||||
|--------|---------|--------|
|
||||
| Refresh token 存储到 DB + 事务保护 | AUD-005, 006 | 4h |
|
||||
| Update handler 添加输入验证 | AUD-017 | 4h |
|
||||
| Webhook URL 验证防 SSRF | AUD-013 | 1h |
|
||||
| 磁盘加密确认移至 admin 路由 | AUD-015 | 15min |
|
||||
| 初始密码写入文件替代 stderr | AUD-016 | 30min |
|
||||
| 添加 IP 速率限制 | AUD-014 | 2h |
|
||||
|
||||
### P2 — 中期 (1 月)
|
||||
|
||||
| 修复项 | 对应发现 | 工作量 |
|
||||
|--------|---------|--------|
|
||||
| Token 迁移至 HttpOnly Cookie | AUD-010 | 8h |
|
||||
| CSP 强化 (nonce-based) | AUD-011 | 4h |
|
||||
| WebSocket ticket 认证 | AUD-012 | 4h |
|
||||
| 服务器身份验证 (证书固定) | AUD-007 | 8h |
|
||||
| 服务保护 (恢复策略/看门狗) | AUD-009 | 4h |
|
||||
| PowerShell 注入面消除 | AUD-008 | 6h |
|
||||
| HSTS + Permissions-Policy 头 | AUD-024, 036 | 1h |
|
||||
| 分页补充 | AUD-021, 022 | 4h |
|
||||
|
||||
---
|
||||
|
||||
## 安全亮点 (做得好的地方)
|
||||
|
||||
1. **SQL 注入防御**: 全库一致使用 `sqlx::bind()` 参数化
|
||||
2. **错误处理**: `ApiResponse::internal_error()` 不泄露内部错误详情
|
||||
3. **密码哈希**: bcrypt cost=12,符合行业标准
|
||||
4. **Token Family 轮换**: 检测 token 重放并撤销整个 family
|
||||
5. **常量时间比较**: 注册 token 验证已使用 `constant_time_eq()`
|
||||
6. **帧速率限制**: 100 帧/5秒/连接
|
||||
7. **审计日志**: 所有管理员写入操作记录到 `admin_audit_log`
|
||||
8. **HMAC 心跳**: 设备认证使用 HMAC-SHA256
|
||||
9. **进程保护列表**: 防止误杀系统关键进程
|
||||
10. **输入验证**: Create handler 普遍有字段验证
|
||||
88
wiki/client.md
Normal file
88
wiki/client.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Client(客户端代理)
|
||||
|
||||
## 设计思想
|
||||
|
||||
`csm-client` 是部署在医院终端设备上的 Windows 代理程序,设计为:
|
||||
1. **无人值守运行** — 支持控制台模式(开发调试)和 Windows 服务模式(生产部署)
|
||||
2. **自动重连** — 指数退避策略(1s → 60s),断线后 drain stale frames
|
||||
3. **插件化采集** — 每个插件独立 task,通过 `watch` channel 接收配置,通过 `mpsc` channel 上报数据
|
||||
4. **单入口 data channel** — 所有插件共享一个 `mpsc::channel::<Frame>(1024)`,network 模块统一发送
|
||||
|
||||
关键设计决策:
|
||||
- **watch + mpsc 双通道** — `watch` 用于服务器推送配置到插件(多消费者最新值),`mpsc` 用于插件上报数据到网络层(多生产者有序队列)
|
||||
- **device_uid 持久化** — UUID 首次生成后写入 `device_uid.txt`,与可执行文件同目录
|
||||
- **device_secret 持久化** — 注册成功后写入 `device_secret.txt`,重启后自动认证
|
||||
|
||||
## 代码逻辑
|
||||
|
||||
### 启动流程
|
||||
|
||||
```
|
||||
main() → load device_uid → load device_secret → create ClientState
|
||||
→ create data channel (mpsc 1024)
|
||||
→ create watch channels for each plugin config
|
||||
→ spawn core tasks (monitor, asset, usb)
|
||||
→ spawn plugin tasks (11 plugins)
|
||||
→ reconnect loop: connect_and_run() with exponential backoff
|
||||
```
|
||||
|
||||
### 网络层 (`network/mod.rs`)
|
||||
|
||||
- `connect_and_run()` — TCP 连接、注册/认证、双工读写循环
|
||||
- `handle_server_message()` — 根据 MessageType 分发服务器下发的帧到对应 watch channel
|
||||
- `PluginChannels` — 持有所有插件的 `watch::Sender`,用于接收服务器推送的配置
|
||||
- 注册流程:发送 Register → 收到 RegisterResponse(含 device_secret)→ 持久化 secret
|
||||
- 认证流程:已有 device_secret 时,心跳帧携带 HMAC-SHA256 签名
|
||||
|
||||
### 插件统一模板
|
||||
|
||||
每个插件遵循相同模式:
|
||||
```rust
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<PluginConfig>,
|
||||
data_tx: mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => { /* 更新 config */ }
|
||||
_ = interval.tick() => {
|
||||
if !config.enabled { continue; }
|
||||
// 采集数据 → Frame::new_json() → data_tx.send()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 双模式运行
|
||||
|
||||
- **控制台模式**: 直接 `cargo run -p csm-client`,Ctrl+C 优雅退出
|
||||
- **服务模式**: `--install` 注册 Windows 服务、`--service` 以服务方式运行、`--uninstall` 卸载
|
||||
|
||||
## 关联模块
|
||||
|
||||
- [[protocol]] — 使用 Frame 构造上报帧,解析服务器下发帧
|
||||
- [[server]] — TCP 连接的对端,接收帧并处理
|
||||
- [[plugins]] — 每个插件的具体实现逻辑
|
||||
|
||||
## 关键文件
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `crates/client/src/main.rs` | 启动入口、插件 channel 创建、task spawn、重连循环 |
|
||||
| `crates/client/src/network/mod.rs` | TCP 连接、注册认证、双工读写、服务器消息分发 |
|
||||
| `crates/client/src/service.rs` | Windows 服务安装/卸载/运行(`#[cfg(target_os = "windows")]`) |
|
||||
| `crates/client/src/monitor/mod.rs` | 核心设备状态采集(CPU/内存/进程) |
|
||||
| `crates/client/src/asset/mod.rs` | 硬件/软件资产采集 |
|
||||
| `crates/client/src/usb/mod.rs` | USB 设备插拔监控 |
|
||||
| `crates/client/src/web_filter/mod.rs` | 上网拦截插件 |
|
||||
| `crates/client/src/usage_timer/mod.rs` | 使用时长记录插件 |
|
||||
| `crates/client/src/software_blocker/mod.rs` | 软件禁止安装插件 |
|
||||
| `crates/client/src/popup_blocker/mod.rs` | 弹窗拦截插件 |
|
||||
| `crates/client/src/usb_audit/mod.rs` | U盘文件操作审计插件 |
|
||||
| `crates/client/src/watermark/mod.rs` | 屏幕水印插件 |
|
||||
| `crates/client/src/disk_encryption/mod.rs` | 磁盘加密检测插件 |
|
||||
| `crates/client/src/print_audit/mod.rs` | 打印审计插件 |
|
||||
| `crates/client/src/clipboard_control/mod.rs` | 剪贴板管控插件 |
|
||||
| `crates/client/src/patch/mod.rs` | 补丁管理插件 |
|
||||
71
wiki/database.md
Normal file
71
wiki/database.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Database(数据库层)
|
||||
|
||||
## 设计思想
|
||||
|
||||
SQLite 单文件数据库,WAL 模式支持并发读写。设计原则:
|
||||
1. **只追加迁移** — 永不修改已有 migration 文件
|
||||
2. **参数绑定** — 所有 SQL 使用 `.bind()`,绝不拼接
|
||||
3. **upsert 模式** — `ON CONFLICT ... DO UPDATE` 处理重复上报,必须更新 `updated_at`
|
||||
4. **嵌入式迁移** — SQL 文件通过 `include_str!` 编译进二进制,运行时按序执行
|
||||
5. **外键启用** — `foreign_keys(true)` 强制引用完整性
|
||||
|
||||
## 代码逻辑
|
||||
|
||||
### 初始化
|
||||
|
||||
```
|
||||
main() → init_database() → SQLite WAL + Normal sync + 5s busy timeout + FK on
|
||||
→ run_migrations() → CREATE _migrations 表 → 按序执行 001-018
|
||||
→ ensure_default_admin() → 首次启动生成随机 admin 密码
|
||||
```
|
||||
|
||||
### 连接池配置
|
||||
|
||||
- 最大 8 连接
|
||||
- cache_size = -64000 (64MB)
|
||||
- wal_autocheckpoint = 1000
|
||||
|
||||
### 迁移历史
|
||||
|
||||
| # | 文件 | 内容 |
|
||||
|---|------|------|
|
||||
| 001 | init.sql | users, devices 表 |
|
||||
| 002 | assets.sql | hardware_assets, software_assets, asset_changes 表 |
|
||||
| 003 | usb.sql | usb_events, usb_policies, usb_device_patterns 表 |
|
||||
| 004 | alerts.sql | alert_rules, alert_records 表 |
|
||||
| 005 | web_filter.sql | web_filter_rules, web_access_logs 表 |
|
||||
| 006 | usage_timer.sql | usage_daily, app_usage 表 |
|
||||
| 007 | software_blocker.sql | software_blacklist, software_violations 表 |
|
||||
| 008 | popup_blocker.sql | popup_blocker_rules, popup_block_stats 表 |
|
||||
| 009 | usb_file_audit.sql | usb_file_operations 表 |
|
||||
| 010 | watermark.sql | watermark_configs 表 |
|
||||
| 011 | token_security.sql | token_families 表(JWT token family 轮换) |
|
||||
| 012 | disk_encryption.sql | disk_encryption_status, disk_encryption_alerts 表 |
|
||||
| 013 | print_audit.sql | print_events 表 |
|
||||
| 014 | clipboard_control.sql | clipboard_rules, clipboard_violations 表 |
|
||||
| 015 | plugin_control.sql | plugin_states 表 |
|
||||
| 016 | encryption_alerts_unique.sql | 唯一约束修复 |
|
||||
| 017 | device_health_scores.sql | device_health_scores 表 |
|
||||
| 018 | patch_management.sql | patch_status 表 |
|
||||
|
||||
### 数据操作层 (`db.rs`)
|
||||
|
||||
`DeviceRepo` 提供:
|
||||
- 设备注册/查询/删除/分组
|
||||
- 资产增删改查
|
||||
- USB 事件记录和策略管理
|
||||
- 告警规则和记录操作
|
||||
- 所有插件数据的 CRUD
|
||||
|
||||
## 关联模块
|
||||
|
||||
- [[server]] — 通过 db.rs 访问数据库
|
||||
- [[plugins]] — 每个插件有对应的数据库表
|
||||
|
||||
## 关键文件
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `crates/server/src/db.rs` | DeviceRepo 数据库操作方法集合 |
|
||||
| `crates/server/src/main.rs` | 数据库初始化、迁移执行 |
|
||||
| `migrations/001_init.sql` ~ `018_*.sql` | 数据库迁移脚本 |
|
||||
46
wiki/index.md
Normal file
46
wiki/index.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# CSM 知识库
|
||||
|
||||
## 项目画像
|
||||
|
||||
CSM (Client Security Manager) — 医院终端安全管控平台,C/S + Web 三层架构。管理 11 个安全插件,覆盖上网拦截、U盘管控、打印审计、剪贴板管控、补丁管理等场景。
|
||||
|
||||
**关键数字**: 3 个 Rust crate + Vue 前端 | 18 个数据库迁移 | 13 个客户端插件 | ~30 个 API 端点 | 自定义 TCP 二进制协议
|
||||
|
||||
## 模块导航树
|
||||
|
||||
```
|
||||
CSM
|
||||
├── [[protocol]] — 二进制协议层(Frame 编解码、MessageType、payload 定义)
|
||||
├── [[server]] — 服务端(HTTP API + TCP 接入 + WebSocket + SQLite)
|
||||
├── [[client]] — 客户端代理(Windows 服务、插件采集、自动重连)
|
||||
├── [[web-frontend]] — Web 管理面板(Vue 3 SPA)
|
||||
├── [[plugins]] — 插件体系(端到端设计、新增插件清单)
|
||||
└── [[database]] — 数据库层(SQLite、迁移、操作方法)
|
||||
```
|
||||
|
||||
## 核心架构决策
|
||||
|
||||
### 为什么用自定义 TCP 二进制协议而不是 HTTP?
|
||||
内网环境低延迟需求,二进制帧比 HTTP 更省带宽和延迟。帧头仅 10 字节(MAGIC+VERSION+TYPE+LENGTH),payload 用 JSON 保持可调试性。
|
||||
|
||||
### 为什么插件配置用 watch channel 而不是 HTTP 轮询?
|
||||
Server 主动推送配置变更到 Client,避免轮询延迟。`tokio::watch` 保证每个插件总是读到最新配置值,配置下发 → 全链路秒级生效。
|
||||
|
||||
### 为什么嵌入前端而不是独立部署?
|
||||
`include_dir!` 编译时打包 `web/dist/`,部署只需一个 server 二进制文件。SPA fallback 让前端路由(如 `/devices`)直接返回 `index.html`。
|
||||
|
||||
### 为什么 SQLite 而不是 PostgreSQL?
|
||||
医院内网单机部署场景,零外部依赖。WAL 模式 + 64MB 缓存足以支撑数百台终端的并发写入。
|
||||
|
||||
### 为什么三级作用域推送(global/group/device)?
|
||||
医院按科室分组管理设备。全局策略作为基线,科室策略覆盖特定需求,单设备策略处理例外情况。`push_to_targets()` 自动解析作用域并过滤在线设备。
|
||||
|
||||
## 技术栈速查
|
||||
|
||||
| 层 | 技术 |
|
||||
|----|------|
|
||||
| 服务端 | Rust + Axum + SQLx + SQLite + JWT + Rustls |
|
||||
| 客户端 | Rust + Tokio + sysinfo + windows-rs |
|
||||
| 协议 | 自定义 TCP 二进制(MAGIC + VERSION + TYPE + LENGTH + JSON payload) |
|
||||
| 前端 | Vue 3 + TypeScript + Vite + Element Plus + Pinia + ECharts |
|
||||
| 构建 | Cargo workspace + npm |
|
||||
78
wiki/plugins.md
Normal file
78
wiki/plugins.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Plugin System(插件体系)
|
||||
|
||||
## 设计思想
|
||||
|
||||
CSM 的核心扩展机制,采用**端到端插件化**设计:
|
||||
- Client 端每个插件独立 tokio task,负责数据采集/策略执行
|
||||
- Server 端每个插件有独立的 API handler 模块和数据库表
|
||||
- Protocol 层每个插件有专属 MessageType 范围和 payload struct
|
||||
- Frontend 端每个插件有独立页面组件
|
||||
|
||||
三级配置推送:`global` → `group` → `device`,优先级递增。
|
||||
|
||||
## 代码逻辑
|
||||
|
||||
### 插件全链路(以 Web Filter 为例)
|
||||
|
||||
```
|
||||
1. API: POST /api/plugins/web-filter/rules → server/api/plugins/web_filter.rs
|
||||
2. Server 存储 → db.rs → INSERT INTO web_filter_rules
|
||||
3. 推送 → push_to_targets(db, clients, WebFilterRuleUpdate, payload, scope, target_id)
|
||||
4. TCP → Client network/mod.rs handle_server_message() → web_filter_tx.send()
|
||||
5. Client → web_filter/mod.rs config_rx.changed() → 更新本地规则 → 采集上报
|
||||
6. Client → Frame::new_json(WebAccessLog, entry) → data_tx → network → TCP → server
|
||||
7. Server → tcp.rs process_frame(WebAccessLog) → db.rs → INSERT INTO web_access_logs
|
||||
8. Frontend → GET /api/plugins/web-filter/log → 展示
|
||||
```
|
||||
|
||||
### 现有插件一览
|
||||
|
||||
| 插件 | 消息类型范围 | 方向 | 功能 |
|
||||
|------|-------------|------|------|
|
||||
| Web Filter | 0x2x | S→C 规则, C→S 日志 | URL 黑白名单、访问日志 |
|
||||
| Usage Timer | 0x3x | C→S 报告 | 每日使用时长、应用使用统计 |
|
||||
| Software Blocker | 0x4x | S→C 黑名单, C→S 违规 | 禁止安装软件、违规上报 |
|
||||
| Popup Blocker | 0x5x | S→C 规则, C→S 统计 | 弹窗拦截规则、拦截统计 |
|
||||
| USB File Audit | 0x6x | C→S 记录 | U盘文件操作审计 |
|
||||
| Watermark | 0x70 | S→C 配置 | 屏幕水印显示配置 |
|
||||
| USB Policy | 0x71 | S→C 策略 | U盘管控(全阻/白名单/黑名单) |
|
||||
| Plugin Control | 0x80-0x81 | S→C 命令 | 远程启停插件 |
|
||||
| Disk Encryption | 0x90, 0x93 | C→S 状态, S→C 配置 | 磁盘加密状态检测 |
|
||||
| Print Audit | 0x91 | C→S 事件 | 打印操作审计 |
|
||||
| Clipboard Control | 0x94-0x95 | S→C 规则, C→S 违规 | 剪贴板操作管控(仅上报元数据) |
|
||||
| Patch Management | 0xA0-0xA2 | 双向 | 系统补丁扫描与安装 |
|
||||
| Behavior Metrics | 0xB0 | C→S 指标 | 行为指标采集(异常检测输入) |
|
||||
|
||||
### 新增插件必改文件清单
|
||||
|
||||
| # | 文件 | 改动 |
|
||||
|---|------|------|
|
||||
| 1 | `crates/protocol/src/message.rs` | 添加 MessageType 枚举值 + payload struct |
|
||||
| 2 | `crates/protocol/src/lib.rs` | re-export 新类型 |
|
||||
| 3 | `crates/client/src/<plugin>/mod.rs` | 创建插件实现 |
|
||||
| 4 | `crates/client/src/main.rs` | `mod <plugin>`, watch channel, PluginChannels 字段, spawn |
|
||||
| 5 | `crates/client/src/network/mod.rs` | PluginChannels 字段, handle_server_message 分支 |
|
||||
| 6 | `crates/server/src/api/plugins/<plugin>.rs` | 创建 API handler |
|
||||
| 7 | `crates/server/src/api/plugins/mod.rs` | mod 声明 + 路由注册 |
|
||||
| 8 | `crates/server/src/tcp.rs` | process_frame 新分支 + push_all_plugin_configs |
|
||||
| 9 | `crates/server/src/db.rs` | 新增 DB 操作方法 |
|
||||
| 10 | `migrations/NNN_<name>.sql` | 新迁移文件 |
|
||||
| 11 | `crates/server/src/main.rs` | include_str! 新迁移 |
|
||||
|
||||
## 关联模块
|
||||
|
||||
- [[protocol]] — 定义插件的 MessageType 和 payload
|
||||
- [[client]] — 插件采集端
|
||||
- [[server]] — 插件 API 和数据处理
|
||||
- [[web-frontend]] — 插件管理页面
|
||||
- [[database]] — 每个插件的数据库表
|
||||
|
||||
## 关键文件
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `crates/client/src/<plugin>/mod.rs` | 客户端插件实现(每个插件一个目录) |
|
||||
| `crates/server/src/api/plugins/<plugin>.rs` | 服务端插件 API(每个插件一个文件) |
|
||||
| `crates/server/src/tcp.rs` | 帧分发 + push_to_targets + push_all_plugin_configs |
|
||||
| `crates/client/src/main.rs` | 插件 watch channel 创建 + task spawn |
|
||||
| `crates/client/src/network/mod.rs` | PluginChannels 定义 + 服务器消息分发 |
|
||||
58
wiki/protocol.md
Normal file
58
wiki/protocol.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Protocol(二进制协议层)
|
||||
|
||||
## 设计思想
|
||||
|
||||
`csm-protocol` 是 Server 和 Client 共享的协议定义 crate。核心设计决策:
|
||||
|
||||
1. **零拷贝编解码** — `Frame::encode()` / `Frame::decode()` 直接操作字节切片,无中间分配
|
||||
2. **类型安全** — `MessageType` 枚举确保所有消息类型在编译期可见,`TryFrom<u8>` 处理未知类型
|
||||
3. **JSON payload** — 网络传输用 JSON(`serde`),兼顾可调试性和跨语言兼容性
|
||||
4. **payload 上限 4MB** — `MAX_PAYLOAD_SIZE` 防止恶意帧耗尽内存
|
||||
|
||||
二进制帧格式:`MAGIC(4B "CSM\0") + VERSION(1B) + TYPE(1B) + LENGTH(4B big-endian) + PAYLOAD(变长 JSON)`
|
||||
|
||||
## 代码逻辑
|
||||
|
||||
### 帧生命周期
|
||||
|
||||
```
|
||||
发送方: T → Frame::new_json(mt, &data) → Frame::encode() → Vec<u8> → TCP stream
|
||||
接收方: TCP bytes → Frame::decode(&buf) → Option<Frame> → Frame::decode_payload::<T>()
|
||||
```
|
||||
|
||||
### MessageType 分块规划
|
||||
|
||||
| 范围 | 插件 | 方向 |
|
||||
|------|------|------|
|
||||
| 0x01-0x0F | Core(心跳/注册/状态/资产) | 双向 |
|
||||
| 0x10-0x1F | Core Server→Client(策略/配置/任务) | S→C |
|
||||
| 0x20-0x2F | Web Filter | C→S 日志, S→C 规则 |
|
||||
| 0x30-0x3F | Usage Timer | C→S 报告 |
|
||||
| 0x40-0x4F | Software Blocker | C→S 违规, S→C 黑名单 |
|
||||
| 0x50-0x5F | Popup Blocker | C→S 统计, S→C 规则 |
|
||||
| 0x60-0x6F | USB File Audit | C→S 操作记录 |
|
||||
| 0x70-0x7F | Watermark + USB Policy | S→C 配置 |
|
||||
| 0x80-0x8F | Plugin Control | S→C 启停命令 |
|
||||
| 0x90-0x9F | Disk Encryption / Print / Clipboard | 混合 |
|
||||
| 0xA0-0xAF | Patch Management | C→S 状态, S→C 配置 |
|
||||
| 0xB0-0xBF | Behavior Metrics | C→S 指标 |
|
||||
|
||||
### 关键类型
|
||||
|
||||
- `Frame` — 帧结构(version + msg_type + payload bytes)
|
||||
- `FrameError` — 解码错误枚举(InvalidMagic / UnknownMessageType / PayloadTooLarge / Io)
|
||||
- 每个 MessageType 对应一个 payload struct(如 `WebAccessLogEntry`, `HeartbeatPayload`)
|
||||
|
||||
## 关联模块
|
||||
|
||||
- [[server]] — TCP 接入层调用 `Frame::decode()` 解析客户端帧,调用 `push_to_targets()` 推送配置帧
|
||||
- [[client]] — 通过 `Frame::new_json()` 构造上报帧,通过 `Frame::decode()` 解析服务器下发的帧
|
||||
- [[plugins]] — 每个插件定义自己的 payload struct 在此 crate 中
|
||||
|
||||
## 关键文件
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `crates/protocol/src/message.rs` | MessageType 枚举、Frame 编解码、所有 payload struct |
|
||||
| `crates/protocol/src/device.rs` | DeviceStatus、ProcessInfo、HardwareAsset、UsbEvent 等设备相关类型 |
|
||||
| `crates/protocol/src/lib.rs` | Re-export 所有公开类型 |
|
||||
84
wiki/server.md
Normal file
84
wiki/server.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# Server(服务端)
|
||||
|
||||
## 设计思想
|
||||
|
||||
`csm-server` 是整个系统的核心枢纽,同时承载三个协议:
|
||||
1. **TCP 二进制协议** (端口 9999) — 接入 Client 代理
|
||||
2. **HTTP REST API** (端口 9998) — 服务 Web 面板
|
||||
3. **WebSocket** (`/ws`) — 实时推送设备状态变更到前端
|
||||
|
||||
关键设计决策:
|
||||
- **SQLite + WAL** — 单机部署零依赖,WAL 模式支持并发读写
|
||||
- **include_dir 嵌入前端** — 编译时将 `web/dist/` 打包进二进制,部署只需一个文件
|
||||
- **三层权限** — public(登录/健康检查)→ authenticated(只读)→ admin(写操作)
|
||||
- **ClientRegistry** — `Arc<RwLock<HashMap>>` 管理在线客户端的 TCP 写端,支持 `push_to_targets()` 三级作用域推送
|
||||
|
||||
## 代码逻辑
|
||||
|
||||
### 启动流程
|
||||
|
||||
```
|
||||
main() → load config → init SQLite → run migrations → ensure admin
|
||||
→ spawn TCP listener (9999)
|
||||
→ spawn alert cleanup task
|
||||
→ spawn health score task
|
||||
→ build HTTP router (9998) with CORS/security headers/SPA fallback
|
||||
→ axum::serve()
|
||||
```
|
||||
|
||||
### TCP 接入层 (`tcp.rs`)
|
||||
|
||||
- `start_tcp_server()` — 监听 TCP,每连接 spawn 一个 task
|
||||
- `process_frame()` — 根据 MessageType 分发到对应 handler(需先 verify_device_uid)
|
||||
- `ClientRegistry` — 线程安全的在线设备注册表,支持 `list_online()`、`send_frame()`
|
||||
- `push_to_targets(db, clients, msg_type, payload, target_type, target_id)` — 三级作用域推送(global/group/device)
|
||||
- 帧速率限制:100 帧/5秒/连接
|
||||
- HMAC 验证:心跳帧必须携带 HMAC-SHA256 签名,连续 3 次失败断开
|
||||
- 空闲超时:180 秒无数据断开
|
||||
- 最大并发连接:500
|
||||
|
||||
### HTTP API (`api/`)
|
||||
|
||||
路由分三层:
|
||||
- **public**: `/api/auth/login`, `/api/auth/refresh`, `/health`
|
||||
- **authenticated** (require_auth 中间件): GET 类设备/资产/告警/插件查询
|
||||
- **admin** (require_admin + require_auth): 设备删除、策略增删改、插件配置写入
|
||||
|
||||
统一响应格式 `ApiResponse<T>`:`{ success, data, error }`,分页默认 page=1, page_size=20, 上限 100。
|
||||
|
||||
### WebSocket (`ws.rs`)
|
||||
|
||||
- `WsHub` 广播设备上线/离线/状态变更事件给所有连接的前端客户端
|
||||
- JWT 认证通过 query parameter `?token=xxx`
|
||||
|
||||
### 后台任务
|
||||
|
||||
- `alert::cleanup_task()` — 定期清理过期告警
|
||||
- `health::health_score_task()` — 定期计算设备健康评分
|
||||
|
||||
## 关联模块
|
||||
|
||||
- [[protocol]] — 使用 Frame 编解码和 MessageType 分发
|
||||
- [[client]] — TCP 连接的对端
|
||||
- [[web-frontend]] — HTTP API 和 WebSocket 的消费者
|
||||
- [[plugins]] — API 层的 plugins/ 子模块处理所有插件相关路由
|
||||
- [[database]] — 数据库操作集中在 db.rs
|
||||
|
||||
## 关键文件
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `crates/server/src/main.rs` | 启动入口、数据库初始化、迁移、路由组装、SPA fallback |
|
||||
| `crates/server/src/tcp.rs` | TCP 监听、帧处理、ClientRegistry、push_to_targets |
|
||||
| `crates/server/src/ws.rs` | WebSocket hub 广播 |
|
||||
| `crates/server/src/api/mod.rs` | 路由定义、ApiResponse 信封、Pagination |
|
||||
| `crates/server/src/api/auth.rs` | JWT 登录/刷新/改密、限流、require_auth/require_admin 中间件 |
|
||||
| `crates/server/src/api/devices.rs` | 设备列表/详情/状态/历史/健康评分 API |
|
||||
| `crates/server/src/api/plugins/mod.rs` | 插件路由注册(read_routes + write_routes) |
|
||||
| `crates/server/src/api/plugins/*.rs` | 各插件 API handler(每个插件一个文件) |
|
||||
| `crates/server/src/db.rs` | DeviceRepo 数据库操作方法集合 |
|
||||
| `crates/server/src/config.rs` | AppConfig TOML 配置加载 |
|
||||
| `crates/server/src/health.rs` | 设备健康评分计算 |
|
||||
| `crates/server/src/anomaly.rs` | 异常检测逻辑 |
|
||||
| `crates/server/src/alert.rs` | 告警处理与清理 |
|
||||
| `crates/server/src/audit.rs` | 审计日志 |
|
||||
70
wiki/web-frontend.md
Normal file
70
wiki/web-frontend.md
Normal file
@@ -0,0 +1,70 @@
|
||||
# Web Frontend(管理面板)
|
||||
|
||||
## 设计思想
|
||||
|
||||
Vue 3 + TypeScript + Vite + Element Plus + Pinia + ECharts 的单页应用。关键决策:
|
||||
|
||||
1. **SPA 嵌入部署** — 构建产物 `web/dist/` 通过 `include_dir!` 编译进 server 二进制,部署零额外依赖
|
||||
2. **JWT 本地存储** — token 存 `localStorage`,路由守卫检查过期,30 秒内即将过期视为无效
|
||||
3. **按路由懒加载** — 所有页面组件使用 `() => import(...)` 动态导入
|
||||
|
||||
## 代码逻辑
|
||||
|
||||
### 路由结构
|
||||
|
||||
```
|
||||
/login → Login.vue(公开)
|
||||
/ → Layout.vue(认证后)
|
||||
/dashboard → Dashboard.vue(仪表盘/健康概览)
|
||||
/devices → Devices.vue(设备列表)
|
||||
/devices/:uid → DeviceDetail.vue(设备详情)
|
||||
/usb → UsbPolicy.vue(U盘策略管理)
|
||||
/alerts → Alerts.vue(告警管理)
|
||||
/settings → Settings.vue(系统设置)
|
||||
/plugins/web-filter → WebFilter.vue
|
||||
/plugins/usage-timer → UsageTimer.vue
|
||||
/plugins/software-blocker → SoftwareBlocker.vue
|
||||
/plugins/popup-blocker → PopupBlocker.vue
|
||||
/plugins/usb-file-audit → UsbFileAudit.vue
|
||||
/plugins/watermark → Watermark.vue
|
||||
/plugins/disk-encryption → DiskEncryption.vue
|
||||
/plugins/print-audit → PrintAudit.vue
|
||||
/plugins/clipboard-control → ClipboardControl.vue
|
||||
/plugins/plugin-control → PluginControl.vue
|
||||
```
|
||||
|
||||
### 认证流程
|
||||
|
||||
1. Login.vue → `POST /api/auth/login` → 获取 access_token + refresh_token
|
||||
2. token 存入 localStorage
|
||||
3. 路由守卫 `beforeEach` 检查 JWT 过期(解析 payload.exp)
|
||||
4. API 调用携带 `Authorization: Bearer <token>` header
|
||||
5. token 过期 → 自动跳转 /login
|
||||
|
||||
### API 通信
|
||||
|
||||
`web/src/lib/api.ts` — 封装所有 API 调用,统一处理认证和错误。
|
||||
|
||||
### 状态管理
|
||||
|
||||
`web/src/stores/devices.ts` — Pinia store 管理设备列表状态。
|
||||
|
||||
## 关联模块
|
||||
|
||||
- [[server]] — 消费其 HTTP REST API 和 WebSocket 推送
|
||||
- [[plugins]] — 每个插件页面对应 server 端的插件 API
|
||||
|
||||
## 关键文件
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `web/src/main.ts` | 应用入口、Vue 实例创建 |
|
||||
| `web/src/App.vue` | 根组件 |
|
||||
| `web/src/router/index.ts` | 路由定义、JWT 路由守卫 |
|
||||
| `web/src/lib/api.ts` | API 通信封装 |
|
||||
| `web/src/stores/devices.ts` | Pinia 设备状态管理 |
|
||||
| `web/src/views/Layout.vue` | 主布局(侧边栏+内容区) |
|
||||
| `web/src/views/Dashboard.vue` | 仪表盘页 |
|
||||
| `web/src/views/Devices.vue` | 设备列表页 |
|
||||
| `web/src/views/DeviceDetail.vue` | 设备详情页 |
|
||||
| `web/src/views/plugins/*.vue` | 各插件管理页面 |
|
||||
Reference in New Issue
Block a user