feat: 新增补丁管理和异常检测插件及相关功能

feat(protocol): 添加补丁管理和行为指标协议类型
feat(client): 实现补丁管理插件采集功能
feat(server): 添加补丁管理和异常检测API
feat(database): 新增补丁状态和异常检测相关表
feat(web): 添加补丁管理和异常检测前端页面
fix(security): 增强输入验证和防注入保护
refactor(auth): 重构认证检查逻辑
perf(service): 优化Windows服务恢复策略
style: 统一健康评分显示样式
docs: 更新知识库文档
This commit is contained in:
iven
2026-04-11 15:59:53 +08:00
parent b5333d8c93
commit 60ee38a3c2
49 changed files with 3988 additions and 461 deletions

View File

@@ -121,6 +121,14 @@ fn collect_system_details() -> (Option<String>, Option<String>, Option<String>)
#[cfg(target_os = "windows")]
fn powershell_lines(command: &str) -> Vec<String> {
use std::process::Command;
// Reject commands containing suspicious patterns that could indicate injection
let lower = command.to_lowercase();
if lower.contains("invoke-expression") || lower.contains("iex ") || lower.contains("& ") {
tracing::warn!("Rejected suspicious PowerShell command pattern");
return Vec::new();
}
let output = match Command::new("powershell")
.args(["-NoProfile", "-NonInteractive", "-Command",
&format!("[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {}", command)])

View File

@@ -1,7 +1,7 @@
use anyhow::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tracing::{info, error, warn};
use tracing::{info, error, warn, debug};
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
mod monitor;
@@ -17,6 +17,7 @@ mod web_filter;
mod disk_encryption;
mod clipboard_control;
mod print_audit;
mod patch;
#[cfg(target_os = "windows")]
mod service;
@@ -58,17 +59,22 @@ fn main() -> Result<()> {
info!("CSM Client starting (console mode)...");
let device_uid = load_or_create_device_uid()?;
info!("Device UID: {}", device_uid);
debug!("Device UID: {}", device_uid);
let server_addr = std::env::var("CSM_SERVER")
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
let registration_token = std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default();
if registration_token.is_empty() {
tracing::warn!("CSM_REGISTRATION_TOKEN not set — device registration may fail");
}
let state = ClientState {
device_uid,
server_addr,
config: ClientConfig::default(),
device_secret: load_device_secret(),
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
registration_token,
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
};
@@ -97,6 +103,7 @@ pub async fn run(state: ClientState) -> Result<()> {
let (disk_encryption_tx, disk_encryption_rx) = tokio::sync::watch::channel(disk_encryption::DiskEncryptionConfig::default());
let (print_audit_tx, print_audit_rx) = tokio::sync::watch::channel(print_audit::PrintAuditConfig::default());
let (clipboard_control_tx, clipboard_control_rx) = tokio::sync::watch::channel(clipboard_control::ClipboardControlConfig::default());
let (patch_tx, patch_rx) = tokio::sync::watch::channel(patch::PluginConfig::default());
// Build plugin channels struct
let plugins = network::PluginChannels {
@@ -110,6 +117,7 @@ pub async fn run(state: ClientState) -> Result<()> {
disk_encryption_tx,
print_audit_tx,
clipboard_control_tx,
patch_tx,
};
// Spawn core monitoring tasks
@@ -182,6 +190,12 @@ pub async fn run(state: ClientState) -> Result<()> {
clipboard_control::start(clipboard_control_rx, cc_data_tx, cc_uid).await;
});
let patch_data_tx = data_tx.clone();
let patch_uid = state.device_uid.clone();
tokio::spawn(async move {
patch::start(patch_rx, patch_data_tx, patch_uid).await;
});
// Connect to server with reconnect
let mut backoff = Duration::from_secs(1);
let max_backoff = Duration::from_secs(60);
@@ -270,5 +284,16 @@ fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Resu
#[cfg(not(unix))]
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
std::fs::write(path, content)
std::fs::write(path, content)?;
// Restrict file ACL to SYSTEM and Administrators only on Windows
let path_str = path.to_string_lossy();
// Remove inherited permissions and grant only SYSTEM full control
let _ = std::process::Command::new("icacls")
.args([&*path_str, "/inheritance:r", "/grant:r", "SYSTEM:(F)", "/grant:r", "Administrators:(F)"])
.output();
// Also hide the file
let _ = std::process::Command::new("attrib")
.args(["+H", &*path_str])
.output();
Ok(())
}

View File

@@ -1,11 +1,13 @@
use anyhow::Result;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload, PatchScanConfigPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use sha2::{Sha256, Digest};
use crate::ClientState;
@@ -21,6 +23,7 @@ pub struct PluginChannels {
pub disk_encryption_tx: tokio::sync::watch::Sender<crate::disk_encryption::DiskEncryptionConfig>,
pub print_audit_tx: tokio::sync::watch::Sender<crate::print_audit::PrintAuditConfig>,
pub clipboard_control_tx: tokio::sync::watch::Sender<crate::clipboard_control::ClipboardControlConfig>,
pub patch_tx: tokio::sync::watch::Sender<crate::patch::PluginConfig>,
}
/// Connect to server and run the main communication loop
@@ -30,7 +33,7 @@ pub async fn connect_and_run(
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
debug!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
@@ -40,9 +43,7 @@ pub async fn connect_and_run(
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
/// Wrap a TCP stream with TLS and certificate pinning.
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
@@ -62,19 +63,38 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
// Check if skip-verify is allowed (only in CSM_DEV mode)
let skip_verify = std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true")
&& std::env::var("CSM_DEV").is_ok();
let config = if skip_verify {
warn!("TLS certificate verification DISABLED — CSM_DEV mode only!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth()
} else {
// Build standard verifier with pinning wrapper
let inner = rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| anyhow::anyhow!("Failed to build TLS verifier: {:?}", e))?;
let pin_file = pin_file_path();
let pinned_hashes = load_pinned_hashes(&pin_file);
let verifier = PinnedCertVerifier {
inner,
pin_file,
pinned_hashes: Arc::new(Mutex::new(pinned_hashes)),
};
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
@@ -86,6 +106,131 @@ async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::
Ok(tls_stream)
}
/// Default pin file path: %PROGRAMDATA%\CSM\server_cert_pin (Windows)
fn pin_file_path() -> PathBuf {
if let Ok(custom) = std::env::var("CSM_TLS_PIN_FILE") {
PathBuf::from(custom)
} else if cfg!(target_os = "windows") {
std::env::var("PROGRAMDATA")
.map(|p| PathBuf::from(p).join("CSM").join("server_cert_pin"))
.unwrap_or_else(|_| PathBuf::from("server_cert_pin"))
} else {
PathBuf::from("/var/lib/csm/server_cert_pin")
}
}
/// Load pinned certificate hashes from file.
/// Format: one hex-encoded SHA-256 hash per line.
fn load_pinned_hashes(path: &PathBuf) -> Vec<String> {
match std::fs::read_to_string(path) {
Ok(content) => content.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect(),
Err(_) => Vec::new(), // First connection — no pin file yet
}
}
/// Save a pinned hash to the pin file.
fn save_pinned_hash(path: &PathBuf, hash: &str) {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let _ = std::fs::write(path, format!("{}\n", hash));
}
/// Compute SHA-256 fingerprint of a DER-encoded certificate.
fn cert_fingerprint(cert: &rustls_pki_types::CertificateDer) -> String {
let mut hasher = Sha256::new();
hasher.update(cert.as_ref());
hex::encode(hasher.finalize())
}
/// Certificate verifier with pinning support.
/// On first connection (no stored pin), records the certificate fingerprint.
/// On subsequent connections, verifies the fingerprint matches.
#[derive(Debug)]
struct PinnedCertVerifier {
inner: Arc<rustls::client::WebPkiServerVerifier>,
pin_file: PathBuf,
pinned_hashes: Arc<Mutex<Vec<String>>>,
}
impl rustls::client::danger::ServerCertVerifier for PinnedCertVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls_pki_types::CertificateDer,
intermediates: &[rustls_pki_types::CertificateDer],
server_name: &rustls_pki_types::ServerName,
ocsp_response: &[u8],
now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
// 1. Standard PKIX verification
self.inner.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?;
// 2. Compute certificate fingerprint
let fingerprint = cert_fingerprint(end_entity);
// 3. Check against pinned hashes
let mut pinned = self.pinned_hashes.lock().unwrap();
if pinned.is_empty() {
// First connection — record the certificate fingerprint
info!("Recording server certificate pin: {}...", &fingerprint[..16]);
save_pinned_hash(&self.pin_file, &fingerprint);
pinned.push(fingerprint);
} else if !pinned.contains(&fingerprint) {
warn!("Certificate pin mismatch! Expected one of {:?}, got {}", pinned, fingerprint);
return Err(rustls::Error::General(
"Server certificate does not match pinned fingerprint. Possible MITM attack.".into(),
));
}
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
/// Update pinned certificate hash (called when receiving TlsCertRotate).
pub fn update_cert_pin(new_hash: &str) {
let pin_file = pin_file_path();
let mut pinned = load_pinned_hashes(&pin_file);
if !pinned.contains(&new_hash.to_string()) {
pinned.push(new_hash.to_string());
// Keep only the last 2 hashes (current + rotating)
while pinned.len() > 2 {
pinned.remove(0);
}
// Write all hashes to file
if let Some(parent) = pin_file.parent() {
let _ = std::fs::create_dir_all(parent);
}
let content = pinned.iter().map(|h| h.as_str()).collect::<Vec<_>>().join("\n");
let _ = std::fs::write(&pin_file, format!("{}\n", content));
info!("Updated certificate pin file with new hash: {}...", &new_hash[..16]);
}
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
@@ -242,7 +387,20 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
let update: csm_protocol::ConfigUpdateType = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid config update: {}", e))?;
match update {
csm_protocol::ConfigUpdateType::UpdateIntervals { heartbeat, status, asset } => {
info!("Config update: intervals heartbeat={}s status={}s asset={}s", heartbeat, status, asset);
}
csm_protocol::ConfigUpdateType::TlsCertRotate { new_cert_hash, valid_until } => {
info!("Certificate rotation: new hash={}... valid_until={}", &new_cert_hash[..16.min(new_cert_hash.len())], valid_until);
update_cert_pin(&new_cert_hash);
}
csm_protocol::ConfigUpdateType::SelfDestruct => {
warn!("Self-destruct command received (not implemented)");
}
}
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
@@ -276,7 +434,14 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
let whitelist: Vec<String> = payload.get("whitelist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig {
enabled: true,
blacklist,
whitelist,
};
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
@@ -322,6 +487,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
};
plugins.clipboard_control_tx.send(config)?;
}
MessageType::PatchScanConfig => {
let config: PatchScanConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid patch scan config: {}", e))?;
info!("Received patch scan config: enabled={}, interval={}s", config.enabled, config.scan_interval_secs);
let plugin_config = crate::patch::PluginConfig {
enabled: config.enabled,
scan_interval_secs: config.scan_interval_secs,
};
plugins.patch_tx.send(plugin_config)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
@@ -351,7 +526,7 @@ fn handle_plugin_control(
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![], whitelist: vec![] })?;
}
}
"popup_blocker" => {
@@ -384,6 +559,11 @@ fn handle_plugin_control(
plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?;
}
}
"patch" => {
if !enabled {
plugins.patch_tx.send(crate::patch::PluginConfig { enabled: false, scan_interval_secs: 43200 })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}

View File

@@ -0,0 +1,116 @@
use tokio::sync::watch;
use csm_protocol::{Frame, MessageType, PatchStatusPayload, PatchEntry};
use tracing::{debug, warn};
#[derive(Debug, Clone, Default)]
pub struct PluginConfig {
pub enabled: bool,
pub scan_interval_secs: u64,
}
pub async fn start(
mut config_rx: watch::Receiver<PluginConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
let mut config = config_rx.borrow_and_update().clone();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(
if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 }
));
interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() { break; }
config = config_rx.borrow_and_update().clone();
let new_secs = if config.scan_interval_secs > 0 { config.scan_interval_secs } else { 43200 };
interval = tokio::time::interval(std::time::Duration::from_secs(new_secs));
interval.tick().await;
debug!("Patch config updated: enabled={}, interval={}s", config.enabled, new_secs);
}
_ = interval.tick() => {
if !config.enabled { continue; }
match collect_patches().await {
Ok(patches) => {
if patches.is_empty() {
debug!("No patches collected for device {}", device_uid);
continue;
}
debug!("Collected {} patches for device {}", patches.len(), device_uid);
let payload = PatchStatusPayload {
device_uid: device_uid.clone(),
patches,
};
if let Ok(frame) = Frame::new_json(MessageType::PatchStatusReport, &payload) {
if data_tx.send(frame).await.is_err() {
warn!("Failed to send patch status: channel closed");
break;
}
}
}
Err(e) => {
warn!("Patch collection failed: {}", e);
}
}
}
}
}
}
async fn collect_patches() -> anyhow::Result<Vec<PatchEntry>> {
// SECURITY: PowerShell command uses only hardcoded strings with no user/remote input.
// The format!() only inserts PowerShell syntax, not external data.
let output = tokio::process::Command::new("powershell")
.args([
"-NoProfile", "-NonInteractive", "-Command",
&format!(
"[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; \
Get-HotFix | Select-Object -First 200 | \
ForEach-Object {{ \
[PSCustomObject]@{{ \
kb = $_.HotFixID; \
desc = $_.Description; \
installed = if ($_.InstalledOn) {{ $_.InstalledOn.ToString('yyyy-MM-dd') }} else {{ '' }} \
}} \
}} | ConvertTo-Json -Compress"
),
])
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow::anyhow!("PowerShell failed: {}", stderr));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let trimmed = stdout.trim();
if trimmed.is_empty() {
return Ok(Vec::new());
}
// Handle single-item case (PowerShell returns object instead of array)
let items: Vec<serde_json::Value> = if trimmed.starts_with('[') {
serde_json::from_str(trimmed).unwrap_or_default()
} else {
serde_json::from_str(trimmed).map(|v: serde_json::Value| vec![v]).unwrap_or_default()
};
let patches: Vec<PatchEntry> = items.iter().filter_map(|item| {
let kb = item.get("kb")?.as_str()?.to_string();
if kb.is_empty() { return None; }
let desc = item.get("desc").and_then(|v| v.as_str()).unwrap_or("");
let installed_str = item.get("installed").and_then(|v| v.as_str()).unwrap_or("");
Some(PatchEntry {
title: format!("{} - {}", kb, desc),
kb_id: kb,
severity: None, // Will be enriched server-side from known CVE data
is_installed: true,
installed_at: if installed_str.is_empty() { None } else { Some(installed_str.to_string()) },
})
}).collect();
Ok(patches)
}

