feat: 添加新插件支持及多项功能改进

- 新增磁盘加密、打印审计和剪贴板管控插件支持
- 优化水印插件显示效果,支持中文及更多Unicode字符
- 改进硬件资产收集逻辑,更准确获取磁盘和显卡信息
- 增强API错误处理,添加详细日志记录
- 完善前端界面,新增插件管理页面
- 修复多个UI问题,优化页面过渡效果
- 添加环境变量覆盖配置功能
- 实现插件状态管理API
- 更新文档和变更日志
- 添加安装程序脚本支持
This commit is contained in:
iven
2026-04-10 22:21:05 +08:00
parent 3d39f0e426
commit b5333d8c93
101 changed files with 4487 additions and 661 deletions

View File

@@ -4,7 +4,7 @@ use tokio::sync::mpsc::Sender;
use tracing::{info, error};
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
let interval = Duration::from_secs(86400);
let interval = Duration::from_secs(43200);
if let Err(e) = collect_and_send(&tx, &device_uid).await {
error!("Initial asset collection failed: {}", e);
@@ -50,18 +50,14 @@ fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
// Memory
let memory_total_mb = sys.total_memory() / 1024 / 1024;
// Disk — pick the largest non-removable disk
// Disk — use PowerShell for real hardware model, sysinfo for total capacity
let disks = sysinfo::Disks::new_with_refreshed_list();
let (disk_model, disk_total_mb) = disks.iter()
.filter(|d| d.kind() == sysinfo::DiskKind::HDD || d.kind() == sysinfo::DiskKind::SSD)
.max_by_key(|d| d.total_space())
.map(|d| {
let total = d.total_space() / 1024 / 1024;
let name = d.name().to_string_lossy().to_string();
let model = if name.is_empty() { "Unknown".to_string() } else { name };
(model, total)
})
.unwrap_or_else(|| ("Unknown".to_string(), 0));
let disk_total_mb: u64 = disks.iter()
.map(|d| d.total_space() / 1024 / 1024)
.sum::<u64>()
.max(1)
.saturating_sub(1); // avoid reporting 0 if no disks
let disk_model = collect_disk_model().unwrap_or_else(|| "Unknown".to_string());
// GPU, motherboard, serial — Windows-specific via PowerShell
let (gpu_model, motherboard, serial_number) = collect_system_details();
@@ -81,12 +77,12 @@ fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
#[cfg(target_os = "windows")]
fn collect_system_details() -> (Option<String>, Option<String>, Option<String>) {
// GPU: query all controllers, filter out virtual/IDDDriver devices, prefer real GPU
// GPU: query all controllers, only exclude explicit virtual/placeholder devices
let gpu = {
let gpus = powershell_lines(
"Get-CimInstance Win32_VideoController | Where-Object { $_.Name -notmatch 'IddDriver|Virtual|Basic Render|Microsoft Basic Display|Remote Desktop|Mirror Driver' } | Select-Object -ExpandProperty Name"
"Get-CimInstance Win32_VideoController | Where-Object { $_.Name -notmatch 'IddDriver|Virtual Display|Basic Render|Microsoft Basic Display Adapter|Mirror Driver' } | Select-Object -ExpandProperty Name"
);
// Prefer NVIDIA/AMD/Intel, fallback to first non-virtual
info!("Detected GPUs: {:?}", gpus);
gpus.into_iter().next()
};
let mb_manufacturer = powershell_first("Get-CimInstance Win32_BaseBoard | Select-Object -ExpandProperty Manufacturer");
@@ -102,6 +98,20 @@ fn collect_system_details() -> (Option<String>, Option<String>, Option<String>)
(gpu, motherboard, serial_number)
}
/// Get real disk hardware model via PowerShell Get-PhysicalDisk.
#[cfg(target_os = "windows")]
fn collect_disk_model() -> Option<String> {
let models = powershell_lines(
"Get-PhysicalDisk | Select-Object -ExpandProperty FriendlyName"
);
models.into_iter().next()
}
#[cfg(not(target_os = "windows"))]
fn collect_disk_model() -> Option<String> {
None
}
#[cfg(not(target_os = "windows"))]
fn collect_system_details() -> (Option<String>, Option<String>, Option<String>) {
(None, None, None)
@@ -150,6 +160,7 @@ fn collect_windows_software(device_uid: &str) -> Vec<SoftwareAsset> {
use std::process::Command;
let ps_cmd = r#"
[Console]::OutputEncoding = [System.Text.Encoding]::UTF8
$paths = @(
"HKLM:\SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\*",
"HKLM:\SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\*",

View File

@@ -0,0 +1,212 @@
use std::time::Duration;
use tokio::sync::watch;
use tracing::{info, warn};
use csm_protocol::{Frame, MessageType, ClipboardRule, ClipboardViolationPayload};
/// Clipboard control configuration pushed from server
#[derive(Debug, Clone, Default)]
pub struct ClipboardControlConfig {
pub enabled: bool,
pub rules: Vec<ClipboardRule>,
}
/// Start the clipboard control plugin.
/// Periodically checks clipboard content against rules and reports violations.
pub async fn start(
mut config_rx: watch::Receiver<ClipboardControlConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Clipboard control plugin started");
let mut config = ClipboardControlConfig::default();
let mut check_interval = tokio::time::interval(Duration::from_secs(2));
check_interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!(
"Clipboard control config updated: enabled={}, rules={}",
new_config.enabled,
new_config.rules.len()
);
config = new_config;
}
_ = check_interval.tick() => {
if !config.enabled || config.rules.is_empty() {
continue;
}
let uid = device_uid.clone();
let rules = config.rules.clone();
let result = tokio::task::spawn_blocking(move || check_clipboard(&uid, &rules)).await;
match result {
Ok(Some(payload)) => {
if let Ok(frame) = Frame::new_json(MessageType::ClipboardViolation, &payload) {
if data_tx.send(frame).await.is_err() {
warn!("Failed to send clipboard violation: channel closed");
return;
}
}
}
Ok(None) => {}
Err(e) => warn!("Clipboard check task failed: {}", e),
}
}
}
}
}
/// Check clipboard content against rules. Returns a violation payload if a rule matched.
fn check_clipboard(device_uid: &str, rules: &[ClipboardRule]) -> Option<ClipboardViolationPayload> {
#[cfg(target_os = "windows")]
{
let clipboard_text = get_clipboard_text();
let foreground_process = get_foreground_process();
for rule in rules {
if rule.rule_type != "block" {
continue;
}
// Check direction — only interested in "out" or "both"
if !matches!(rule.direction.as_str(), "out" | "both") {
continue;
}
// Check source process filter
if let Some(ref src_pattern) = rule.source_process {
if let Some(ref fg_proc) = foreground_process {
if !pattern_match(src_pattern, fg_proc) {
continue;
}
} else {
continue;
}
}
// Check content pattern
if let Some(ref content_pattern) = rule.content_pattern {
if let Some(ref text) = clipboard_text {
if !content_matches(content_pattern, text) {
continue;
}
} else {
continue;
}
}
// Rule matched — generate violation (never send raw content)
let preview = clipboard_text.as_ref().map(|t| format!("[{} chars]", t.len()));
// Clear clipboard to enforce block
let _ = std::process::Command::new("powershell")
.args(["-NoProfile", "-NonInteractive", "-Command", "Set-Clipboard -Value ''"])
.output();
info!("Clipboard blocked: rule_id={}", rule.id);
return Some(ClipboardViolationPayload {
device_uid: device_uid.to_string(),
source_process: foreground_process,
target_process: None,
content_preview: preview,
action_taken: "blocked".to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
});
}
None
}
#[cfg(not(target_os = "windows"))]
{
let _ = (device_uid, rules);
None
}
}
#[cfg(target_os = "windows")]
fn get_clipboard_text() -> Option<String> {
let output = std::process::Command::new("powershell")
.args(["-NoProfile", "-NonInteractive", "-Command", "Get-Clipboard -Raw"])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let text = String::from_utf8_lossy(&output.stdout).trim().to_string();
if text.is_empty() { None } else { Some(text) }
}
#[cfg(target_os = "windows")]
fn get_foreground_process() -> Option<String> {
let output = std::process::Command::new("powershell")
.args([
"-NoProfile",
"-NonInteractive",
"-Command",
r#"Add-Type @"
using System;
using System.Runtime.InteropServices;
public class WinAPI {
[DllImport("user32.dll")] public static extern IntPtr GetForegroundWindow();
[DllImport("user32.dll")] public static extern uint GetWindowThreadProcessId(IntPtr hWnd, out uint lpdwProcessId);
}
"@
$hwnd = [WinAPI]::GetForegroundWindow()
$pid = 0
[WinAPI]::GetWindowThreadProcessId($hwnd, [ref]$pid) | Out-Null
if ($pid -gt 0) { (Get-Process -Id $pid -ErrorAction SilentlyContinue).ProcessName } else { "" }"#,
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
if name.is_empty() {
None
} else {
Some(name)
}
}
/// Simple case-insensitive wildcard pattern matching. Supports `*` as wildcard.
fn pattern_match(pattern: &str, text: &str) -> bool {
let p = pattern.to_lowercase();
let t = text.to_lowercase();
if !p.contains('*') {
return t.contains(&p);
}
let parts: Vec<&str> = p.split('*').collect();
if parts.is_empty() {
return true;
}
let mut pos = 0usize;
let mut matched_any = false;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
matched_any = true;
if i == 0 && !parts[0].is_empty() {
if !t.starts_with(part) {
return false;
}
pos = part.len();
} else {
match t[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
if matched_any && !parts.last().map_or(true, |p| p.is_empty()) {
return t.ends_with(parts.last().unwrap());
}
true
}
fn content_matches(pattern: &str, text: &str) -> bool {
text.to_lowercase().contains(&pattern.to_lowercase())
}

View File

@@ -0,0 +1,200 @@
use std::time::Duration;
use tokio::sync::watch;
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, DiskEncryptionStatusPayload, DriveEncryptionInfo, DiskEncryptionConfigPayload};
/// Disk encryption configuration pushed from server
#[derive(Debug, Clone, Default)]
pub struct DiskEncryptionConfig {
pub enabled: bool,
pub report_interval_secs: u64,
}
impl From<DiskEncryptionConfigPayload> for DiskEncryptionConfig {
fn from(payload: DiskEncryptionConfigPayload) -> Self {
Self {
enabled: payload.enabled,
report_interval_secs: payload.report_interval_secs,
}
}
}
/// Start the disk encryption detection plugin.
/// On startup and periodically, collects BitLocker volume status via PowerShell
/// and sends results to the server.
pub async fn start(
mut config_rx: watch::Receiver<DiskEncryptionConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Disk encryption plugin started");
let mut config = DiskEncryptionConfig::default();
let default_interval_secs: u64 = 3600;
let mut report_interval = tokio::time::interval(Duration::from_secs(default_interval_secs));
report_interval.tick().await;
// Collect and report once on startup if enabled
if config.enabled {
collect_and_report(&data_tx, &device_uid).await;
}
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
if new_config.enabled != config.enabled {
info!("Disk encryption enabled: {}", new_config.enabled);
}
config = new_config;
if config.enabled {
let secs = if config.report_interval_secs > 0 {
config.report_interval_secs
} else {
default_interval_secs
};
report_interval = tokio::time::interval(Duration::from_secs(secs));
report_interval.tick().await;
}
}
_ = report_interval.tick() => {
if !config.enabled {
continue;
}
collect_and_report(&data_tx, &device_uid).await;
}
}
}
}
async fn collect_and_report(
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
) {
let uid = device_uid.to_string();
match tokio::task::spawn_blocking(move || collect_bitlocker_status()).await {
Ok(drives) => {
if drives.is_empty() {
debug!("No BitLocker volumes found for device {}", uid);
return;
}
let payload = DiskEncryptionStatusPayload {
device_uid: uid,
drives,
};
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionStatus, &payload) {
if data_tx.send(frame).await.is_err() {
warn!("Failed to send disk encryption status: channel closed");
}
}
}
Err(e) => {
warn!("Failed to collect disk encryption status: {}", e);
}
}
}
/// Collect BitLocker volume information via PowerShell.
/// Runs: Get-BitLockerVolume | ConvertTo-Json
fn collect_bitlocker_status() -> Vec<DriveEncryptionInfo> {
#[cfg(target_os = "windows")]
{
let output = std::process::Command::new("powershell")
.args([
"-NoProfile",
"-NonInteractive",
"-Command",
"Get-BitLockerVolume | Select-Object MountPoint, VolumeName, EncryptionMethod, ProtectionStatus, EncryptionPercentage, LockStatus | ConvertTo-Json -Compress",
])
.output();
match output {
Ok(out) if out.status.success() => {
let stdout = String::from_utf8_lossy(&out.stdout);
let trimmed = stdout.trim();
if trimmed.is_empty() {
return Vec::new();
}
// PowerShell returns a single object (not array) when there is exactly one volume
let json_str = if trimmed.starts_with('{') {
format!("[{}]", trimmed)
} else {
trimmed.to_string()
};
match serde_json::from_str::<Vec<serde_json::Value>>(&json_str) {
Ok(entries) => entries.into_iter().map(|e| parse_bitlocker_entry(&e)).collect(),
Err(e) => {
warn!("Failed to parse BitLocker output: {}", e);
Vec::new()
}
}
}
Ok(out) => {
let stderr = String::from_utf8_lossy(&out.stderr);
warn!("PowerShell BitLocker query failed: {}", stderr);
Vec::new()
}
Err(e) => {
warn!("Failed to run PowerShell for BitLocker status: {}", e);
Vec::new()
}
}
}
#[cfg(not(target_os = "windows"))]
{
Vec::new()
}
}
fn parse_bitlocker_entry(entry: &serde_json::Value) -> DriveEncryptionInfo {
let mount_point = entry.get("MountPoint")
.and_then(|v| v.as_str())
.unwrap_or("Unknown:")
.to_string();
let volume_name = entry.get("VolumeName")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
.map(String::from);
let encryption_method = entry.get("EncryptionMethod")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty() && *s != "None")
.map(String::from);
let protection_status = match entry.get("ProtectionStatus") {
Some(v) if v.is_number() => match v.as_i64().unwrap_or(0) {
1 => "On".to_string(),
0 => "Off".to_string(),
_ => "Unknown".to_string(),
},
Some(v) if v.is_string() => v.as_str().unwrap_or("Unknown").to_string(),
_ => "Unknown".to_string(),
};
let encryption_percentage = entry.get("EncryptionPercentage")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let lock_status = match entry.get("LockStatus") {
Some(v) if v.is_number() => match v.as_i64().unwrap_or(0) {
1 => "Locked".to_string(),
0 => "Unlocked".to_string(),
_ => "Unknown".to_string(),
},
Some(v) if v.is_string() => v.as_str().unwrap_or("Unknown").to_string(),
_ => "Unknown".to_string(),
};
DriveEncryptionInfo {
drive_letter: mount_point,
volume_name,
encryption_method,
protection_status,
encryption_percentage,
lock_status,
}
}

