feat: 初始化项目基础架构和核心功能
- 添加项目基础结构:Cargo.toml、.gitignore、设备UID和密钥文件 - 实现前端Vue3项目结构:路由、登录页面、设备管理页面 - 添加核心协议定义(crates/protocol):设备状态、资产、USB事件等 - 实现客户端监控模块:系统状态收集、资产收集 - 实现服务端基础API和插件系统 - 添加数据库迁移脚本:设备管理、资产跟踪、告警系统等 - 实现前端设备状态展示和基本交互 - 添加使用时长统计和水印功能插件
This commit is contained in:
48
crates/client/Cargo.toml
Normal file
48
crates/client/Cargo.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[package]
|
||||
name = "csm-client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
csm-protocol = { path = "../protocol" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
sysinfo = "0.30"
|
||||
tokio-rustls = "0.26"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rustls-pki-types = "1"
|
||||
webpki-roots = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
windows = { version = "0.54", features = [
|
||||
"Win32_Foundation",
|
||||
"Win32_System_SystemInformation",
|
||||
"Win32_System_Registry",
|
||||
"Win32_System_IO",
|
||||
"Win32_Security",
|
||||
"Win32_NetworkManagement_IpHelper",
|
||||
"Win32_Storage_FileSystem",
|
||||
"Win32_UI_WindowsAndMessaging",
|
||||
"Win32_UI_Input_KeyboardAndMouse",
|
||||
"Win32_System_Threading",
|
||||
"Win32_System_Diagnostics_ToolHelp",
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_Performance",
|
||||
"Win32_Graphics_Gdi",
|
||||
] }
|
||||
windows-service = "0.7"
|
||||
hostname = "0.4"
|
||||
|
||||
[target.'cfg(not(target_os = "windows"))'.dependencies]
|
||||
hostname = "0.4"
|
||||
54
crates/client/src/asset/mod.rs
Normal file
54
crates/client/src/asset/mod.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use csm_protocol::{Frame, MessageType, HardwareAsset};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{info, error};
|
||||
use sysinfo::System;
|
||||
|
||||
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
let interval = Duration::from_secs(86400); // Once per day
|
||||
|
||||
// Initial collection on startup
|
||||
if let Err(e) = collect_and_send(&tx, &device_uid).await {
|
||||
error!("Initial asset collection failed: {}", e);
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
if let Err(e) = collect_and_send(&tx, &device_uid).await {
|
||||
error!("Asset collection failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_and_send(tx: &Sender<Frame>, device_uid: &str) -> anyhow::Result<()> {
|
||||
let hardware = collect_hardware(device_uid)?;
|
||||
let frame = Frame::new_json(MessageType::AssetReport, &hardware)?;
|
||||
tx.send(frame).await.map_err(|e| anyhow::anyhow!("Channel send failed: {}", e))?;
|
||||
info!("Asset report sent for {}", device_uid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
|
||||
let mut sys = System::new_all();
|
||||
sys.refresh_all();
|
||||
|
||||
let cpu_model = sys.cpus().first()
|
||||
.map(|c| c.brand().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
let cpu_cores = sys.cpus().len() as u32;
|
||||
let memory_total_mb = sys.total_memory() / 1024 / 1024; // bytes to MB (sysinfo 0.30)
|
||||
|
||||
Ok(HardwareAsset {
|
||||
device_uid: device_uid.to_string(),
|
||||
cpu_model,
|
||||
cpu_cores,
|
||||
memory_total_mb: memory_total_mb as u64,
|
||||
disk_model: "Unknown".to_string(),
|
||||
disk_total_mb: 0,
|
||||
gpu_model: None,
|
||||
motherboard: None,
|
||||
serial_number: None,
|
||||
})
|
||||
}
|
||||
206
crates/client/src/main.rs
Normal file
206
crates/client/src/main.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use anyhow::Result;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, error, warn};
|
||||
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
|
||||
|
||||
mod monitor;
|
||||
mod asset;
|
||||
mod usb;
|
||||
mod network;
|
||||
mod watermark;
|
||||
mod usage_timer;
|
||||
mod usb_audit;
|
||||
mod popup_blocker;
|
||||
mod software_blocker;
|
||||
mod web_filter;
|
||||
|
||||
/// Shared shutdown flag
|
||||
static SHUTDOWN: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
/// Client configuration
|
||||
struct ClientState {
|
||||
device_uid: String,
|
||||
server_addr: String,
|
||||
config: ClientConfig,
|
||||
device_secret: Option<String>,
|
||||
registration_token: String,
|
||||
/// Whether to use TLS when connecting to the server
|
||||
use_tls: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("csm_client=info")
|
||||
.init();
|
||||
|
||||
info!("CSM Client starting...");
|
||||
|
||||
// Load or generate device identity
|
||||
let device_uid = load_or_create_device_uid()?;
|
||||
info!("Device UID: {}", device_uid);
|
||||
|
||||
// Load server address
|
||||
let server_addr = std::env::var("CSM_SERVER")
|
||||
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
|
||||
|
||||
let state = ClientState {
|
||||
device_uid,
|
||||
server_addr,
|
||||
config: ClientConfig::default(),
|
||||
device_secret: load_device_secret(),
|
||||
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
|
||||
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
|
||||
};
|
||||
|
||||
// TODO: Register as Windows Service on Windows
|
||||
// For development, run directly
|
||||
|
||||
// Main event loop
|
||||
run(state).await
|
||||
}
|
||||
|
||||
async fn run(state: ClientState) -> Result<()> {
|
||||
let (data_tx, mut data_rx) = tokio::sync::mpsc::channel::<Frame>(1024);
|
||||
|
||||
// Spawn Ctrl+C handler
|
||||
tokio::spawn(async {
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
info!("Received Ctrl+C, initiating graceful shutdown...");
|
||||
SHUTDOWN.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
// Create plugin config channels
|
||||
let (watermark_tx, watermark_rx) = tokio::sync::watch::channel(None);
|
||||
let (web_filter_tx, web_filter_rx) = tokio::sync::watch::channel(web_filter::WebFilterConfig::default());
|
||||
let (software_blocker_tx, software_blocker_rx) = tokio::sync::watch::channel(software_blocker::SoftwareBlockerConfig::default());
|
||||
let (popup_blocker_tx, popup_blocker_rx) = tokio::sync::watch::channel(popup_blocker::PopupBlockerConfig::default());
|
||||
let (usb_audit_tx, usb_audit_rx) = tokio::sync::watch::channel(usb_audit::UsbAuditConfig::default());
|
||||
let (usage_timer_tx, usage_timer_rx) = tokio::sync::watch::channel(usage_timer::UsageConfig::default());
|
||||
let (usb_policy_tx, usb_policy_rx) = tokio::sync::watch::channel(None::<UsbPolicyPayload>);
|
||||
|
||||
let plugins = network::PluginChannels {
|
||||
watermark_tx,
|
||||
web_filter_tx,
|
||||
software_blocker_tx,
|
||||
popup_blocker_tx,
|
||||
usb_audit_tx,
|
||||
usage_timer_tx,
|
||||
usb_policy_tx,
|
||||
};
|
||||
|
||||
// Spawn core monitoring tasks
|
||||
let monitor_tx = data_tx.clone();
|
||||
let uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
monitor::start_collecting(monitor_tx, uid).await;
|
||||
});
|
||||
|
||||
let asset_tx = data_tx.clone();
|
||||
let uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
asset::start_collecting(asset_tx, uid).await;
|
||||
});
|
||||
|
||||
let usb_tx = data_tx.clone();
|
||||
let uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
usb::start_monitoring(usb_tx, uid, usb_policy_rx).await;
|
||||
});
|
||||
|
||||
// Spawn plugin tasks
|
||||
tokio::spawn(async move {
|
||||
watermark::start(watermark_rx).await;
|
||||
});
|
||||
|
||||
let usage_data_tx = data_tx.clone();
|
||||
let usage_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
usage_timer::start(usage_timer_rx, usage_data_tx, usage_uid).await;
|
||||
});
|
||||
|
||||
let audit_data_tx = data_tx.clone();
|
||||
let audit_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
usb_audit::start(usb_audit_rx, audit_data_tx, audit_uid).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
popup_blocker::start(popup_blocker_rx).await;
|
||||
});
|
||||
|
||||
let sw_data_tx = data_tx.clone();
|
||||
let sw_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
software_blocker::start(software_blocker_rx, sw_data_tx, sw_uid).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
web_filter::start(web_filter_rx).await;
|
||||
});
|
||||
|
||||
// Connect to server with reconnect
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let max_backoff = Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
if SHUTDOWN.load(Ordering::SeqCst) {
|
||||
info!("Shutting down gracefully...");
|
||||
break Ok(());
|
||||
}
|
||||
|
||||
match network::connect_and_run(&state, &mut data_rx, &plugins).await {
|
||||
Ok(()) => {
|
||||
warn!("Disconnected from server, reconnecting...");
|
||||
// Use a short fixed delay for clean disconnects (server-initiated),
|
||||
// but don't reset to zero to prevent connection storms
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Connection error: {}, reconnecting...", e);
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(max_backoff);
|
||||
}
|
||||
}
|
||||
|
||||
// Drain stale frames that accumulated during disconnection
|
||||
let drained = data_rx.try_recv().ok().map(|_| 1).unwrap_or(0);
|
||||
if drained > 0 {
|
||||
let mut count = drained;
|
||||
while data_rx.try_recv().is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
warn!("Drained {} stale frames from channel", count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_or_create_device_uid() -> Result<String> {
|
||||
// In production, store in Windows Credential Store or local config
|
||||
// For now, use a simple file
|
||||
let uid_file = "device_uid.txt";
|
||||
if std::path::Path::new(uid_file).exists() {
|
||||
let uid = std::fs::read_to_string(uid_file)?;
|
||||
Ok(uid.trim().to_string())
|
||||
} else {
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
std::fs::write(uid_file, &uid)?;
|
||||
Ok(uid)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load persisted device_secret from disk (if available)
|
||||
pub fn load_device_secret() -> Option<String> {
|
||||
let secret_file = "device_secret.txt";
|
||||
let secret = std::fs::read_to_string(secret_file).ok()?;
|
||||
let trimmed = secret.trim().to_string();
|
||||
if trimmed.is_empty() { None } else { Some(trimmed) }
|
||||
}
|
||||
|
||||
/// Persist device_secret to disk
|
||||
pub fn save_device_secret(secret: &str) {
|
||||
if let Err(e) = std::fs::write("device_secret.txt", secret) {
|
||||
warn!("Failed to persist device_secret: {}", e);
|
||||
}
|
||||
}
|
||||
86
crates/client/src/monitor/mod.rs
Normal file
86
crates/client/src/monitor/mod.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use anyhow::Result;
|
||||
use csm_protocol::{Frame, MessageType, DeviceStatus, ProcessInfo};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{info, error, debug};
|
||||
use sysinfo::System;
|
||||
|
||||
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
let interval = Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
// Run blocking sysinfo collection on a dedicated thread
|
||||
let uid_clone = device_uid.clone();
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
collect_system_status(&uid_clone)
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(status)) => {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::StatusReport, &status) {
|
||||
debug!("Sending status report: cpu={:.1}%, mem={:.1}%", status.cpu_usage, status.memory_usage);
|
||||
if tx.send(frame).await.is_err() {
|
||||
info!("Monitor channel closed, exiting");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("Failed to collect system status: {}", e);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Monitor task join error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
|
||||
let mut sys = System::new_all();
|
||||
sys.refresh_all();
|
||||
|
||||
// Brief wait for CPU usage to stabilize
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
sys.refresh_all();
|
||||
|
||||
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
|
||||
|
||||
let total_memory = sys.total_memory() / 1024 / 1024; // Convert bytes to MB (sysinfo 0.30 returns bytes)
|
||||
let used_memory = sys.used_memory() / 1024 / 1024;
|
||||
let memory_usage = if total_memory > 0 {
|
||||
(used_memory as f64 / total_memory as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Top processes by CPU
|
||||
let mut processes: Vec<ProcessInfo> = sys.processes()
|
||||
.iter()
|
||||
.map(|(_, p)| {
|
||||
ProcessInfo {
|
||||
name: p.name().to_string(),
|
||||
pid: p.pid().as_u32(),
|
||||
cpu_usage: p.cpu_usage() as f64,
|
||||
memory_mb: p.memory() / 1024 / 1024, // bytes to MB (sysinfo 0.30)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
processes.sort_by(|a, b| b.cpu_usage.partial_cmp(&a.cpu_usage).unwrap_or(std::cmp::Ordering::Equal));
|
||||
processes.truncate(10);
|
||||
|
||||
Ok(DeviceStatus {
|
||||
device_uid: device_uid.to_string(),
|
||||
cpu_usage,
|
||||
memory_usage,
|
||||
memory_total_mb: total_memory as u64,
|
||||
disk_usage: 0.0, // TODO: implement disk usage via Windows API
|
||||
disk_total_mb: 0,
|
||||
network_rx_rate: 0,
|
||||
network_tx_rate: 0,
|
||||
running_procs: sys.processes().len() as u32,
|
||||
top_processes: processes,
|
||||
})
|
||||
}
|
||||
380
crates/client/src/network/mod.rs
Normal file
380
crates/client/src/network/mod.rs
Normal file
@@ -0,0 +1,380 @@
|
||||
use anyhow::Result;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::ClientState;
|
||||
|
||||
/// Holds senders for all plugin config channels
|
||||
pub struct PluginChannels {
|
||||
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
|
||||
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
|
||||
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
|
||||
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
|
||||
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
|
||||
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
|
||||
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
pub async fn connect_and_run(
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
|
||||
info!("TCP connected to {}", state.server_addr);
|
||||
|
||||
if state.use_tls {
|
||||
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
|
||||
run_comm_loop(tls_stream, state, data_rx, plugins).await
|
||||
} else {
|
||||
run_comm_loop(tcp_stream, state, data_rx, plugins).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a TCP stream with TLS.
|
||||
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
|
||||
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
|
||||
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load custom CA certificate if specified
|
||||
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
|
||||
let ca_pem = std::fs::read(&ca_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
|
||||
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
|
||||
for cert in certs {
|
||||
root_store.add(cert)?;
|
||||
}
|
||||
info!("Loaded custom CA certificates from {}", ca_path);
|
||||
}
|
||||
|
||||
// Always include system roots as fallback
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
|
||||
warn!("TLS certificate verification DISABLED — do not use in production!");
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
||||
|
||||
// Extract hostname from server_addr (strip port)
|
||||
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
|
||||
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
|
||||
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
|
||||
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
info!("TLS handshake completed with {}", domain);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer],
|
||||
_server_name: &rustls_pki_types::ServerName,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Main communication loop over any read+write stream
|
||||
async fn run_comm_loop<S>(
|
||||
mut stream: S,
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||
{
|
||||
// Send registration
|
||||
let register = RegisterRequest {
|
||||
device_uid: state.device_uid.clone(),
|
||||
hostname: hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string()),
|
||||
registration_token: state.registration_token.clone(),
|
||||
os_version: get_os_info(),
|
||||
mac_address: None,
|
||||
};
|
||||
|
||||
let frame = Frame::new_json(MessageType::Register, ®ister)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
info!("Registration request sent");
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
// Clamp heartbeat interval to sane range [5, 3600] to prevent CPU spin or effective disable
|
||||
let heartbeat_secs = state.config.heartbeat_interval_secs.clamp(5, 3600);
|
||||
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
|
||||
heartbeat_interval.tick().await; // Skip first tick
|
||||
|
||||
// HMAC key — set after receiving RegisterResponse
|
||||
let mut device_secret: Option<String> = state.device_secret.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Read from server
|
||||
result = stream.read(&mut buffer) => {
|
||||
let n = result?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("Server closed connection"));
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Guard against unbounded buffer growth from a malicious server
|
||||
if read_buf.len() > 1_048_576 {
|
||||
return Err(anyhow::anyhow!("Read buffer exceeded 1MB, server may be malicious"));
|
||||
}
|
||||
|
||||
// Process complete frames
|
||||
loop {
|
||||
match Frame::decode(&read_buf)? {
|
||||
Some(frame) => {
|
||||
let consumed = frame.encoded_size();
|
||||
read_buf.drain(..consumed);
|
||||
// Capture device_secret from registration response
|
||||
if frame.msg_type == MessageType::RegisterResponse {
|
||||
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
|
||||
device_secret = Some(resp.device_secret.clone());
|
||||
crate::save_device_secret(&resp.device_secret);
|
||||
info!("Device secret received and persisted, HMAC enabled for heartbeats");
|
||||
}
|
||||
}
|
||||
handle_server_message(frame, plugins)?;
|
||||
}
|
||||
None => break, // Incomplete frame, wait for more data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queued data
|
||||
frame = data_rx.recv() => {
|
||||
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
}
|
||||
|
||||
// Heartbeat
|
||||
_ = heartbeat_interval.tick() => {
|
||||
let timestamp = chrono::Utc::now().to_rfc3339();
|
||||
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, ×tamp);
|
||||
let heartbeat = HeartbeatPayload {
|
||||
device_uid: state.device_uid.clone(),
|
||||
timestamp,
|
||||
hmac: hmac_value,
|
||||
};
|
||||
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::RegisterResponse => {
|
||||
let resp: RegisterResponse = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
|
||||
info!("Registration accepted by server (server version: {})", resp.config.server_version);
|
||||
}
|
||||
MessageType::PolicyUpdate => {
|
||||
let policy: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
|
||||
info!("Received policy update: {}", policy);
|
||||
}
|
||||
MessageType::ConfigUpdate => {
|
||||
info!("Received config update");
|
||||
}
|
||||
MessageType::TaskExecute => {
|
||||
warn!("Task execution requested (not yet implemented)");
|
||||
}
|
||||
MessageType::WatermarkConfig => {
|
||||
let config: WatermarkConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
|
||||
info!("Received watermark config: enabled={}", config.enabled);
|
||||
plugins.watermark_tx.send(Some(config))?;
|
||||
}
|
||||
MessageType::UsbPolicyUpdate => {
|
||||
let policy: UsbPolicyPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid USB policy: {}", e))?;
|
||||
info!("Received USB policy: type={}, enabled={}", policy.policy_type, policy.enabled);
|
||||
plugins.usb_policy_tx.send(Some(policy))?;
|
||||
}
|
||||
MessageType::WebFilterRuleUpdate => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
|
||||
info!("Received web filter rules update");
|
||||
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
|
||||
plugins.web_filter_tx.send(config)?;
|
||||
}
|
||||
MessageType::SoftwareBlacklist => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
|
||||
info!("Received software blacklist update");
|
||||
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
|
||||
plugins.software_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PopupRules => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
|
||||
info!("Received popup blocker rules update");
|
||||
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
|
||||
plugins.popup_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PluginEnable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
|
||||
info!("Plugin enabled: {}", payload.plugin_name);
|
||||
// Route to appropriate plugin channel based on plugin_name
|
||||
handle_plugin_control(&payload, plugins, true)?;
|
||||
}
|
||||
MessageType::PluginDisable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
|
||||
info!("Plugin disabled: {}", payload.plugin_name);
|
||||
handle_plugin_control(&payload, plugins, false)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_plugin_control(
|
||||
payload: &csm_protocol::PluginControlPayload,
|
||||
plugins: &PluginChannels,
|
||||
enabled: bool,
|
||||
) -> Result<()> {
|
||||
match payload.plugin_name.as_str() {
|
||||
"watermark" => {
|
||||
if !enabled {
|
||||
// Send disabled config to remove overlay
|
||||
plugins.watermark_tx.send(None)?;
|
||||
}
|
||||
// When enabling, server will push the actual config next
|
||||
}
|
||||
"web_filter" => {
|
||||
if !enabled {
|
||||
// Clear hosts rules on disable
|
||||
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
// When enabling, server will push rules
|
||||
}
|
||||
"software_blocker" => {
|
||||
if !enabled {
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
|
||||
}
|
||||
}
|
||||
"popup_blocker" => {
|
||||
if !enabled {
|
||||
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usb_audit" => {
|
||||
if !enabled {
|
||||
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usage_timer" => {
|
||||
if !enabled {
|
||||
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
|
||||
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
|
||||
let secret = match secret {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return String::new(),
|
||||
};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
fn get_os_info() -> String {
|
||||
use sysinfo::System;
|
||||
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
|
||||
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
|
||||
format!("{} {}", name, version)
|
||||
}
|
||||
370
crates/client/src/network/mod.rs.tmp.575580.1775308681874
Normal file
370
crates/client/src/network/mod.rs.tmp.575580.1775308681874
Normal file
@@ -0,0 +1,370 @@
|
||||
use anyhow::Result;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::ClientState;
|
||||
|
||||
/// Maximum accumulated read buffer size per connection (8 MB)
|
||||
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Holds senders for all plugin config channels
|
||||
pub struct PluginChannels {
|
||||
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
|
||||
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
|
||||
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
|
||||
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
|
||||
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
|
||||
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
pub async fn connect_and_run(
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
|
||||
info!("TCP connected to {}", state.server_addr);
|
||||
|
||||
if state.use_tls {
|
||||
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
|
||||
run_comm_loop(tls_stream, state, data_rx, plugins).await
|
||||
} else {
|
||||
run_comm_loop(tcp_stream, state, data_rx, plugins).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a TCP stream with TLS.
|
||||
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
|
||||
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
|
||||
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load custom CA certificate if specified
|
||||
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
|
||||
let ca_pem = std::fs::read(&ca_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
|
||||
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
|
||||
for cert in certs {
|
||||
root_store.add(cert)?;
|
||||
}
|
||||
info!("Loaded custom CA certificates from {}", ca_path);
|
||||
}
|
||||
|
||||
// Always include system roots as fallback
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
|
||||
warn!("TLS certificate verification DISABLED — do not use in production!");
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
||||
|
||||
// Extract hostname from server_addr (strip port)
|
||||
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
|
||||
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
|
||||
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
|
||||
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
info!("TLS handshake completed with {}", domain);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer],
|
||||
_server_name: &rustls_pki_types::ServerName,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Main communication loop over any read+write stream
|
||||
async fn run_comm_loop<S>(
|
||||
mut stream: S,
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||
{
|
||||
// Send registration
|
||||
let register = RegisterRequest {
|
||||
device_uid: state.device_uid.clone(),
|
||||
hostname: hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string()),
|
||||
registration_token: state.registration_token.clone(),
|
||||
os_version: get_os_info(),
|
||||
mac_address: None,
|
||||
};
|
||||
|
||||
let frame = Frame::new_json(MessageType::Register, ®ister)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
info!("Registration request sent");
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let heartbeat_secs = state.config.heartbeat_interval_secs;
|
||||
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
|
||||
heartbeat_interval.tick().await; // Skip first tick
|
||||
|
||||
// HMAC key — set after receiving RegisterResponse
|
||||
let mut device_secret: Option<String> = state.device_secret.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Read from server
|
||||
result = stream.read(&mut buffer) => {
|
||||
let n = result?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("Server closed connection"));
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Process complete frames
|
||||
loop {
|
||||
match Frame::decode(&read_buf)? {
|
||||
Some(frame) => {
|
||||
let consumed = frame.encoded_size();
|
||||
read_buf.drain(..consumed);
|
||||
// Capture device_secret from registration response
|
||||
if frame.msg_type == MessageType::RegisterResponse {
|
||||
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
|
||||
device_secret = Some(resp.device_secret.clone());
|
||||
crate::save_device_secret(&resp.device_secret);
|
||||
info!("Device secret received and persisted, HMAC enabled for heartbeats");
|
||||
}
|
||||
}
|
||||
handle_server_message(frame, plugins)?;
|
||||
}
|
||||
None => break, // Incomplete frame, wait for more data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queued data
|
||||
frame = data_rx.recv() => {
|
||||
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
}
|
||||
|
||||
// Heartbeat
|
||||
_ = heartbeat_interval.tick() => {
|
||||
let timestamp = chrono::Utc::now().to_rfc3339();
|
||||
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, ×tamp);
|
||||
let heartbeat = HeartbeatPayload {
|
||||
device_uid: state.device_uid.clone(),
|
||||
timestamp,
|
||||
hmac: hmac_value,
|
||||
};
|
||||
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::RegisterResponse => {
|
||||
let resp: RegisterResponse = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
|
||||
info!("Registration accepted by server (server version: {})", resp.config.server_version);
|
||||
}
|
||||
MessageType::PolicyUpdate => {
|
||||
let policy: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
|
||||
info!("Received policy update: {}", policy);
|
||||
}
|
||||
MessageType::ConfigUpdate => {
|
||||
info!("Received config update");
|
||||
}
|
||||
MessageType::TaskExecute => {
|
||||
warn!("Task execution requested (not yet implemented)");
|
||||
}
|
||||
MessageType::WatermarkConfig => {
|
||||
let config: WatermarkConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
|
||||
info!("Received watermark config: enabled={}", config.enabled);
|
||||
plugins.watermark_tx.send(Some(config))?;
|
||||
}
|
||||
MessageType::WebFilterRuleUpdate => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
|
||||
info!("Received web filter rules update");
|
||||
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
|
||||
plugins.web_filter_tx.send(config)?;
|
||||
}
|
||||
MessageType::SoftwareBlacklist => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
|
||||
info!("Received software blacklist update");
|
||||
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
|
||||
plugins.software_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PopupRules => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
|
||||
info!("Received popup blocker rules update");
|
||||
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
|
||||
plugins.popup_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PluginEnable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
|
||||
info!("Plugin enabled: {}", payload.plugin_name);
|
||||
// Route to appropriate plugin channel based on plugin_name
|
||||
handle_plugin_control(&payload, plugins, true)?;
|
||||
}
|
||||
MessageType::PluginDisable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
|
||||
info!("Plugin disabled: {}", payload.plugin_name);
|
||||
handle_plugin_control(&payload, plugins, false)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_plugin_control(
|
||||
payload: &csm_protocol::PluginControlPayload,
|
||||
plugins: &PluginChannels,
|
||||
enabled: bool,
|
||||
) -> Result<()> {
|
||||
match payload.plugin_name.as_str() {
|
||||
"watermark" => {
|
||||
if !enabled {
|
||||
// Send disabled config to remove overlay
|
||||
plugins.watermark_tx.send(None)?;
|
||||
}
|
||||
// When enabling, server will push the actual config next
|
||||
}
|
||||
"web_filter" => {
|
||||
if !enabled {
|
||||
// Clear hosts rules on disable
|
||||
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
// When enabling, server will push rules
|
||||
}
|
||||
"software_blocker" => {
|
||||
if !enabled {
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
|
||||
}
|
||||
}
|
||||
"popup_blocker" => {
|
||||
if !enabled {
|
||||
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usb_audit" => {
|
||||
if !enabled {
|
||||
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usage_timer" => {
|
||||
if !enabled {
|
||||
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
|
||||
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
|
||||
let secret = match secret {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return String::new(),
|
||||
};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
fn get_os_info() -> String {
|
||||
use sysinfo::System;
|
||||
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
|
||||
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
|
||||
format!("{} {}", name, version)
|
||||
}
|
||||
254
crates/client/src/popup_blocker/mod.rs
Normal file
254
crates/client/src/popup_blocker/mod.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Popup blocker rule from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct PopupRule {
|
||||
pub id: i64,
|
||||
pub rule_type: String,
|
||||
pub window_title: Option<String>,
|
||||
pub window_class: Option<String>,
|
||||
pub process_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Popup blocker configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PopupBlockerConfig {
|
||||
pub enabled: bool,
|
||||
pub rules: Vec<PopupRule>,
|
||||
}
|
||||
|
||||
/// Context passed to EnumWindows callback via LPARAM
|
||||
struct ScanContext {
|
||||
rules: Vec<PopupRule>,
|
||||
blocked_count: u32,
|
||||
}
|
||||
|
||||
/// Start popup blocker plugin.
|
||||
/// Periodically enumerates windows and closes those matching rules.
|
||||
pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
|
||||
info!("Popup blocker plugin started");
|
||||
let mut config = PopupBlockerConfig::default();
|
||||
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(2));
|
||||
scan_interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Popup blocker config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
|
||||
config = new_config;
|
||||
}
|
||||
_ = scan_interval.tick() => {
|
||||
if !config.enabled || config.rules.is_empty() {
|
||||
continue;
|
||||
}
|
||||
scan_and_block(&config.rules);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_and_block(rules: &[PopupRule]) {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::WindowsAndMessaging::EnumWindows;
|
||||
use windows::Win32::Foundation::LPARAM;
|
||||
|
||||
let mut ctx = ScanContext {
|
||||
rules: rules.to_vec(),
|
||||
blocked_count: 0,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let _ = EnumWindows(
|
||||
Some(enum_windows_callback),
|
||||
LPARAM(&mut ctx as *mut ScanContext as isize),
|
||||
);
|
||||
}
|
||||
if ctx.blocked_count > 0 {
|
||||
debug!("Popup scan blocked {} windows", ctx.blocked_count);
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = rules;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
unsafe extern "system" fn enum_windows_callback(
|
||||
hwnd: windows::Win32::Foundation::HWND,
|
||||
lparam: windows::Win32::Foundation::LPARAM,
|
||||
) -> windows::Win32::Foundation::BOOL {
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
|
||||
// Only consider visible, top-level windows without an owner
|
||||
if !IsWindowVisible(hwnd).as_bool() {
|
||||
return BOOL(1);
|
||||
}
|
||||
|
||||
// Skip windows that have an owner (they're child dialogs, not popups)
|
||||
if GetWindow(hwnd, GW_OWNER).0 != 0 {
|
||||
return BOOL(1);
|
||||
}
|
||||
|
||||
// Get window title
|
||||
let mut title_buf = [0u16; 512];
|
||||
let title_len = GetWindowTextW(hwnd, &mut title_buf);
|
||||
let title_len = title_len.max(0) as usize;
|
||||
let title = String::from_utf16_lossy(&title_buf[..title_len]);
|
||||
|
||||
// Skip windows with empty titles
|
||||
if title.is_empty() {
|
||||
return BOOL(1);
|
||||
}
|
||||
|
||||
// Get class name
|
||||
let mut class_buf = [0u16; 256];
|
||||
let class_len = GetClassNameW(hwnd, &mut class_buf);
|
||||
let class_len = class_len.max(0) as usize;
|
||||
let class_name = String::from_utf16_lossy(&class_buf[..class_len]);
|
||||
|
||||
// Get process name from PID
|
||||
let mut pid: u32 = 0;
|
||||
GetWindowThreadProcessId(hwnd, Some(&mut pid));
|
||||
let process_name = get_process_name(pid);
|
||||
|
||||
// Recover the ScanContext from LPARAM
|
||||
let ctx = &mut *(lparam.0 as *mut ScanContext);
|
||||
|
||||
// Check against each rule
|
||||
for rule in &ctx.rules {
|
||||
if rule.rule_type != "block" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let matches = rule_matches(rule, &title, &class_name, &process_name);
|
||||
if matches {
|
||||
let _ = PostMessageW(hwnd, WM_CLOSE, WPARAM(0), LPARAM(0));
|
||||
ctx.blocked_count += 1;
|
||||
info!(
|
||||
"Blocked popup: title='{}' class='{}' process='{}' (rule_id={})",
|
||||
title, class_name, process_name, rule.id
|
||||
);
|
||||
break; // One match is enough per window
|
||||
}
|
||||
}
|
||||
|
||||
BOOL(1) // Continue enumeration
|
||||
}
|
||||
|
||||
fn rule_matches(rule: &PopupRule, title: &str, class_name: &str, process_name: &str) -> bool {
|
||||
let title_match = match &rule.window_title {
|
||||
Some(pattern) => pattern_match(pattern, title),
|
||||
None => true, // No title filter = match all
|
||||
};
|
||||
|
||||
let class_match = match &rule.window_class {
|
||||
Some(pattern) => pattern_match(pattern, class_name),
|
||||
None => true,
|
||||
};
|
||||
|
||||
let process_match = match &rule.process_name {
|
||||
Some(pattern) => pattern_match(pattern, process_name),
|
||||
None => true,
|
||||
};
|
||||
|
||||
title_match && class_match && process_match
|
||||
}
|
||||
|
||||
/// Simple case-insensitive wildcard pattern matching.
|
||||
/// Supports `*` as wildcard (matches any characters).
|
||||
fn pattern_match(pattern: &str, text: &str) -> bool {
|
||||
let p = pattern.to_lowercase();
|
||||
let t = text.to_lowercase();
|
||||
|
||||
if !p.contains('*') {
|
||||
return t.contains(&p);
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = p.split('*').collect();
|
||||
if parts.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut pos = 0usize;
|
||||
let mut matched_any = false;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
// Leading empty = pattern starts with * → no start anchor
|
||||
// Trailing empty = pattern ends with * → no end anchor
|
||||
continue;
|
||||
}
|
||||
|
||||
matched_any = true;
|
||||
|
||||
if i == 0 && !parts[0].is_empty() {
|
||||
// Pattern starts with literal → must match at start
|
||||
if !t.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else {
|
||||
// Find this segment anywhere after current position
|
||||
match t[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If pattern ends with literal (no trailing *), must match at end
|
||||
if matched_any && !parts.last().map_or(true, |p| p.is_empty()) {
|
||||
return t.ends_with(parts.last().unwrap());
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn get_process_name(pid: u32) -> String {
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
|
||||
unsafe {
|
||||
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return format!("pid:{}", pid),
|
||||
};
|
||||
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
if entry.th32ProcessID == pid {
|
||||
let name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
let _ = CloseHandle(snapshot);
|
||||
return name;
|
||||
}
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if Process32NextW(snapshot, &mut entry).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
format!("pid:{}", pid)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn get_process_name(pid: u32) -> String {
|
||||
format!("pid:{}", pid)
|
||||
}
|
||||
254
crates/client/src/software_blocker/mod.rs
Normal file
254
crates/client/src/software_blocker/mod.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// System-critical processes that must never be killed regardless of server rules.
|
||||
/// Killing any of these would cause system instability or a BSOD.
|
||||
const PROTECTED_PROCESSES: &[&str] = &[
|
||||
"system",
|
||||
"system idle process",
|
||||
"svchost.exe",
|
||||
"lsass.exe",
|
||||
"csrss.exe",
|
||||
"wininit.exe",
|
||||
"winlogon.exe",
|
||||
"services.exe",
|
||||
"dwm.exe",
|
||||
"explorer.exe",
|
||||
"taskhostw.exe",
|
||||
"registry",
|
||||
"smss.exe",
|
||||
"conhost.exe",
|
||||
];
|
||||
|
||||
/// Software blacklist entry from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BlacklistEntry {
|
||||
pub id: i64,
|
||||
pub name_pattern: String,
|
||||
pub action: String,
|
||||
}
|
||||
|
||||
/// Software blocker configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SoftwareBlockerConfig {
|
||||
pub enabled: bool,
|
||||
pub blacklist: Vec<BlacklistEntry>,
|
||||
}
|
||||
|
||||
/// Start software blocker plugin.
|
||||
/// Periodically scans running processes against the blacklist.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<SoftwareBlockerConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Software blocker plugin started");
|
||||
let mut config = SoftwareBlockerConfig::default();
|
||||
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
|
||||
scan_interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
|
||||
config = new_config;
|
||||
}
|
||||
_ = scan_interval.tick() => {
|
||||
if !config.enabled || config.blacklist.is_empty() {
|
||||
continue;
|
||||
}
|
||||
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn scan_processes(
|
||||
blacklist: &[BlacklistEntry],
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
) {
|
||||
let running = get_running_processes_with_pids();
|
||||
|
||||
for entry in blacklist {
|
||||
for (process_name, pid) in &running {
|
||||
if pattern_matches(&entry.name_pattern, process_name) {
|
||||
// Never kill system-critical processes
|
||||
if is_protected_process(process_name) {
|
||||
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
|
||||
continue;
|
||||
}
|
||||
|
||||
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
|
||||
|
||||
// Report violation to server
|
||||
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
|
||||
let action_taken = match entry.action.as_str() {
|
||||
"block" => "blocked_install",
|
||||
"alert" => "alerted",
|
||||
other => other,
|
||||
};
|
||||
let violation = SoftwareViolationReport {
|
||||
device_uid: device_uid.to_string(),
|
||||
software_name: process_name.clone(),
|
||||
action_taken: action_taken.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
|
||||
// Kill the process directly by captured PID (avoids TOCTOU race)
|
||||
if entry.action == "block" {
|
||||
kill_process_by_pid(*pid, process_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pattern_matches(pattern: &str, name: &str) -> bool {
|
||||
let pattern_lower = pattern.to_lowercase();
|
||||
let name_lower = name.to_lowercase();
|
||||
// Support wildcard patterns
|
||||
if pattern_lower.contains('*') {
|
||||
let parts: Vec<&str> = pattern_lower.split('*').collect();
|
||||
let mut pos = 0;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if i == 0 && !parts[0].is_empty() {
|
||||
// Pattern starts with literal → must match at start
|
||||
if !name_lower.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else {
|
||||
match name_lower[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
// If pattern ends with literal (no trailing *), must match at end
|
||||
if !parts.last().map_or(true, |p| p.is_empty()) {
|
||||
return name_lower.ends_with(parts.last().unwrap());
|
||||
}
|
||||
true
|
||||
} else {
|
||||
name_lower.contains(&pattern_lower)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
|
||||
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
|
||||
let mut procs = Vec::new();
|
||||
unsafe {
|
||||
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return procs,
|
||||
};
|
||||
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
let name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
procs.push((name, entry.th32ProcessID));
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if !Process32NextW(snapshot, &mut entry).is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
}
|
||||
procs
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_process_by_pid(pid: u32, expected_name: &str) {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::System::Threading::{OpenProcess, TerminateProcess, PROCESS_TERMINATE};
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
|
||||
unsafe {
|
||||
// Verify the PID still belongs to the expected process (prevent PID reuse kills)
|
||||
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
let mut name_matches = false;
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
if entry.th32ProcessID == pid {
|
||||
let current_name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
name_matches = current_name.to_lowercase() == expected_name.to_lowercase();
|
||||
break;
|
||||
}
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if !Process32NextW(snapshot, &mut entry).is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
|
||||
if !name_matches {
|
||||
warn!("Skipping kill: PID {} no longer matches expected process '{}'", pid, expected_name);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(handle) = OpenProcess(PROCESS_TERMINATE, false, pid) {
|
||||
let terminated = TerminateProcess(handle, 1).is_ok();
|
||||
let _ = CloseHandle(handle);
|
||||
if terminated {
|
||||
warn!("Killed process: {} (pid={})", expected_name, pid);
|
||||
} else {
|
||||
warn!("TerminateProcess failed for: {} (pid={})", expected_name, pid);
|
||||
}
|
||||
} else {
|
||||
warn!("OpenProcess failed for: {} (pid={}) — insufficient privileges?", expected_name, pid);
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = (pid, expected_name);
|
||||
}
|
||||
}
|
||||
|
||||
fn is_protected_process(name: &str) -> bool {
|
||||
let lower = name.to_lowercase();
|
||||
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
|
||||
}
|
||||
197
crates/client/src/usage_timer/mod.rs
Normal file
197
crates/client/src/usage_timer/mod.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug};
|
||||
use csm_protocol::{Frame, MessageType, UsageDailyReport, AppUsageEntry};
|
||||
|
||||
/// Usage tracking configuration pushed from server
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct UsageConfig {
|
||||
pub enabled: bool,
|
||||
pub idle_threshold_secs: u64,
|
||||
pub report_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Start the usage timer plugin.
|
||||
/// Tracks active/idle time and foreground application usage.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<UsageConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Usage timer plugin started");
|
||||
|
||||
let mut config = UsageConfig {
|
||||
enabled: false,
|
||||
idle_threshold_secs: 300, // 5 minutes default
|
||||
report_interval_secs: 300,
|
||||
};
|
||||
|
||||
let mut report_interval = tokio::time::interval(Duration::from_secs(60));
|
||||
report_interval.tick().await;
|
||||
|
||||
let mut active_secs: u64 = 0;
|
||||
let mut idle_secs: u64 = 0;
|
||||
let mut app_usage: std::collections::HashMap<String, u64> = std::collections::HashMap::new();
|
||||
let mut last_tick = Instant::now();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
if new_config.enabled != config.enabled {
|
||||
info!("Usage timer enabled: {}", new_config.enabled);
|
||||
}
|
||||
config = new_config;
|
||||
if config.enabled {
|
||||
report_interval = tokio::time::interval(Duration::from_secs(config.report_interval_secs));
|
||||
report_interval.tick().await;
|
||||
last_tick = Instant::now();
|
||||
}
|
||||
}
|
||||
_ = report_interval.tick() => {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Measure actual elapsed time since last tick
|
||||
let now = Instant::now();
|
||||
let elapsed_secs = now.duration_since(last_tick).as_secs();
|
||||
last_tick = now;
|
||||
|
||||
let idle_ms = get_idle_millis();
|
||||
let is_idle = idle_ms as u64 > config.idle_threshold_secs * 1000;
|
||||
|
||||
if is_idle {
|
||||
idle_secs += elapsed_secs;
|
||||
} else {
|
||||
active_secs += elapsed_secs;
|
||||
}
|
||||
|
||||
// Track foreground app
|
||||
if let Some(app) = get_foreground_app_name() {
|
||||
*app_usage.entry(app).or_insert(0) += elapsed_secs;
|
||||
}
|
||||
|
||||
// Report usage to server
|
||||
if active_secs > 0 || idle_secs > 0 {
|
||||
let report = UsageDailyReport {
|
||||
device_uid: device_uid.clone(),
|
||||
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
|
||||
total_active_minutes: (active_secs / 60) as u32,
|
||||
total_idle_minutes: (idle_secs / 60) as u32,
|
||||
first_active_at: None,
|
||||
last_active_at: Some(chrono::Utc::now().to_rfc3339()),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsageReport, &report) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Report per-app usage
|
||||
for (app, secs) in &app_usage {
|
||||
let entry = AppUsageEntry {
|
||||
device_uid: device_uid.clone(),
|
||||
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
|
||||
app_name: app.clone(),
|
||||
usage_minutes: (secs / 60) as u32,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::AppUsageReport, &entry) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Reset counters after reporting
|
||||
active_secs = 0;
|
||||
idle_secs = 0;
|
||||
app_usage.clear();
|
||||
}
|
||||
|
||||
debug!("Usage report sent (idle_ms={}, elapsed={}s)", idle_ms, elapsed_secs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get system idle time in milliseconds using GetLastInputInfo + GetTickCount64.
|
||||
/// Both use the same time base. LASTINPUTINFO.dwTime is u32 (GetTickCount legacy),
|
||||
/// so we take the low 32 bits of GetTickCount64 for correct wrapping comparison.
|
||||
fn get_idle_millis() -> u32 {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::Input::KeyboardAndMouse::{GetLastInputInfo, LASTINPUTINFO};
|
||||
|
||||
unsafe {
|
||||
let mut lii = LASTINPUTINFO {
|
||||
cbSize: std::mem::size_of::<LASTINPUTINFO>() as u32,
|
||||
dwTime: 0,
|
||||
};
|
||||
if GetLastInputInfo(&mut lii).as_bool() {
|
||||
// GetTickCount64 returns u64, but dwTime is u32.
|
||||
// Take low 32 bits so wrapping_sub produces correct idle delta.
|
||||
let tick_low32 = (windows::Win32::System::SystemInformation::GetTickCount64() & 0xFFFFFFFF) as u32;
|
||||
return tick_low32.wrapping_sub(lii.dwTime);
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the foreground window's process name using CreateToolhelp32Snapshot
|
||||
fn get_foreground_app_name() -> Option<String> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::WindowsAndMessaging::{GetForegroundWindow, GetWindowThreadProcessId};
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
|
||||
unsafe {
|
||||
let hwnd = GetForegroundWindow();
|
||||
if hwnd.0 == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pid: u32 = 0;
|
||||
GetWindowThreadProcessId(hwnd, Some(&mut pid));
|
||||
if pid == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get process name via CreateToolhelp32Snapshot
|
||||
let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()?;
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
if entry.th32ProcessID == pid {
|
||||
let name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
let _ = CloseHandle(snapshot);
|
||||
return Some(name);
|
||||
}
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if !Process32NextW(snapshot, &mut entry).is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
249
crates/client/src/usb/mod.rs
Normal file
249
crates/client/src/usb/mod.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use csm_protocol::{Frame, MessageType, UsbEvent, UsbEventType, UsbPolicyPayload, UsbDeviceRule};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc::Sender, watch};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Start USB monitoring with policy enforcement.
|
||||
/// Monitors for removable drive insertions/removals and enforces USB policies.
|
||||
pub async fn start_monitoring(
|
||||
tx: Sender<Frame>,
|
||||
device_uid: String,
|
||||
mut policy_rx: watch::Receiver<Option<UsbPolicyPayload>>,
|
||||
) {
|
||||
info!("USB monitoring started for {}", device_uid);
|
||||
|
||||
let mut current_policy: Option<UsbPolicyPayload> = None;
|
||||
let mut known_drives: Vec<String> = Vec::new();
|
||||
let interval = Duration::from_secs(10);
|
||||
|
||||
loop {
|
||||
// Check for policy updates (non-blocking)
|
||||
if policy_rx.has_changed().unwrap_or(false) {
|
||||
let new_policy = policy_rx.borrow_and_update().clone();
|
||||
let policy_desc = new_policy.as_ref()
|
||||
.map(|p| format!("type={}, enabled={}", p.policy_type, p.enabled))
|
||||
.unwrap_or_else(|| "None".to_string());
|
||||
info!("USB policy updated: {}", policy_desc);
|
||||
current_policy = new_policy;
|
||||
|
||||
// Re-check all currently mounted drives against new policy
|
||||
let drives = detect_removable_drives();
|
||||
for drive in &drives {
|
||||
if should_block_drive(drive, ¤t_policy) {
|
||||
warn!("USB device {} blocked by policy, ejecting...", drive);
|
||||
if let Err(e) = eject_drive(drive) {
|
||||
warn!("Failed to eject {}: {}", drive, e);
|
||||
} else {
|
||||
info!("Successfully ejected {}", drive);
|
||||
}
|
||||
// Report blocked event
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let current_drives = detect_removable_drives();
|
||||
|
||||
// Detect new drives (inserted)
|
||||
for drive in ¤t_drives {
|
||||
if !known_drives.iter().any(|d: &String| d == drive) {
|
||||
info!("USB device inserted: {}", drive);
|
||||
|
||||
// Check against policy before reporting
|
||||
if should_block_drive(drive, ¤t_policy) {
|
||||
warn!("USB device {} blocked by policy, ejecting...", drive);
|
||||
if let Err(e) = eject_drive(drive) {
|
||||
warn!("Failed to eject {}: {}", drive, e);
|
||||
} else {
|
||||
info!("Successfully ejected {}", drive);
|
||||
}
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Inserted, drive).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Detect removed drives
|
||||
for drive in &known_drives {
|
||||
if !current_drives.iter().any(|d| d == drive) {
|
||||
info!("USB device removed: {}", drive);
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Removed, drive).await;
|
||||
}
|
||||
}
|
||||
|
||||
known_drives = current_drives;
|
||||
tokio::time::sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_usb_event(
|
||||
tx: &Sender<Frame>,
|
||||
device_uid: &str,
|
||||
event_type: UsbEventType,
|
||||
drive: &str,
|
||||
) {
|
||||
let event = UsbEvent {
|
||||
device_uid: device_uid.to_string(),
|
||||
event_type,
|
||||
vendor_id: None,
|
||||
product_id: None,
|
||||
serial: None,
|
||||
device_name: Some(drive.to_string()),
|
||||
};
|
||||
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsbEvent, &event) {
|
||||
tx.send(frame).await.ok();
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a drive should be blocked based on the current policy.
|
||||
fn should_block_drive(drive: &str, policy: &Option<UsbPolicyPayload>) -> bool {
|
||||
let policy = match policy {
|
||||
Some(p) if p.enabled => p,
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
match policy.policy_type.as_str() {
|
||||
"all_block" => true,
|
||||
"blacklist" => {
|
||||
// Block if any rule matches
|
||||
policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
|
||||
}
|
||||
"whitelist" => {
|
||||
// Block if NOT in whitelist (empty whitelist = block all)
|
||||
if policy.rules.is_empty() {
|
||||
return true;
|
||||
}
|
||||
!policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a device matches a rule pattern.
|
||||
/// Currently matches by device_name (drive letter path) since that's what we detect.
|
||||
fn device_matches_rule(drive: &str, rule: &UsbDeviceRule) -> bool {
|
||||
// Match by device name (drive root path like "E:\")
|
||||
if let Some(ref name) = rule.device_name {
|
||||
if drive.eq_ignore_ascii_case(name) || drive.eq_ignore_ascii_case(&name.replace('\\', "")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// vendor_id, product_id, serial matching would require WMI or SetupDi queries
|
||||
// For now, these are placeholder checks
|
||||
let _ = (drive, rule);
|
||||
false
|
||||
}
|
||||
|
||||
/// DRIVE_REMOVABLE = 2 in Windows API
|
||||
const DRIVE_REMOVABLE: u32 = 2;
|
||||
|
||||
fn detect_removable_drives() -> Vec<String> {
|
||||
let mut drives = Vec::new();
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
for letter in b'A'..=b'Z' {
|
||||
let root = format!("{}:\\", letter as char);
|
||||
let root_wide: Vec<u16> = root.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
unsafe {
|
||||
let drive_type = windows::Win32::Storage::FileSystem::GetDriveTypeW(
|
||||
windows::core::PCWSTR(root_wide.as_ptr()),
|
||||
);
|
||||
if drive_type == DRIVE_REMOVABLE {
|
||||
drives.push(root);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = &drives; // Suppress unused warning on non-Windows
|
||||
}
|
||||
|
||||
drives
|
||||
}
|
||||
|
||||
/// Eject a removable drive using Windows API.
|
||||
/// Opens the volume handle, dismounts the filesystem, and ejects the media.
|
||||
#[cfg(target_os = "windows")]
|
||||
fn eject_drive(drive: &str) -> Result<(), anyhow::Error> {
|
||||
use windows::Win32::Storage::FileSystem::*;
|
||||
use windows::Win32::System::IO::DeviceIoControl;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::core::PCWSTR;
|
||||
|
||||
// IOCTL control codes (not exposed directly in windows crate)
|
||||
const FSCTL_DISMOUNT_VOLUME: u32 = 0x00090020;
|
||||
const IOCTL_STORAGE_EJECT_MEDIA: u32 = 0x002D4808;
|
||||
|
||||
// Extract drive letter from path like "E:\"
|
||||
let letter = drive.chars().next().unwrap_or('A');
|
||||
let path = format!("\\\\.\\{}:\0", letter);
|
||||
let path_wide: Vec<u16> = path.encode_utf16().collect();
|
||||
|
||||
unsafe {
|
||||
let handle = CreateFileW(
|
||||
PCWSTR(path_wide.as_ptr()),
|
||||
GENERIC_READ.0,
|
||||
FILE_SHARE_READ | FILE_SHARE_WRITE,
|
||||
None,
|
||||
OPEN_EXISTING,
|
||||
FILE_ATTRIBUTE_NORMAL,
|
||||
None,
|
||||
)?;
|
||||
|
||||
if handle.is_invalid() {
|
||||
return Err(anyhow::anyhow!("Failed to open volume handle for {}", drive));
|
||||
}
|
||||
|
||||
// Dismount the filesystem
|
||||
let mut bytes_returned = 0u32;
|
||||
let dismount_ok = DeviceIoControl(
|
||||
handle,
|
||||
FSCTL_DISMOUNT_VOLUME,
|
||||
None,
|
||||
0,
|
||||
None,
|
||||
0,
|
||||
Some(&mut bytes_returned),
|
||||
None,
|
||||
).is_ok();
|
||||
|
||||
if !dismount_ok {
|
||||
warn!("FSCTL_DISMOUNT_VOLUME failed for {}", drive);
|
||||
}
|
||||
|
||||
// Eject the media
|
||||
let eject_ok = DeviceIoControl(
|
||||
handle,
|
||||
IOCTL_STORAGE_EJECT_MEDIA,
|
||||
None,
|
||||
0,
|
||||
None,
|
||||
0,
|
||||
Some(&mut bytes_returned),
|
||||
None,
|
||||
).is_ok();
|
||||
|
||||
if !eject_ok {
|
||||
warn!("IOCTL_STORAGE_EJECT_MEDIA failed for {}", drive);
|
||||
}
|
||||
|
||||
let _ = CloseHandle(handle);
|
||||
|
||||
if !dismount_ok && !eject_ok {
|
||||
Err(anyhow::anyhow!("Failed to eject {}", drive))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn eject_drive(_drive: &str) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
190
crates/client/src/usb_audit/mod.rs
Normal file
190
crates/client/src/usb_audit/mod.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, UsbFileOpEntry};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// USB file audit configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct UsbAuditConfig {
|
||||
pub enabled: bool,
|
||||
pub monitored_extensions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Start USB file audit plugin.
|
||||
/// Detects removable drives and monitors file changes via periodic scanning.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<UsbAuditConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("USB file audit plugin started");
|
||||
let mut config = UsbAuditConfig::default();
|
||||
let mut active_drives: HashSet<String> = HashSet::new();
|
||||
// Track file listings per drive to detect changes
|
||||
let mut drive_snapshots: HashMap<String, HashSet<String>> = HashMap::new();
|
||||
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
|
||||
interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("USB audit config updated: enabled={}", new_config.enabled);
|
||||
config = new_config;
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
let drives = detect_removable_drives();
|
||||
|
||||
// Detect new drives
|
||||
for drive in &drives {
|
||||
if !active_drives.contains(drive) {
|
||||
info!("New removable drive detected: {}", drive);
|
||||
// Take initial snapshot in blocking thread to avoid freezing
|
||||
let drive_clone = drive.clone();
|
||||
let files = tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await;
|
||||
match files {
|
||||
Ok(files) => { drive_snapshots.insert(drive.clone(), files); }
|
||||
Err(e) => { warn!("Failed to scan drive {}: {}", drive, e); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scan existing drives for changes (each in a blocking thread)
|
||||
for drive in &drives {
|
||||
let drive_clone = drive.clone();
|
||||
let current_files = match tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await {
|
||||
Ok(files) => files,
|
||||
Err(e) => { warn!("Drive scan failed for {}: {}", drive, e); continue; }
|
||||
};
|
||||
if let Some(prev_files) = drive_snapshots.get_mut(drive) {
|
||||
// Find new files (created)
|
||||
for file in ¤t_files {
|
||||
if !prev_files.contains(file) {
|
||||
report_file_op(&data_tx, &device_uid, drive, file, "create", &config.monitored_extensions).await;
|
||||
}
|
||||
}
|
||||
// Find deleted files
|
||||
for file in prev_files.iter() {
|
||||
if !current_files.contains(file) {
|
||||
report_file_op(&data_tx, &device_uid, drive, file, "delete", &config.monitored_extensions).await;
|
||||
}
|
||||
}
|
||||
*prev_files = current_files;
|
||||
} else {
|
||||
drive_snapshots.insert(drive.clone(), current_files);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up removed drives
|
||||
active_drives.retain(|d| drives.contains(d));
|
||||
drive_snapshots.retain(|d, _| drives.contains(d));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn report_file_op(
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
drive: &str,
|
||||
file_path: &str,
|
||||
operation: &str,
|
||||
ext_filter: &[String],
|
||||
) {
|
||||
// Check extension filter
|
||||
let should_report = if ext_filter.is_empty() {
|
||||
true
|
||||
} else {
|
||||
let file_ext = std::path::Path::new(file_path)
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e.to_lowercase());
|
||||
file_ext.map_or(true, |ext| {
|
||||
ext_filter.iter().any(|f| f.to_lowercase() == ext)
|
||||
})
|
||||
};
|
||||
|
||||
if should_report {
|
||||
let entry = UsbFileOpEntry {
|
||||
device_uid: device_uid.to_string(),
|
||||
usb_serial: None,
|
||||
drive_letter: Some(drive.to_string()),
|
||||
operation: operation.to_string(),
|
||||
file_path: file_path.to_string(),
|
||||
file_size: None,
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsbFileOp, &entry) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
debug!("USB file op: {} {}", operation, file_path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively scan a drive and return all file paths
|
||||
fn scan_drive_files(drive: &str) -> HashSet<String> {
|
||||
let mut files = HashSet::new();
|
||||
let max_depth = 3; // Limit recursion depth for performance
|
||||
scan_dir_recursive(drive, &mut files, 0, max_depth);
|
||||
files
|
||||
}
|
||||
|
||||
fn scan_dir_recursive(dir: &str, files: &mut HashSet<String>, depth: usize, max_depth: usize) {
|
||||
if depth >= max_depth {
|
||||
return;
|
||||
}
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
if path.is_dir() {
|
||||
scan_dir_recursive(&path_str, files, depth + 1, max_depth);
|
||||
} else {
|
||||
files.insert(path_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_removable_drives() -> Vec<String> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::Storage::FileSystem::GetDriveTypeW;
|
||||
use windows::core::PCWSTR;
|
||||
|
||||
let mut drives = Vec::new();
|
||||
let mask = unsafe { windows::Win32::Storage::FileSystem::GetLogicalDrives() };
|
||||
let mut mask = mask as u32;
|
||||
let mut letter = b'A';
|
||||
|
||||
while mask != 0 {
|
||||
if mask & 1 != 0 {
|
||||
let drive_letter = format!("{}:\\", letter as char);
|
||||
let drive_wide: Vec<u16> = drive_letter.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let drive_type = unsafe {
|
||||
GetDriveTypeW(PCWSTR(drive_wide.as_ptr()))
|
||||
};
|
||||
// DRIVE_REMOVABLE = 2
|
||||
if drive_type == 2 {
|
||||
drives.push(drive_letter);
|
||||
}
|
||||
}
|
||||
mask >>= 1;
|
||||
letter += 1;
|
||||
}
|
||||
drives
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
313
crates/client/src/watermark/mod.rs
Normal file
313
crates/client/src/watermark/mod.rs
Normal file
@@ -0,0 +1,313 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn, error};
|
||||
use csm_protocol::WatermarkConfigPayload;
|
||||
|
||||
/// Watermark overlay state
|
||||
struct WatermarkState {
|
||||
enabled: bool,
|
||||
content: String,
|
||||
font_size: u32,
|
||||
opacity: f64,
|
||||
color: String,
|
||||
angle: i32,
|
||||
hwnd: Option<isize>,
|
||||
}
|
||||
|
||||
impl Default for WatermarkState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
content: String::new(),
|
||||
font_size: 14,
|
||||
opacity: 0.15,
|
||||
color: "#808080".to_string(),
|
||||
angle: -30,
|
||||
hwnd: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static WATERMARK_STATE: std::sync::OnceLock<Arc<Mutex<WatermarkState>>> = std::sync::OnceLock::new();
|
||||
|
||||
const WM_USER_UPDATE: u32 = 0x0401;
|
||||
const OVERLAY_CLASS_NAME: &str = "CSM_WatermarkOverlay";
|
||||
|
||||
/// Start the watermark plugin, listening for config updates.
|
||||
pub async fn start(mut rx: watch::Receiver<Option<WatermarkConfigPayload>>) {
|
||||
info!("Watermark plugin started");
|
||||
|
||||
let state = WATERMARK_STATE.get_or_init(|| Arc::new(Mutex::new(WatermarkState::default())));
|
||||
let state_clone = state.clone();
|
||||
|
||||
// Spawn a dedicated thread for the Win32 message loop
|
||||
std::thread::spawn(move || {
|
||||
#[cfg(target_os = "windows")]
|
||||
run_overlay_message_loop(state_clone);
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = state_clone;
|
||||
info!("Watermark overlay not supported on this platform");
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.changed() => {
|
||||
if result.is_err() {
|
||||
warn!("Watermark config channel closed");
|
||||
break;
|
||||
}
|
||||
let config = rx.borrow_and_update().clone();
|
||||
if let Some(cfg) = config {
|
||||
info!("Watermark config updated: enabled={}, content='{}'", cfg.enabled, cfg.content);
|
||||
|
||||
{
|
||||
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
s.enabled = cfg.enabled;
|
||||
s.content = cfg.content.clone();
|
||||
s.font_size = cfg.font_size;
|
||||
s.opacity = cfg.opacity;
|
||||
s.color = cfg.color.clone();
|
||||
s.angle = cfg.angle;
|
||||
}
|
||||
|
||||
// Post message to overlay window thread
|
||||
#[cfg(target_os = "windows")]
|
||||
post_overlay_update();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn post_overlay_update() {
|
||||
use windows::Win32::UI::WindowsAndMessaging::{FindWindowA, PostMessageW};
|
||||
use windows::Win32::Foundation::{WPARAM, LPARAM};
|
||||
use windows::core::PCSTR;
|
||||
|
||||
unsafe {
|
||||
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
|
||||
let hwnd = FindWindowA(PCSTR(class_name.as_ptr()), PCSTR::null());
|
||||
if hwnd.0 != 0 {
|
||||
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn run_overlay_message_loop(state: Arc<Mutex<WatermarkState>>) {
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::System::LibraryLoader::GetModuleHandleA;
|
||||
use windows::core::PCSTR;
|
||||
|
||||
// Register window class
|
||||
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
|
||||
let instance = unsafe { GetModuleHandleA(PCSTR::null()) }.unwrap_or_default();
|
||||
let instance: HINSTANCE = instance.into();
|
||||
|
||||
let wc = WNDCLASSA {
|
||||
style: CS_HREDRAW | CS_VREDRAW,
|
||||
lpfnWndProc: Some(watermark_wnd_proc),
|
||||
hInstance: instance,
|
||||
lpszClassName: PCSTR(class_name.as_ptr()),
|
||||
hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
unsafe { RegisterClassA(&wc); }
|
||||
|
||||
let screen_w = unsafe { GetSystemMetrics(SM_CXSCREEN) };
|
||||
let screen_h = unsafe { GetSystemMetrics(SM_CYSCREEN) };
|
||||
|
||||
// WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW
|
||||
let ex_style = WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW;
|
||||
let style = WS_POPUP;
|
||||
|
||||
let hwnd = unsafe {
|
||||
CreateWindowExA(
|
||||
ex_style,
|
||||
PCSTR(class_name.as_ptr()),
|
||||
PCSTR("CSM Watermark\0".as_ptr()),
|
||||
style,
|
||||
0, 0, screen_w, screen_h,
|
||||
HWND::default(),
|
||||
HMENU::default(),
|
||||
instance,
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
if hwnd.0 == 0 {
|
||||
error!("Failed to create watermark overlay window");
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
s.hwnd = Some(hwnd.0);
|
||||
}
|
||||
|
||||
info!("Watermark overlay window created (hwnd={})", hwnd.0);
|
||||
|
||||
// Post initial update to self — ensures overlay shows even if config
|
||||
// arrived before window creation completed (race between threads)
|
||||
unsafe {
|
||||
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
|
||||
}
|
||||
|
||||
// Message loop
|
||||
let mut msg = MSG::default();
|
||||
unsafe {
|
||||
loop {
|
||||
match GetMessageA(&mut msg, HWND::default(), 0, 0) {
|
||||
BOOL(0) | BOOL(-1) => break,
|
||||
_ => {
|
||||
TranslateMessage(&msg);
|
||||
DispatchMessageA(&msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
unsafe extern "system" fn watermark_wnd_proc(
|
||||
hwnd: windows::Win32::Foundation::HWND,
|
||||
msg: u32,
|
||||
wparam: windows::Win32::Foundation::WPARAM,
|
||||
lparam: windows::Win32::Foundation::LPARAM,
|
||||
) -> windows::Win32::Foundation::LRESULT {
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
|
||||
match msg {
|
||||
WM_PAINT => {
|
||||
if let Some(state) = WATERMARK_STATE.get() {
|
||||
let s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if s.enabled && !s.content.is_empty() {
|
||||
paint_watermark(hwnd, &s);
|
||||
}
|
||||
}
|
||||
LRESULT(0)
|
||||
}
|
||||
WM_ERASEBKGND => {
|
||||
// Fill with black (will be color-keyed to transparent)
|
||||
let hdc = HDC(wparam.0 as isize);
|
||||
unsafe {
|
||||
let rect = {
|
||||
let mut r = RECT::default();
|
||||
let _ = GetClientRect(hwnd, &mut r);
|
||||
r
|
||||
};
|
||||
let brush = CreateSolidBrush(COLORREF(0x00000000));
|
||||
let _ = FillRect(hdc, &rect, brush);
|
||||
let _ = DeleteObject(brush);
|
||||
}
|
||||
LRESULT(1)
|
||||
}
|
||||
m if m == WM_USER_UPDATE => {
|
||||
if let Some(state) = WATERMARK_STATE.get() {
|
||||
let s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if s.enabled {
|
||||
let _ = SetWindowPos(
|
||||
hwnd, HWND_TOPMOST,
|
||||
0, 0,
|
||||
GetSystemMetrics(SM_CXSCREEN),
|
||||
GetSystemMetrics(SM_CYSCREEN),
|
||||
SWP_SHOWWINDOW,
|
||||
);
|
||||
let alpha = (s.opacity * 255.0).clamp(0.0, 255.0) as u8;
|
||||
// Color key black background to transparent, apply alpha to text
|
||||
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY | LWA_ALPHA);
|
||||
let _ = InvalidateRect(hwnd, None, true);
|
||||
} else {
|
||||
let _ = ShowWindow(hwnd, SW_HIDE);
|
||||
}
|
||||
}
|
||||
LRESULT(0)
|
||||
}
|
||||
WM_DESTROY => {
|
||||
PostQuitMessage(0);
|
||||
LRESULT(0)
|
||||
}
|
||||
_ => DefWindowProcA(hwnd, msg, wparam, lparam),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkState) {
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::core::PCSTR;
|
||||
|
||||
unsafe {
|
||||
let mut ps = PAINTSTRUCT::default();
|
||||
let hdc = BeginPaint(hwnd, &mut ps);
|
||||
|
||||
let color = parse_color(&state.color);
|
||||
let font_size = state.font_size.max(1);
|
||||
|
||||
// Create font with rotation
|
||||
let font = CreateFontA(
|
||||
(font_size as i32) * 2,
|
||||
0,
|
||||
(state.angle as i32) * 10,
|
||||
0,
|
||||
FW_NORMAL.0 as i32,
|
||||
0, 0, 0,
|
||||
DEFAULT_CHARSET.0 as u32,
|
||||
OUT_DEFAULT_PRECIS.0 as u32,
|
||||
CLIP_DEFAULT_PRECIS.0 as u32,
|
||||
DEFAULT_QUALITY.0 as u32,
|
||||
DEFAULT_PITCH.0 as u32 | FF_DONTCARE.0 as u32,
|
||||
PCSTR("Arial\0".as_ptr()),
|
||||
);
|
||||
|
||||
let old_font = SelectObject(hdc, font);
|
||||
|
||||
let _ = SetBkMode(hdc, TRANSPARENT);
|
||||
let _ = SetTextColor(hdc, color);
|
||||
|
||||
// Draw tiled watermark text
|
||||
let screen_w = GetSystemMetrics(SM_CXSCREEN);
|
||||
let screen_h = GetSystemMetrics(SM_CYSCREEN);
|
||||
|
||||
let content_bytes: Vec<u8> = state.content.bytes().chain(std::iter::once(0)).collect();
|
||||
let text_slice = &content_bytes[..content_bytes.len().saturating_sub(1)];
|
||||
|
||||
let spacing_x = 400i32;
|
||||
let spacing_y = 200i32;
|
||||
|
||||
let mut y = -100i32;
|
||||
while y < screen_h + 100 {
|
||||
let mut x = -200i32;
|
||||
while x < screen_w + 200 {
|
||||
let _ = TextOutA(hdc, x, y, text_slice);
|
||||
x += spacing_x;
|
||||
}
|
||||
y += spacing_y;
|
||||
}
|
||||
|
||||
SelectObject(hdc, old_font);
|
||||
let _ = DeleteObject(font);
|
||||
EndPaint(hwnd, &ps);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_color(hex: &str) -> windows::Win32::Foundation::COLORREF {
|
||||
use windows::Win32::Foundation::COLORREF;
|
||||
let hex = hex.trim_start_matches('#');
|
||||
if hex.len() == 6 {
|
||||
let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(128);
|
||||
let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(128);
|
||||
let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(128);
|
||||
COLORREF((b as u32) << 16 | (g as u32) << 8 | (r as u32))
|
||||
} else {
|
||||
COLORREF(0x00808080)
|
||||
}
|
||||
}
|
||||
169
crates/client/src/web_filter/mod.rs
Normal file
169
crates/client/src/web_filter/mod.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn, debug};
|
||||
use serde::Deserialize;
|
||||
use std::io;
|
||||
|
||||
/// Web filter rule from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct WebFilterRule {
|
||||
pub id: i64,
|
||||
pub rule_type: String,
|
||||
pub pattern: String,
|
||||
}
|
||||
|
||||
/// Web filter configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct WebFilterConfig {
|
||||
pub enabled: bool,
|
||||
pub rules: Vec<WebFilterRule>,
|
||||
}
|
||||
|
||||
/// Start web filter plugin.
|
||||
/// Manages the hosts file to block/allow URLs based on server rules.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<WebFilterConfig>,
|
||||
) {
|
||||
info!("Web filter plugin started");
|
||||
let mut _config = WebFilterConfig::default();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Web filter config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
|
||||
|
||||
if new_config.enabled {
|
||||
match apply_hosts_rules(&new_config.rules) {
|
||||
Ok(()) => info!("Web filter hosts rules applied ({} rules)", new_config.rules.len()),
|
||||
Err(e) => warn!("Failed to apply hosts rules: {}", e),
|
||||
}
|
||||
} else {
|
||||
match clear_hosts_rules() {
|
||||
Ok(()) => info!("Web filter hosts rules cleared"),
|
||||
Err(e) => warn!("Failed to clear hosts rules: {}", e),
|
||||
}
|
||||
}
|
||||
_config = new_config;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const HOSTS_MARKER_START: &str = "# CSM_WEB_FILTER_START";
|
||||
const HOSTS_MARKER_END: &str = "# CSM_WEB_FILTER_END";
|
||||
|
||||
fn apply_hosts_rules(rules: &[WebFilterRule]) -> io::Result<()> {
|
||||
let hosts_path = get_hosts_path();
|
||||
let original = std::fs::read_to_string(&hosts_path)?;
|
||||
|
||||
// Remove existing CSM block
|
||||
let clean = remove_csm_block(&original);
|
||||
|
||||
// Build new block with block rules
|
||||
let block_rules: Vec<&WebFilterRule> = rules.iter()
|
||||
.filter(|r| r.rule_type == "block")
|
||||
.filter(|r| is_valid_hosts_entry(&r.pattern))
|
||||
.collect();
|
||||
|
||||
if block_rules.is_empty() {
|
||||
atomic_write(&hosts_path, &clean)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut new_block = format!("{}\n", HOSTS_MARKER_START);
|
||||
for rule in &block_rules {
|
||||
// Redirect blocked domains to 127.0.0.1
|
||||
new_block.push_str(&format!("127.0.0.1 {}\n", rule.pattern));
|
||||
}
|
||||
new_block.push_str(HOSTS_MARKER_END);
|
||||
new_block.push('\n');
|
||||
|
||||
let new_content = format!("{}{}", clean, new_block);
|
||||
atomic_write(&hosts_path, &new_content)?;
|
||||
|
||||
debug!("Applied {} web filter rules to hosts file", block_rules.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_hosts_rules() -> io::Result<()> {
|
||||
let hosts_path = get_hosts_path();
|
||||
let original = std::fs::read_to_string(&hosts_path)?;
|
||||
let clean = remove_csm_block(&original);
|
||||
atomic_write(&hosts_path, &clean)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write content to a file atomically.
|
||||
/// On Windows, hosts file may be locked by the DNS cache service,
|
||||
/// so we use direct overwrite with truncation instead of rename.
|
||||
fn atomic_write(path: &str, content: &str) -> io::Result<()> {
|
||||
use std::io::Write;
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.create(true)
|
||||
.open(path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
file.sync_data()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_csm_block(content: &str) -> String {
|
||||
// Handle paired markers (normal case)
|
||||
let start_idx = content.find(HOSTS_MARKER_START);
|
||||
let end_idx = content.find(HOSTS_MARKER_END);
|
||||
|
||||
match (start_idx, end_idx) {
|
||||
(Some(s), Some(e)) if e > s => {
|
||||
let mut result = String::new();
|
||||
result.push_str(&content[..s]);
|
||||
result.push_str(&content[e + HOSTS_MARKER_END.len()..]);
|
||||
return result.trim_end().to_string() + "\n";
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Handle orphan markers: remove any lone START or END marker line
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let cleaned: Vec<&str> = lines.iter()
|
||||
.filter(|line| {
|
||||
let trimmed = line.trim();
|
||||
trimmed != HOSTS_MARKER_START && trimmed != HOSTS_MARKER_END
|
||||
})
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let mut result = cleaned.join("\n");
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn get_hosts_path() -> String {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
r"C:\Windows\System32\drivers\etc\hosts".to_string()
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
"/etc/hosts".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that a pattern is safe to write to the hosts file.
|
||||
/// Rejects patterns with whitespace, control chars, or comment markers.
|
||||
fn is_valid_hosts_entry(pattern: &str) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return false;
|
||||
}
|
||||
// Reject if contains whitespace, control chars, or comment markers
|
||||
if pattern.chars().any(|c| c.is_whitespace() || c.is_control() || c == '#') {
|
||||
return false;
|
||||
}
|
||||
// Must look like a hostname (alphanumeric, dots, hyphens, underscores, asterisks)
|
||||
pattern.chars().all(|c| c.is_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '*')
|
||||
}
|
||||
Reference in New Issue
Block a user