View File

@@ -1,9 +1,10 @@
use std::ffi::OsString;
use std::time::Duration;
use tracing::{info, error};
use tracing::{info, error, debug};
use windows_service::define_windows_service;
use windows_service::service::{
ServiceAccess, ServiceControl, ServiceErrorControl, ServiceExitCode,
ServiceAccess, ServiceAction, ServiceActionType, ServiceControl, ServiceErrorControl,
ServiceExitCode, ServiceFailureActions, ServiceFailureResetPeriod,
ServiceInfo, ServiceStartType, ServiceState, ServiceStatus, ServiceType,
};
use windows_service::service_control_handler::{self, ServiceControlHandlerResult};
@@ -38,6 +39,30 @@ pub fn install() -> anyhow::Result<()> {
let service = manager.create_service(&service_info, ServiceAccess::CHANGE_CONFIG)?;
service.set_description(SERVICE_DESCRIPTION)?;
// Configure service recovery: restart on failure with escalating delays
let failure_actions = ServiceFailureActions {
reset_period: ServiceFailureResetPeriod::After(std::time::Duration::from_secs(60)),
reboot_msg: None,
command: None,
actions: Some(vec![
ServiceAction {
action_type: ServiceActionType::Restart,
delay: std::time::Duration::from_secs(5),
},
ServiceAction {
action_type: ServiceActionType::Restart,
delay: std::time::Duration::from_secs(30),
},
ServiceAction {
action_type: ServiceActionType::Restart,
delay: std::time::Duration::from_secs(60),
},
]),
};
if let Err(e) = service.update_failure_actions(failure_actions) {
tracing::warn!("Failed to set service recovery actions: {}", e);
}
println!("Service '{}' installed successfully.", SERVICE_NAME);
println!("Use 'sc start {}' or restart to launch.", SERVICE_NAME);
Ok(())
@@ -117,7 +142,7 @@ fn run_service_inner() -> anyhow::Result<()> {
// Build ClientState
let device_uid = crate::load_or_create_device_uid()?;
info!("Device UID: {}", device_uid);
debug!("Device UID: {}", device_uid);
let state = crate::ClientState {
device_uid,

View File

@@ -1,11 +1,13 @@
use std::collections::HashSet;
use tokio::sync::watch;
use tracing::{info, warn};
use tracing::{info, warn, debug};
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
use serde::Deserialize;
/// System-critical processes that must never be killed regardless of server rules.
/// Killing any of these would cause system instability or a BSOD.
const PROTECTED_PROCESSES: &[&str] = &[
// Windows system processes
"system",
"system idle process",
"svchost.exe",
@@ -20,8 +22,60 @@ const PROTECTED_PROCESSES: &[&str] = &[
"registry",
"smss.exe",
"conhost.exe",
"ntoskrnl.exe",
"dcomlaunch.exe",
"rundll32.exe",
"sihost.exe",
"taskeng.exe",
"wermgr.exe",
"WerFault.exe",
"fontdrvhost.exe",
"ctfmon.exe",
"SearchIndexer.exe",
"SearchHost.exe",
"RuntimeBroker.exe",
"SecurityHealthService.exe",
"SecurityHealthSystray.exe",
"MpCmdRun.exe",
"MsMpEng.exe",
"NisSrv.exe",
// Common browsers — should never be blocked unless explicitly configured with exact name
"chrome.exe",
"msedge.exe",
"firefox.exe",
"iexplore.exe",
"opera.exe",
"brave.exe",
"vivaldi.exe",
"thorium.exe",
// Development tools & IDEs
"code.exe",
"devenv.exe",
"idea64.exe",
"webstorm64.exe",
"pycharm64.exe",
"goland64.exe",
"clion64.exe",
"rider64.exe",
"datagrip64.exe",
"trae.exe",
"windsurf.exe",
"cursor.exe",
"zed.exe",
// Terminal & system tools
"cmd.exe",
"powershell.exe",
"pwsh.exe",
"WindowsTerminal.exe",
"conhost.exe",
// CSM itself
"csm-client.exe",
];
/// Cooldown period (seconds) before reporting/killing the same process again.
/// Prevents spamming violations for long-running blocked processes.
const REPORT_COOLDOWN_SECS: u64 = 300; // 5 minutes
/// Software blacklist entry from server
#[derive(Debug, Clone, Deserialize)]
pub struct BlacklistEntry {
@@ -36,6 +90,8 @@ pub struct BlacklistEntry {
pub struct SoftwareBlockerConfig {
pub enabled: bool,
pub blacklist: Vec<BlacklistEntry>,
/// Server-pushed whitelist: processes matching these patterns are never blocked.
pub whitelist: Vec<String>,
}
/// Start software blocker plugin.
@@ -47,6 +103,11 @@ pub async fn start(
) {
info!("Software blocker plugin started");
let mut config = SoftwareBlockerConfig::default();
// Track recently acted-on processes to avoid repeated kill/report spam
let mut recent_actions: HashSet<String> = HashSet::new();
let mut cooldown_interval = tokio::time::interval(std::time::Duration::from_secs(REPORT_COOLDOWN_SECS));
cooldown_interval.tick().await;
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
scan_interval.tick().await;
@@ -58,13 +119,23 @@ pub async fn start(
}
let new_config = config_rx.borrow_and_update().clone();
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
// Clear cooldown cache when config changes so new rules take effect immediately
recent_actions.clear();
config = new_config;
}
_ = cooldown_interval.tick() => {
// Periodically clear the cooldown cache so we can re-check
let cleared = recent_actions.len();
recent_actions.clear();
if cleared > 0 {
info!("Software blocker cooldown cache cleared ({} entries)", cleared);
}
}
_ = scan_interval.tick() => {
if !config.enabled || config.blacklist.is_empty() {
continue;
}
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
scan_processes(&config.blacklist, &config.whitelist, &data_tx, &device_uid, &mut recent_actions).await;
}
}
}
@@ -72,83 +143,148 @@ pub async fn start(
async fn scan_processes(
blacklist: &[BlacklistEntry],
whitelist: &[String],
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
recent_actions: &mut HashSet<String>,
) {
let running = get_running_processes_with_pids();
for entry in blacklist {
for (process_name, pid) in &running {
if pattern_matches(&entry.name_pattern, process_name) {
// Never kill system-critical processes
if is_protected_process(process_name) {
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
continue;
}
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
// Report violation to server
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
let action_taken = match entry.action.as_str() {
"block" => "blocked_install",
"alert" => "alerted",
other => other,
};
let violation = SoftwareViolationReport {
device_uid: device_uid.to_string(),
software_name: process_name.clone(),
action_taken: action_taken.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
let _ = data_tx.send(frame).await;
}
// Kill the process directly by captured PID (avoids TOCTOU race)
if entry.action == "block" {
kill_process_by_pid(*pid, process_name);
}
if !pattern_matches(&entry.name_pattern, process_name) {
continue;
}
// Never kill protected processes (system, browsers, IDEs, etc.)
if is_protected_process(process_name) {
warn!(
"Blacklisted match '{}' skipped — protected process (pid={})",
process_name, pid
);
continue;
}
// Check server-pushed whitelist (takes precedence over blacklist)
if is_whitelisted(process_name, whitelist) {
debug!(
"Blacklisted match '{}' skipped — whitelisted (pid={})",
process_name, pid
);
continue;
}
// Skip if already acted on recently (cooldown)
let action_key = format!("{}:{}", process_name.to_lowercase(), pid);
if recent_actions.contains(&action_key) {
continue;
}
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
// Report violation to server
let action_taken = match entry.action.as_str() {
"block" => "blocked_install",
"alert" => "alerted",
other => other,
};
let violation = SoftwareViolationReport {
device_uid: device_uid.to_string(),
software_name: process_name.clone(),
action_taken: action_taken.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
let _ = data_tx.send(frame).await;
}
// Kill the process if action is "block"
if entry.action == "block" {
kill_process_by_pid(*pid, process_name);
}
// Mark as recently acted-on
recent_actions.insert(action_key);
}
}
}
/// Match a blacklist pattern against a process name.
///
/// **Matching rules (case-insensitive)**:
/// - No wildcard → exact filename match only (e.g. `chrome.exe` matches `chrome.exe` but NOT `new_chrome.exe`)
/// - `*` wildcard → glob-style pattern match (e.g. `*miner*` matches `bitcoin_miner.exe`)
/// - Pattern with `.exe` suffix → match against full process name
/// - Pattern without extension → match against stem (name without `.exe`)
///
/// **IMPORTANT**: We intentionally do NOT use substring matching (`contains()`) for
/// non-wildcard patterns. Substring matching caused false positives where a pattern
/// like "game" would match "game_bar.exe" or even "svchost.exe" in edge cases.
fn pattern_matches(pattern: &str, name: &str) -> bool {
let pattern_lower = pattern.to_lowercase();
let name_lower = name.to_lowercase();
// Support wildcard patterns
if pattern_lower.contains('*') {
let parts: Vec<&str> = pattern_lower.split('*').collect();
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 && !parts[0].is_empty() {
// Pattern starts with literal → must match at start
if !name_lower.starts_with(part) {
return false;
}
pos = part.len();
} else {
match name_lower[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with literal (no trailing *), must match at end
if !parts.last().map_or(true, |p| p.is_empty()) {
return name_lower.ends_with(parts.last().unwrap());
}
true
// Extract the stem (filename without extension) for extension-less patterns
let name_stem = name_lower.strip_suffix(".exe").unwrap_or(&name_lower);
if pattern_lower.contains('*') {
// Wildcard glob matching
glob_matches(&pattern_lower, &name_lower, &name_stem)
} else {
name_lower.contains(&pattern_lower)
// Exact match: compare against full name OR stem (if pattern has no extension)
let pattern_stem = pattern_lower.strip_suffix(".exe").unwrap_or(&pattern_lower);
name_lower == pattern_lower || name_stem == pattern_stem
}
}
/// Glob-style pattern matching with `*` wildcards.
/// Checks both the full name and the stem (without .exe).
fn glob_matches(pattern: &str, full_name: &str, stem: &str) -> bool {
// Try matching against both the full process name and the stem
glob_matches_single(pattern, full_name) || glob_matches_single(pattern, stem)
}
fn glob_matches_single(pattern: &str, text: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.is_empty() {
return true;
}
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 {
// First segment: must match at the start
if !text.starts_with(part) {
return false;
}
pos = part.len();
} else if i == parts.len() - 1 {
// Last segment: must match at the end (if pattern doesn't end with *)
return text.ends_with(part) && text[..text.len() - part.len()].len() >= pos;
} else {
// Middle segment: must appear after current position
match text[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with *, anything after last match is fine
if pattern.ends_with('*') {
return true;
}
// Full match required
pos == text.len()
}
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
#[cfg(target_os = "windows")]
@@ -253,3 +389,12 @@ fn is_protected_process(name: &str) -> bool {
let lower = name.to_lowercase();
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
}
/// Check if a process name matches any server-pushed whitelist pattern.
/// Whitelist patterns use the same matching logic as blacklist (exact or glob).
fn is_whitelisted(process_name: &str, whitelist: &[String]) -> bool {
if whitelist.is_empty() {
return false;
}
whitelist.iter().any(|pattern| pattern_matches(pattern, process_name))
}