View File

@@ -14,6 +14,9 @@ mod usb_audit;
mod popup_blocker;
mod software_blocker;
mod web_filter;
mod disk_encryption;
mod clipboard_control;
mod print_audit;
#[cfg(target_os = "windows")]
mod service;
@@ -91,7 +94,11 @@ pub async fn run(state: ClientState) -> Result<()> {
let (usb_audit_tx, usb_audit_rx) = tokio::sync::watch::channel(usb_audit::UsbAuditConfig::default());
let (usage_timer_tx, usage_timer_rx) = tokio::sync::watch::channel(usage_timer::UsageConfig::default());
let (usb_policy_tx, usb_policy_rx) = tokio::sync::watch::channel(None::<UsbPolicyPayload>);
let (disk_encryption_tx, disk_encryption_rx) = tokio::sync::watch::channel(disk_encryption::DiskEncryptionConfig::default());
let (print_audit_tx, print_audit_rx) = tokio::sync::watch::channel(print_audit::PrintAuditConfig::default());
let (clipboard_control_tx, clipboard_control_rx) = tokio::sync::watch::channel(clipboard_control::ClipboardControlConfig::default());
// Build plugin channels struct
let plugins = network::PluginChannels {
watermark_tx,
web_filter_tx,
@@ -100,6 +107,9 @@ pub async fn run(state: ClientState) -> Result<()> {
usb_audit_tx,
usage_timer_tx,
usb_policy_tx,
disk_encryption_tx,
print_audit_tx,
clipboard_control_tx,
};
// Spawn core monitoring tasks
@@ -138,8 +148,10 @@ pub async fn run(state: ClientState) -> Result<()> {
usb_audit::start(usb_audit_rx, audit_data_tx, audit_uid).await;
});
let pb_data_tx = data_tx.clone();
let pb_uid = state.device_uid.clone();
tokio::spawn(async move {
popup_blocker::start(popup_blocker_rx).await;
popup_blocker::start(popup_blocker_rx, pb_data_tx, pb_uid).await;
});
let sw_data_tx = data_tx.clone();
@@ -152,6 +164,24 @@ pub async fn run(state: ClientState) -> Result<()> {
web_filter::start(web_filter_rx).await;
});
let de_data_tx = data_tx.clone();
let de_uid = state.device_uid.clone();
tokio::spawn(async move {
disk_encryption::start(disk_encryption_rx, de_data_tx, de_uid).await;
});
let pa_data_tx = data_tx.clone();
let pa_uid = state.device_uid.clone();
tokio::spawn(async move {
print_audit::start(print_audit_rx, pa_data_tx, pa_uid).await;
});
let cc_data_tx = data_tx.clone();
let cc_uid = state.device_uid.clone();
tokio::spawn(async move {
clipboard_control::start(clipboard_control_rx, cc_data_tx, cc_uid).await;
});
// Connect to server with reconnect
let mut backoff = Duration::from_secs(1);
let max_backoff = Duration::from_secs(60);
@@ -163,6 +193,7 @@ pub async fn run(state: ClientState) -> Result<()> {
}
match network::connect_and_run(&state, &mut data_rx, &plugins).await {
// Plugin channels moved into plugins struct — watchers are already cloned per-task
Ok(()) => {
warn!("Disconnected from server, reconnecting...");
tokio::time::sleep(Duration::from_secs(2)).await;
@@ -186,29 +217,58 @@ pub async fn run(state: ClientState) -> Result<()> {
}
}
/// Get directory for storing persistent files (next to the executable)
fn data_dir() -> std::path::PathBuf {
std::env::current_exe()
.ok()
.and_then(|p| p.parent().map(|p| p.to_path_buf()))
.unwrap_or_else(|| std::path::PathBuf::from("."))
}
fn load_or_create_device_uid() -> Result<String> {
let uid_file = "device_uid.txt";
if std::path::Path::new(uid_file).exists() {
let uid = std::fs::read_to_string(uid_file)?;
let uid_file = data_dir().join("device_uid.txt");
if uid_file.exists() {
let uid = std::fs::read_to_string(&uid_file)?;
Ok(uid.trim().to_string())
} else {
let uid = uuid::Uuid::new_v4().to_string();
std::fs::write(uid_file, &uid)?;
write_restricted_file(&uid_file, &uid)?;
Ok(uid)
}
}
/// Load persisted device_secret from disk (if available)
pub fn load_device_secret() -> Option<String> {
let secret_file = "device_secret.txt";
let secret = std::fs::read_to_string(secret_file).ok()?;
let secret_file = data_dir().join("device_secret.txt");
let secret = std::fs::read_to_string(&secret_file).ok()?;
let trimmed = secret.trim().to_string();
if trimmed.is_empty() { None } else { Some(trimmed) }
}
/// Persist device_secret to disk
/// Persist device_secret to disk with restricted permissions
pub fn save_device_secret(secret: &str) {
if let Err(e) = std::fs::write("device_secret.txt", secret) {
let secret_file = data_dir().join("device_secret.txt");
if let Err(e) = write_restricted_file(&secret_file, secret) {
warn!("Failed to persist device_secret: {}", e);
}
}
/// Write a file with owner-only permissions (0o600 on Unix).
/// On Windows, the file inherits the directory's ACL — consider setting
/// explicit ACLs via PowerShell for production deployments.
#[cfg(unix)]
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(path)
.and_then(|mut f| std::io::Write::write_all(&mut f, content.as_bytes()))
}
#[cfg(not(unix))]
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
std::fs::write(path, content)
}

View File

@@ -1,24 +1,33 @@
use anyhow::Result;
use csm_protocol::{Frame, MessageType, DeviceStatus, ProcessInfo};
use std::time::Duration;
use tokio::sync::mpsc::Sender;
use tracing::{info, error, debug};
use sysinfo::System;
use sysinfo::Disks;
use sysinfo::Networks;
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
let interval = Duration::from_secs(60);
let mut prev_rx: Option<u64> = None;
let mut prev_tx: Option<u64> = None;
loop {
// Run blocking sysinfo collection on a dedicated thread
let uid_clone = device_uid.clone();
let result = tokio::task::spawn_blocking(move || {
collect_system_status(&uid_clone)
}).await;
let prev_rx_c = prev_rx;
let prev_tx_c = prev_tx;
let result = tokio::task::spawn_blocking(move || {
collect_system_status(&uid_clone, prev_rx_c, prev_tx_c)
}).await;
match result {
Ok(Ok(status)) => {
Ok(Ok((status, new_rx, new_tx))) => {
prev_rx = Some(new_rx);
prev_tx = Some(new_tx);
if let Ok(frame) = Frame::new_json(MessageType::StatusReport, &status) {
debug!("Sending status report: cpu={:.1}%, mem={:.1}%", status.cpu_usage, status.memory_usage);
debug!(
"Sending status: cpu={:.1}%, mem={:.1}%, disk={:.1}%",
status.cpu_usage, status.memory_usage, status.disk_usage
);
if tx.send(frame).await.is_err() {
info!("Monitor channel closed, exiting");
break;
@@ -37,25 +46,68 @@ pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
}
}
fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
fn collect_system_status(
device_uid: &str,
prev_rx: Option<u64>,
prev_tx: Option<u64>,
) -> anyhow::Result<(DeviceStatus, u64, u64)> {
let mut sys = System::new_all();
sys.refresh_all();
let disks = Disks::new_with_refreshed_list();
let networks = Networks::new_with_refreshed_list();
// Brief wait for CPU usage to stabilize
sys.refresh_all();
std::thread::sleep(Duration::from_millis(200));
sys.refresh_all();
sys.refresh_cpu_usage();
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
let total_memory = sys.total_memory() / 1024 / 1024; // Convert bytes to MB (sysinfo 0.30 returns bytes)
let used_memory = sys.used_memory() / 1024 / 1024;
let memory_usage = if total_memory > 0 {
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
let total_memory = sys.total_memory() / 1024 / 1024;
// Convert bytes to MB
let used_memory = sys.used_memory() / 1024 / 1024;
let memory_usage = if total_memory > 0 {
(used_memory as f64 / total_memory as f64) * 100.0
} else {
0.0
};
// Disk usage
let (disk_usage, disk_total_mb) = {
let mut total_space: u64 = 0;
let mut total_available: u64 = 0;
for disk in disks.list() {
let total = disk.total_space() / 1024 / 1024;
// MB
let available = disk.available_space() / 1024 / 1024;
// MB
total_space += total;
total_available += available;
}
let used_mb = total_space.saturating_sub(total_available);
let usage_pct = if total_space > 0 {
(used_mb as f64 / total_space as f64) * 100.0
} else {
0.0
};
(usage_pct, total_space)
};
// Network rate
let (network_rx_rate, network_tx_rate, current_rx, current_tx) = {
let mut cur_rx: u64 = 0;
let mut cur_tx: u64 = 0;
for (_, data) in networks.iter() {
cur_rx += data.received();
cur_tx += data.transmitted();
}
let rx_rate = match prev_rx {
Some(prev) => cur_rx.saturating_sub(prev) / 60, // bytes/sec (60s interval)
None => 0,
};
let tx_rate = match prev_tx {
Some(prev) => cur_tx.saturating_sub(prev) / 60,
None => 0,
};
(rx_rate, tx_rate, cur_rx, cur_tx)
};
// Top processes by CPU
// Top processes by CPU
let mut processes: Vec<ProcessInfo> = sys.processes()
.iter()
.map(|(_, p)| {
@@ -63,24 +115,24 @@ fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
name: p.name().to_string(),
pid: p.pid().as_u32(),
cpu_usage: p.cpu_usage() as f64,
memory_mb: p.memory() / 1024 / 1024, // bytes to MB (sysinfo 0.30)
memory_mb: p.memory() / 1024 / 1024,
// bytes to MB
}
})
.collect();
processes.sort_by(|a, b| b.cpu_usage.partial_cmp(&a.cpu_usage).unwrap_or(std::cmp::Ordering::Equal));
processes.truncate(10);
Ok(DeviceStatus {
processes.truncate(10);
let status = DeviceStatus {
device_uid: device_uid.to_string(),
cpu_usage,
memory_usage,
memory_total_mb: total_memory as u64,
disk_usage: 0.0, // TODO: implement disk usage via Windows API
disk_total_mb: 0,
network_rx_rate: 0,
network_tx_rate: 0,
disk_usage,
disk_total_mb,
network_rx_rate: network_rx_rate,
network_tx_rate: network_tx_rate,
running_procs: sys.processes().len() as u32,
top_processes: processes,
})
};
Ok((status, current_rx, current_tx))
}

View File

@@ -3,7 +3,7 @@ use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
@@ -18,6 +18,9 @@ pub struct PluginChannels {
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
pub disk_encryption_tx: tokio::sync::watch::Sender<crate::disk_encryption::DiskEncryptionConfig>,
pub print_audit_tx: tokio::sync::watch::Sender<crate::print_audit::PrintAuditConfig>,
pub clipboard_control_tx: tokio::sync::watch::Sender<crate::clipboard_control::ClipboardControlConfig>,
}
/// Connect to server and run the main communication loop
@@ -286,6 +289,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::DiskEncryptionConfig => {
let config: DiskEncryptionConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid disk encryption config: {}", e))?;
info!("Received disk encryption config: enabled={}, interval={}s", config.enabled, config.report_interval_secs);
let plugin_config = crate::disk_encryption::DiskEncryptionConfig {
enabled: config.enabled,
report_interval_secs: config.report_interval_secs,
};
plugins.disk_encryption_tx.send(plugin_config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
@@ -299,6 +312,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
MessageType::ClipboardRules => {
let payload: csm_protocol::ClipboardRulesPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid clipboard rules: {}", e))?;
info!("Received clipboard rules update: {} rules", payload.rules.len());
let config = crate::clipboard_control::ClipboardControlConfig {
enabled: true,
rules: payload.rules,
};
plugins.clipboard_control_tx.send(config)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
@@ -346,6 +369,21 @@ fn handle_plugin_control(
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
"disk_encryption" => {
if !enabled {
plugins.disk_encryption_tx.send(crate::disk_encryption::DiskEncryptionConfig { enabled: false, ..Default::default() })?;
}
}
"print_audit" => {
if !enabled {
plugins.print_audit_tx.send(crate::print_audit::PrintAuditConfig { enabled: false, ..Default::default() })?;
}
}
"clipboard_control" => {
if !enabled {
plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}

View File

@@ -1,370 +0,0 @@
use anyhow::Result;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::ClientState;
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Holds senders for all plugin config channels
pub struct PluginChannels {
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
}
/// Connect to server and run the main communication loop
pub async fn connect_and_run(
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
run_comm_loop(tls_stream, state, data_rx, plugins).await
} else {
run_comm_loop(tcp_stream, state, data_rx, plugins).await
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
// Load custom CA certificate if specified
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
let ca_pem = std::fs::read(&ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
for cert in certs {
root_store.add(cert)?;
}
info!("Loaded custom CA certificates from {}", ca_path);
}
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_no_client_auth()
} else {
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
let tls_stream = connector.connect(server_name, stream).await?;
info!("TLS handshake completed with {}", domain);
Ok(tls_stream)
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &rustls_pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
/// Main communication loop over any read+write stream
async fn run_comm_loop<S>(
mut stream: S,
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
// Send registration
let register = RegisterRequest {
device_uid: state.device_uid.clone(),
hostname: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()),
registration_token: state.registration_token.clone(),
os_version: get_os_info(),
mac_address: None,
};
let frame = Frame::new_json(MessageType::Register, &register)?;
stream.write_all(&frame.encode()).await?;
info!("Registration request sent");
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let heartbeat_secs = state.config.heartbeat_interval_secs;
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
heartbeat_interval.tick().await; // Skip first tick
// HMAC key — set after receiving RegisterResponse
let mut device_secret: Option<String> = state.device_secret.clone();
loop {
tokio::select! {
// Read from server
result = stream.read(&mut buffer) => {
let n = result?;
if n == 0 {
return Err(anyhow::anyhow!("Server closed connection"));
}
read_buf.extend_from_slice(&buffer[..n]);
// Process complete frames
loop {
match Frame::decode(&read_buf)? {
Some(frame) => {
let consumed = frame.encoded_size();
read_buf.drain(..consumed);
// Capture device_secret from registration response
if frame.msg_type == MessageType::RegisterResponse {
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
device_secret = Some(resp.device_secret.clone());
crate::save_device_secret(&resp.device_secret);
info!("Device secret received and persisted, HMAC enabled for heartbeats");
}
}
handle_server_message(frame, plugins)?;
}
None => break, // Incomplete frame, wait for more data
}
}
}
// Send queued data
frame = data_rx.recv() => {
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
stream.write_all(&frame.encode()).await?;
}
// Heartbeat
_ = heartbeat_interval.tick() => {
let timestamp = chrono::Utc::now().to_rfc3339();
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, &timestamp);
let heartbeat = HeartbeatPayload {
device_uid: state.device_uid.clone(),
timestamp,
hmac: hmac_value,
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
stream.write_all(&frame.encode()).await?;
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
}
}
}
}
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
match frame.msg_type {
MessageType::RegisterResponse => {
let resp: RegisterResponse = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
info!("Registration accepted by server (server version: {})", resp.config.server_version);
}
MessageType::PolicyUpdate => {
let policy: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
}
MessageType::WatermarkConfig => {
let config: WatermarkConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
info!("Received watermark config: enabled={}", config.enabled);
plugins.watermark_tx.send(Some(config))?;
}
MessageType::WebFilterRuleUpdate => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
info!("Received web filter rules update");
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
plugins.web_filter_tx.send(config)?;
}
MessageType::SoftwareBlacklist => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
info!("Received software blacklist update");
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
info!("Received popup blocker rules update");
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
info!("Plugin enabled: {}", payload.plugin_name);
// Route to appropriate plugin channel based on plugin_name
handle_plugin_control(&payload, plugins, true)?;
}
MessageType::PluginDisable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
fn handle_plugin_control(
payload: &csm_protocol::PluginControlPayload,
plugins: &PluginChannels,
enabled: bool,
) -> Result<()> {
match payload.plugin_name.as_str() {
"watermark" => {
if !enabled {
// Send disabled config to remove overlay
plugins.watermark_tx.send(None)?;
}
// When enabling, server will push the actual config next
}
"web_filter" => {
if !enabled {
// Clear hosts rules on disable
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
}
// When enabling, server will push rules
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
}
}
"popup_blocker" => {
if !enabled {
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
}
}
"usb_audit" => {
if !enabled {
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
}
}
"usage_timer" => {
if !enabled {
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}
}
Ok(())
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
let secret = match secret {
Some(s) if !s.is_empty() => s,
_ => return String::new(),
};
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
fn get_os_info() -> String {
use sysinfo::System;
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
format!("{} {}", name, version)
}

View File

@@ -1,6 +1,7 @@
use tokio::sync::watch;
use tokio::sync::{watch, mpsc};
use tracing::{info, debug};
use serde::Deserialize;
use csm_protocol::{Frame, MessageType, PopupBlockStatsPayload, PopupRuleStat};
/// Popup blocker rule from server
#[derive(Debug, Clone, Deserialize)]
@@ -23,15 +24,27 @@ pub struct PopupBlockerConfig {
struct ScanContext {
rules: Vec<PopupRule>,
blocked_count: u32,
rule_hits: std::collections::HashMap<i64, u32>,
}
/// Start popup blocker plugin.
/// Periodically enumerates windows and closes those matching rules.
pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
/// Reports statistics to server every 60 seconds.
pub async fn start(
mut config_rx: watch::Receiver<PopupBlockerConfig>,
data_tx: mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Popup blocker plugin started");
let mut config = PopupBlockerConfig::default();
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(2));
scan_interval.tick().await;
let mut stats_interval = tokio::time::interval(std::time::Duration::from_secs(60));
stats_interval.tick().await;
// Accumulated stats
let mut total_blocked: u32 = 0;
let mut rule_hits: std::collections::HashMap<i64, u32> = std::collections::HashMap::new();
loop {
tokio::select! {
@@ -47,13 +60,34 @@ pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
if !config.enabled || config.rules.is_empty() {
continue;
}
scan_and_block(&config.rules);
let ctx = scan_and_block(&config.rules);
total_blocked += ctx.blocked_count;
for (rule_id, hits) in ctx.rule_hits {
*rule_hits.entry(rule_id).or_insert(0) += hits;
}
}
_ = stats_interval.tick() => {
if total_blocked > 0 {
let stats = PopupBlockStatsPayload {
device_uid: device_uid.clone(),
blocked_count: total_blocked,
rule_stats: rule_hits.iter().map(|(&id, &hits)| PopupRuleStat { rule_id: id, hits }).collect(),
period_secs: 60,
};
if let Ok(frame) = Frame::new_json(MessageType::PopupBlockStats, &stats) {
if data_tx.send(frame).await.is_err() {
debug!("Failed to send popup block stats: channel closed");
}
}
total_blocked = 0;
rule_hits.clear();
}
}
}
}
}
fn scan_and_block(rules: &[PopupRule]) {
fn scan_and_block(rules: &[PopupRule]) -> ScanContext {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::WindowsAndMessaging::EnumWindows;
@@ -62,6 +96,7 @@ fn scan_and_block(rules: &[PopupRule]) {
let mut ctx = ScanContext {
rules: rules.to_vec(),
blocked_count: 0,
rule_hits: std::collections::HashMap::new(),
};
unsafe {
@@ -73,10 +108,12 @@ fn scan_and_block(rules: &[PopupRule]) {
if ctx.blocked_count > 0 {
debug!("Popup scan blocked {} windows", ctx.blocked_count);
}
ctx
}
#[cfg(not(target_os = "windows"))]
{
let _ = rules;
ScanContext { rules: vec![], blocked_count: 0, rule_hits: std::collections::HashMap::new() }
}
}
@@ -133,6 +170,7 @@ unsafe extern "system" fn enum_windows_callback(
if matches {
let _ = PostMessageW(hwnd, WM_CLOSE, WPARAM(0), LPARAM(0));
ctx.blocked_count += 1;
*ctx.rule_hits.entry(rule.id).or_insert(0) += 1;
info!(
"Blocked popup: title='{}' class='{}' process='{}' (rule_id={})",
title, class_name, process_name, rule.id

View File

@@ -0,0 +1,249 @@
use std::collections::HashSet;
use std::time::Duration;
use tokio::sync::watch;
use tracing::{info, warn};
use csm_protocol::{Frame, MessageType, PrintEventPayload};
/// Print audit configuration pushed from server
#[derive(Debug, Clone, Default)]
pub struct PrintAuditConfig {
pub enabled: bool,
pub report_interval_secs: u64,
}
/// Start the print audit plugin.
/// On startup and periodically, queries Windows print spooler for recent
/// print jobs via WMI and sends new events to the server.
pub async fn start(
mut config_rx: watch::Receiver<PrintAuditConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Print audit plugin started");
let mut config = PrintAuditConfig::default();
let default_interval_secs: u64 = 300;
let mut report_interval = tokio::time::interval(Duration::from_secs(default_interval_secs));
report_interval.tick().await;
// Track seen print job IDs to avoid duplicates
let mut seen_jobs: HashSet<String> = HashSet::new();
// Collect and report once on startup if enabled
if config.enabled {
collect_and_report(&data_tx, &device_uid, &mut seen_jobs).await;
}
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
if new_config.enabled != config.enabled {
info!("Print audit enabled: {}", new_config.enabled);
}
config = new_config;
if config.enabled {
let secs = if config.report_interval_secs > 0 {
config.report_interval_secs
} else {
default_interval_secs
};
report_interval = tokio::time::interval(Duration::from_secs(secs));
report_interval.tick().await;
}
}
_ = report_interval.tick() => {
if !config.enabled {
continue;
}
collect_and_report(&data_tx, &device_uid, &mut seen_jobs).await;
}
}
}
}
async fn collect_and_report(
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
seen_jobs: &mut HashSet<String>,
) {
let uid = device_uid.to_string();
match tokio::task::spawn_blocking(move || collect_print_jobs()).await {
Ok(jobs) => {
for job in jobs {
let job_key = format!("{}|{}|{}", job.document_name.as_deref().unwrap_or(""), job.printer_name.as_deref().unwrap_or(""), &job.timestamp);
if seen_jobs.contains(&job_key) {
continue;
}
seen_jobs.insert(job_key.clone());
let payload = PrintEventPayload {
device_uid: uid.clone(),
document_name: job.document_name,
printer_name: job.printer_name,
pages: job.pages,
copies: job.copies,
user_name: job.user_name,
file_size_bytes: job.file_size_bytes,
timestamp: job.timestamp,
};
if let Ok(frame) = Frame::new_json(MessageType::PrintEvent, &payload) {
if data_tx.send(frame).await.is_err() {
warn!("Failed to send print event: channel closed");
return;
}
}
}
// Keep seen_jobs bounded — evict entries older than what we'd reasonably see
if seen_jobs.len() > 10000 {
seen_jobs.clear();
}
}
Err(e) => {
warn!("Failed to collect print jobs: {}", e);
}
}
}
struct PrintJob {
document_name: Option<String>,
printer_name: Option<String>,
pages: Option<i32>,
copies: Option<i32>,
user_name: Option<String>,
file_size_bytes: Option<i64>,
timestamp: String,
}
/// Collect recent print jobs via WMI on Windows.
/// Queries Win32_PrintJob for jobs that completed in the recent period.
fn collect_print_jobs() -> Vec<PrintJob> {
#[cfg(target_os = "windows")]
{
let output = std::process::Command::new("powershell")
.args([
"-NoProfile",
"-NonInteractive",
"-Command",
"Get-WinEvent -FilterHashtable @{LogName='Microsoft-Windows-PrintService/Operational'; ID=307} -MaxEvents 50 -ErrorAction SilentlyContinue | Select-Object TimeCreated, Message | ConvertTo-Json -Compress",
])
.output();
match output {
Ok(out) if out.status.success() => {
let stdout = String::from_utf8_lossy(&out.stdout);
let trimmed = stdout.trim();
if trimmed.is_empty() {
return Vec::new();
}
// PowerShell may return single object or array
let json_str = if trimmed.starts_with('{') {
format!("[{}]", trimmed)
} else {
trimmed.to_string()
};
match serde_json::from_str::<Vec<serde_json::Value>>(&json_str) {
Ok(entries) => entries.into_iter().filter_map(|e| parse_print_event(&e)).collect(),
Err(e) => {
warn!("Failed to parse print event output: {}", e);
Vec::new()
}
}
}
Ok(_) => {
// No print events or error — not logged as warning since this is normal
Vec::new()
}
Err(e) => {
warn!("Failed to run PowerShell for print events: {}", e);
Vec::new()
}
}
}
#[cfg(not(target_os = "windows"))]
{
Vec::new()
}
}
/// Parse a Windows Event Log entry for print event (Event ID 307).
/// The Message field contains: "Document N, owner owned by USER was printed on PRINTER via port PORT. Size in bytes: SIZE. Pages printed: PAGES. No pages for the client."
fn parse_print_event(entry: &serde_json::Value) -> Option<PrintJob> {
let timestamp = entry.get("TimeCreated")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if timestamp.is_empty() {
return None;
}
let message = entry.get("Message")
.and_then(|v| v.as_str())
.unwrap_or("");
let (document_name, printer_name, user_name, pages, file_size_bytes) = parse_print_message(message);
Some(PrintJob {
document_name: if document_name.is_empty() { None } else { Some(document_name) },
printer_name: if printer_name.is_empty() { None } else { Some(printer_name) },
pages,
copies: Some(1),
user_name: if user_name.is_empty() { None } else { Some(user_name) },
file_size_bytes,
timestamp,
})
}
/// Parse the print event message text to extract fields.
/// Example: "Document 10, Test Page owned by JOHN was printed on HP LaserJet via port PORT. Size in bytes: 12345. Pages printed: 1."
fn parse_print_message(msg: &str) -> (String, String, String, Option<i32>, Option<i64>) {
let mut document_name = String::new();
let mut printer_name = String::new();
let mut user_name = String::new();
let mut pages: Option<i32> = None;
let mut file_size_bytes: Option<i64> = None;
// Extract document name: "Document N, <name> owned by"
if let Some(start) = msg.find("Document ") {
let rest = &msg[start + "Document ".len()..];
// Skip job number and comma
if let Some(comma_pos) = rest.find(", ") {
let after_comma = &rest[comma_pos + 2..];
if let Some(owned_pos) = after_comma.find(" owned by ") {
document_name = after_comma[..owned_pos].trim().to_string();
let after_owned = &after_comma[owned_pos + " owned by ".len()..];
if let Some(was_pos) = after_owned.find(" was printed on ") {
user_name = after_owned[..was_pos].trim().to_string();
let after_printer = &after_owned[was_pos + " was printed on ".len()..];
if let Some(via_pos) = after_printer.find(" via port") {
printer_name = after_printer[..via_pos].trim().to_string();
}
}
}
}
}
// Extract pages: "Pages printed: N."
if let Some(pages_start) = msg.find("Pages printed: ") {
let rest = &msg[pages_start + "Pages printed: ".len()..];
let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
if !num_str.is_empty() {
pages = num_str.parse().ok();
}
}
// Extract file size: "Size in bytes: N."
if let Some(size_start) = msg.find("Size in bytes: ") {
let rest = &msg[size_start + "Size in bytes: ".len()..];
let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
if !num_str.is_empty() {
file_size_bytes = num_str.parse().ok();
}
}
(document_name, printer_name, user_name, pages, file_size_bytes)
}

View File

@@ -25,6 +25,7 @@ const PROTECTED_PROCESSES: &[&str] = &[
/// Software blacklist entry from server
#[derive(Debug, Clone, Deserialize)]
pub struct BlacklistEntry {
#[allow(dead_code)]
pub id: i64,
pub name_pattern: String,
pub action: String,

View File

@@ -221,9 +221,10 @@ unsafe extern "system" fn watermark_wnd_proc(
GetSystemMetrics(SM_CYSCREEN),
SWP_SHOWWINDOW,
);
let alpha = (s.opacity * 255.0).clamp(0.0, 255.0) as u8;
// Color key black background to transparent, apply alpha to text
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY | LWA_ALPHA);
let alpha = (s.opacity * 255.0).clamp(1.0, 255.0) as u8;
// Use only LWA_COLORKEY: black background becomes fully transparent.
// Text is drawn with the actual color, no additional alpha dimming.
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY);
let _ = InvalidateRect(hwnd, None, true);
} else {
let _ = ShowWindow(hwnd, SW_HIDE);
@@ -243,7 +244,7 @@ unsafe extern "system" fn watermark_wnd_proc(
fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkState) {
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::core::PCSTR;
use windows::core::PCWSTR;
unsafe {
let mut ps = PAINTSTRUCT::default();
@@ -252,8 +253,11 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
let color = parse_color(&state.color);
let font_size = state.font_size.max(1);
// Create font with rotation
let font = CreateFontA(
// Create wide font name for CreateFontW (supports CJK characters)
let font_name: Vec<u16> = "Microsoft YaHei\0".encode_utf16().collect();
// Create font with rotation using CreateFontW for proper Unicode support
let font = CreateFontW(
(font_size as i32) * 2,
0,
(state.angle as i32) * 10,
@@ -265,7 +269,7 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
CLIP_DEFAULT_PRECIS.0 as u32,
DEFAULT_QUALITY.0 as u32,
DEFAULT_PITCH.0 as u32 | FF_DONTCARE.0 as u32,
PCSTR("Arial\0".as_ptr()),
PCWSTR(font_name.as_ptr()),
);
let old_font = SelectObject(hdc, font);
@@ -273,12 +277,13 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
let _ = SetBkMode(hdc, TRANSPARENT);
let _ = SetTextColor(hdc, color);
// Draw tiled watermark text
// Draw tiled watermark text using TextOutW with UTF-16 encoding
let screen_w = GetSystemMetrics(SM_CXSCREEN);
let screen_h = GetSystemMetrics(SM_CYSCREEN);
let content_bytes: Vec<u8> = state.content.bytes().chain(std::iter::once(0)).collect();
let text_slice = &content_bytes[..content_bytes.len().saturating_sub(1)];
// Encode content as UTF-16 for TextOutW (supports Chinese and all Unicode)
let wide_content: Vec<u16> = state.content.encode_utf16().collect();
let text_slice = wide_content.as_slice();
let spacing_x = 400i32;
let spacing_y = 200i32;
@@ -287,7 +292,7 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
while y < screen_h + 100 {
let mut x = -200i32;
while x < screen_w + 200 {
let _ = TextOutA(hdc, x, y, text_slice);
let _ = TextOutW(hdc, x, y, text_slice);
x += spacing_x;
}
y += spacing_y;

View File

@@ -6,6 +6,7 @@ use std::io;
/// Web filter rule from server
#[derive(Debug, Clone, Deserialize)]
pub struct WebFilterRule {
#[allow(dead_code)]
pub id: i64,
pub rule_type: String,
pub pattern: String,