feat: 添加新插件支持及多项功能改进
- 新增磁盘加密、打印审计和剪贴板管控插件支持 - 优化水印插件显示效果,支持中文及更多Unicode字符 - 改进硬件资产收集逻辑,更准确获取磁盘和显卡信息 - 增强API错误处理,添加详细日志记录 - 完善前端界面,新增插件管理页面 - 修复多个UI问题,优化页面过渡效果 - 添加环境变量覆盖配置功能 - 实现插件状态管理API - 更新文档和变更日志 - 添加安装程序脚本支持
This commit is contained in:
@@ -4,7 +4,7 @@ use tokio::sync::mpsc::Sender;
|
||||
use tracing::{info, error};
|
||||
|
||||
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
let interval = Duration::from_secs(86400);
|
||||
let interval = Duration::from_secs(43200);
|
||||
|
||||
if let Err(e) = collect_and_send(&tx, &device_uid).await {
|
||||
error!("Initial asset collection failed: {}", e);
|
||||
@@ -50,18 +50,14 @@ fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
|
||||
// Memory
|
||||
let memory_total_mb = sys.total_memory() / 1024 / 1024;
|
||||
|
||||
// Disk — pick the largest non-removable disk
|
||||
// Disk — use PowerShell for real hardware model, sysinfo for total capacity
|
||||
let disks = sysinfo::Disks::new_with_refreshed_list();
|
||||
let (disk_model, disk_total_mb) = disks.iter()
|
||||
.filter(|d| d.kind() == sysinfo::DiskKind::HDD || d.kind() == sysinfo::DiskKind::SSD)
|
||||
.max_by_key(|d| d.total_space())
|
||||
.map(|d| {
|
||||
let total = d.total_space() / 1024 / 1024;
|
||||
let name = d.name().to_string_lossy().to_string();
|
||||
let model = if name.is_empty() { "Unknown".to_string() } else { name };
|
||||
(model, total)
|
||||
})
|
||||
.unwrap_or_else(|| ("Unknown".to_string(), 0));
|
||||
let disk_total_mb: u64 = disks.iter()
|
||||
.map(|d| d.total_space() / 1024 / 1024)
|
||||
.sum::<u64>()
|
||||
.max(1)
|
||||
.saturating_sub(1); // avoid reporting 0 if no disks
|
||||
let disk_model = collect_disk_model().unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
// GPU, motherboard, serial — Windows-specific via PowerShell
|
||||
let (gpu_model, motherboard, serial_number) = collect_system_details();
|
||||
@@ -81,12 +77,12 @@ fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn collect_system_details() -> (Option<String>, Option<String>, Option<String>) {
|
||||
// GPU: query all controllers, filter out virtual/IDDDriver devices, prefer real GPU
|
||||
// GPU: query all controllers, only exclude explicit virtual/placeholder devices
|
||||
let gpu = {
|
||||
let gpus = powershell_lines(
|
||||
"Get-CimInstance Win32_VideoController | Where-Object { $_.Name -notmatch 'IddDriver|Virtual|Basic Render|Microsoft Basic Display|Remote Desktop|Mirror Driver' } | Select-Object -ExpandProperty Name"
|
||||
"Get-CimInstance Win32_VideoController | Where-Object { $_.Name -notmatch 'IddDriver|Virtual Display|Basic Render|Microsoft Basic Display Adapter|Mirror Driver' } | Select-Object -ExpandProperty Name"
|
||||
);
|
||||
// Prefer NVIDIA/AMD/Intel, fallback to first non-virtual
|
||||
info!("Detected GPUs: {:?}", gpus);
|
||||
gpus.into_iter().next()
|
||||
};
|
||||
let mb_manufacturer = powershell_first("Get-CimInstance Win32_BaseBoard | Select-Object -ExpandProperty Manufacturer");
|
||||
@@ -102,6 +98,20 @@ fn collect_system_details() -> (Option<String>, Option<String>, Option<String>)
|
||||
(gpu, motherboard, serial_number)
|
||||
}
|
||||
|
||||
/// Get real disk hardware model via PowerShell Get-PhysicalDisk.
|
||||
#[cfg(target_os = "windows")]
|
||||
fn collect_disk_model() -> Option<String> {
|
||||
let models = powershell_lines(
|
||||
"Get-PhysicalDisk | Select-Object -ExpandProperty FriendlyName"
|
||||
);
|
||||
models.into_iter().next()
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn collect_disk_model() -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn collect_system_details() -> (Option<String>, Option<String>, Option<String>) {
|
||||
(None, None, None)
|
||||
@@ -150,6 +160,7 @@ fn collect_windows_software(device_uid: &str) -> Vec<SoftwareAsset> {
|
||||
use std::process::Command;
|
||||
|
||||
let ps_cmd = r#"
|
||||
[Console]::OutputEncoding = [System.Text.Encoding]::UTF8
|
||||
$paths = @(
|
||||
"HKLM:\SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\*",
|
||||
"HKLM:\SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\*",
|
||||
|
||||
212
crates/client/src/clipboard_control/mod.rs
Normal file
212
crates/client/src/clipboard_control/mod.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
use csm_protocol::{Frame, MessageType, ClipboardRule, ClipboardViolationPayload};
|
||||
|
||||
/// Clipboard control configuration pushed from server
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ClipboardControlConfig {
|
||||
pub enabled: bool,
|
||||
pub rules: Vec<ClipboardRule>,
|
||||
}
|
||||
|
||||
/// Start the clipboard control plugin.
|
||||
/// Periodically checks clipboard content against rules and reports violations.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<ClipboardControlConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Clipboard control plugin started");
|
||||
let mut config = ClipboardControlConfig::default();
|
||||
let mut check_interval = tokio::time::interval(Duration::from_secs(2));
|
||||
check_interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!(
|
||||
"Clipboard control config updated: enabled={}, rules={}",
|
||||
new_config.enabled,
|
||||
new_config.rules.len()
|
||||
);
|
||||
config = new_config;
|
||||
}
|
||||
_ = check_interval.tick() => {
|
||||
if !config.enabled || config.rules.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let uid = device_uid.clone();
|
||||
let rules = config.rules.clone();
|
||||
let result = tokio::task::spawn_blocking(move || check_clipboard(&uid, &rules)).await;
|
||||
match result {
|
||||
Ok(Some(payload)) => {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::ClipboardViolation, &payload) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
warn!("Failed to send clipboard violation: channel closed");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => warn!("Clipboard check task failed: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check clipboard content against rules. Returns a violation payload if a rule matched.
|
||||
fn check_clipboard(device_uid: &str, rules: &[ClipboardRule]) -> Option<ClipboardViolationPayload> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let clipboard_text = get_clipboard_text();
|
||||
let foreground_process = get_foreground_process();
|
||||
|
||||
for rule in rules {
|
||||
if rule.rule_type != "block" {
|
||||
continue;
|
||||
}
|
||||
// Check direction — only interested in "out" or "both"
|
||||
if !matches!(rule.direction.as_str(), "out" | "both") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check source process filter
|
||||
if let Some(ref src_pattern) = rule.source_process {
|
||||
if let Some(ref fg_proc) = foreground_process {
|
||||
if !pattern_match(src_pattern, fg_proc) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check content pattern
|
||||
if let Some(ref content_pattern) = rule.content_pattern {
|
||||
if let Some(ref text) = clipboard_text {
|
||||
if !content_matches(content_pattern, text) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Rule matched — generate violation (never send raw content)
|
||||
let preview = clipboard_text.as_ref().map(|t| format!("[{} chars]", t.len()));
|
||||
|
||||
// Clear clipboard to enforce block
|
||||
let _ = std::process::Command::new("powershell")
|
||||
.args(["-NoProfile", "-NonInteractive", "-Command", "Set-Clipboard -Value ''"])
|
||||
.output();
|
||||
|
||||
info!("Clipboard blocked: rule_id={}", rule.id);
|
||||
return Some(ClipboardViolationPayload {
|
||||
device_uid: device_uid.to_string(),
|
||||
source_process: foreground_process,
|
||||
target_process: None,
|
||||
content_preview: preview,
|
||||
action_taken: "blocked".to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = (device_uid, rules);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn get_clipboard_text() -> Option<String> {
|
||||
let output = std::process::Command::new("powershell")
|
||||
.args(["-NoProfile", "-NonInteractive", "-Command", "Get-Clipboard -Raw"])
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let text = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if text.is_empty() { None } else { Some(text) }
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn get_foreground_process() -> Option<String> {
|
||||
let output = std::process::Command::new("powershell")
|
||||
.args([
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
r#"Add-Type @"
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
public class WinAPI {
|
||||
[DllImport("user32.dll")] public static extern IntPtr GetForegroundWindow();
|
||||
[DllImport("user32.dll")] public static extern uint GetWindowThreadProcessId(IntPtr hWnd, out uint lpdwProcessId);
|
||||
}
|
||||
"@
|
||||
$hwnd = [WinAPI]::GetForegroundWindow()
|
||||
$pid = 0
|
||||
[WinAPI]::GetWindowThreadProcessId($hwnd, [ref]$pid) | Out-Null
|
||||
if ($pid -gt 0) { (Get-Process -Id $pid -ErrorAction SilentlyContinue).ProcessName } else { "" }"#,
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple case-insensitive wildcard pattern matching. Supports `*` as wildcard.
|
||||
fn pattern_match(pattern: &str, text: &str) -> bool {
|
||||
let p = pattern.to_lowercase();
|
||||
let t = text.to_lowercase();
|
||||
if !p.contains('*') {
|
||||
return t.contains(&p);
|
||||
}
|
||||
let parts: Vec<&str> = p.split('*').collect();
|
||||
if parts.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let mut pos = 0usize;
|
||||
let mut matched_any = false;
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
matched_any = true;
|
||||
if i == 0 && !parts[0].is_empty() {
|
||||
if !t.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else {
|
||||
match t[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
if matched_any && !parts.last().map_or(true, |p| p.is_empty()) {
|
||||
return t.ends_with(parts.last().unwrap());
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn content_matches(pattern: &str, text: &str) -> bool {
|
||||
text.to_lowercase().contains(&pattern.to_lowercase())
|
||||
}
|
||||
200
crates/client/src/disk_encryption/mod.rs
Normal file
200
crates/client/src/disk_encryption/mod.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, DiskEncryptionStatusPayload, DriveEncryptionInfo, DiskEncryptionConfigPayload};
|
||||
|
||||
/// Disk encryption configuration pushed from server
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct DiskEncryptionConfig {
|
||||
pub enabled: bool,
|
||||
pub report_interval_secs: u64,
|
||||
}
|
||||
|
||||
impl From<DiskEncryptionConfigPayload> for DiskEncryptionConfig {
|
||||
fn from(payload: DiskEncryptionConfigPayload) -> Self {
|
||||
Self {
|
||||
enabled: payload.enabled,
|
||||
report_interval_secs: payload.report_interval_secs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the disk encryption detection plugin.
|
||||
/// On startup and periodically, collects BitLocker volume status via PowerShell
|
||||
/// and sends results to the server.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<DiskEncryptionConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Disk encryption plugin started");
|
||||
|
||||
let mut config = DiskEncryptionConfig::default();
|
||||
let default_interval_secs: u64 = 3600;
|
||||
let mut report_interval = tokio::time::interval(Duration::from_secs(default_interval_secs));
|
||||
report_interval.tick().await;
|
||||
|
||||
// Collect and report once on startup if enabled
|
||||
if config.enabled {
|
||||
collect_and_report(&data_tx, &device_uid).await;
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
if new_config.enabled != config.enabled {
|
||||
info!("Disk encryption enabled: {}", new_config.enabled);
|
||||
}
|
||||
config = new_config;
|
||||
if config.enabled {
|
||||
let secs = if config.report_interval_secs > 0 {
|
||||
config.report_interval_secs
|
||||
} else {
|
||||
default_interval_secs
|
||||
};
|
||||
report_interval = tokio::time::interval(Duration::from_secs(secs));
|
||||
report_interval.tick().await;
|
||||
}
|
||||
}
|
||||
_ = report_interval.tick() => {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
collect_and_report(&data_tx, &device_uid).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_and_report(
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
) {
|
||||
let uid = device_uid.to_string();
|
||||
match tokio::task::spawn_blocking(move || collect_bitlocker_status()).await {
|
||||
Ok(drives) => {
|
||||
if drives.is_empty() {
|
||||
debug!("No BitLocker volumes found for device {}", uid);
|
||||
return;
|
||||
}
|
||||
let payload = DiskEncryptionStatusPayload {
|
||||
device_uid: uid,
|
||||
drives,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionStatus, &payload) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
warn!("Failed to send disk encryption status: channel closed");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to collect disk encryption status: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect BitLocker volume information via PowerShell.
|
||||
/// Runs: Get-BitLockerVolume | ConvertTo-Json
|
||||
fn collect_bitlocker_status() -> Vec<DriveEncryptionInfo> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let output = std::process::Command::new("powershell")
|
||||
.args([
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
"Get-BitLockerVolume | Select-Object MountPoint, VolumeName, EncryptionMethod, ProtectionStatus, EncryptionPercentage, LockStatus | ConvertTo-Json -Compress",
|
||||
])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(out) if out.status.success() => {
|
||||
let stdout = String::from_utf8_lossy(&out.stdout);
|
||||
let trimmed = stdout.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
// PowerShell returns a single object (not array) when there is exactly one volume
|
||||
let json_str = if trimmed.starts_with('{') {
|
||||
format!("[{}]", trimmed)
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
};
|
||||
match serde_json::from_str::<Vec<serde_json::Value>>(&json_str) {
|
||||
Ok(entries) => entries.into_iter().map(|e| parse_bitlocker_entry(&e)).collect(),
|
||||
Err(e) => {
|
||||
warn!("Failed to parse BitLocker output: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(out) => {
|
||||
let stderr = String::from_utf8_lossy(&out.stderr);
|
||||
warn!("PowerShell BitLocker query failed: {}", stderr);
|
||||
Vec::new()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to run PowerShell for BitLocker status: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bitlocker_entry(entry: &serde_json::Value) -> DriveEncryptionInfo {
|
||||
let mount_point = entry.get("MountPoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Unknown:")
|
||||
.to_string();
|
||||
|
||||
let volume_name = entry.get("VolumeName")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(String::from);
|
||||
|
||||
let encryption_method = entry.get("EncryptionMethod")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.is_empty() && *s != "None")
|
||||
.map(String::from);
|
||||
|
||||
let protection_status = match entry.get("ProtectionStatus") {
|
||||
Some(v) if v.is_number() => match v.as_i64().unwrap_or(0) {
|
||||
1 => "On".to_string(),
|
||||
0 => "Off".to_string(),
|
||||
_ => "Unknown".to_string(),
|
||||
},
|
||||
Some(v) if v.is_string() => v.as_str().unwrap_or("Unknown").to_string(),
|
||||
_ => "Unknown".to_string(),
|
||||
};
|
||||
|
||||
let encryption_percentage = entry.get("EncryptionPercentage")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let lock_status = match entry.get("LockStatus") {
|
||||
Some(v) if v.is_number() => match v.as_i64().unwrap_or(0) {
|
||||
1 => "Locked".to_string(),
|
||||
0 => "Unlocked".to_string(),
|
||||
_ => "Unknown".to_string(),
|
||||
},
|
||||
Some(v) if v.is_string() => v.as_str().unwrap_or("Unknown").to_string(),
|
||||
_ => "Unknown".to_string(),
|
||||
};
|
||||
|
||||
DriveEncryptionInfo {
|
||||
drive_letter: mount_point,
|
||||
volume_name,
|
||||
encryption_method,
|
||||
protection_status,
|
||||
encryption_percentage,
|
||||
lock_status,
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,9 @@ mod usb_audit;
|
||||
mod popup_blocker;
|
||||
mod software_blocker;
|
||||
mod web_filter;
|
||||
mod disk_encryption;
|
||||
mod clipboard_control;
|
||||
mod print_audit;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod service;
|
||||
|
||||
@@ -91,7 +94,11 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
let (usb_audit_tx, usb_audit_rx) = tokio::sync::watch::channel(usb_audit::UsbAuditConfig::default());
|
||||
let (usage_timer_tx, usage_timer_rx) = tokio::sync::watch::channel(usage_timer::UsageConfig::default());
|
||||
let (usb_policy_tx, usb_policy_rx) = tokio::sync::watch::channel(None::<UsbPolicyPayload>);
|
||||
let (disk_encryption_tx, disk_encryption_rx) = tokio::sync::watch::channel(disk_encryption::DiskEncryptionConfig::default());
|
||||
let (print_audit_tx, print_audit_rx) = tokio::sync::watch::channel(print_audit::PrintAuditConfig::default());
|
||||
let (clipboard_control_tx, clipboard_control_rx) = tokio::sync::watch::channel(clipboard_control::ClipboardControlConfig::default());
|
||||
|
||||
// Build plugin channels struct
|
||||
let plugins = network::PluginChannels {
|
||||
watermark_tx,
|
||||
web_filter_tx,
|
||||
@@ -100,6 +107,9 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
usb_audit_tx,
|
||||
usage_timer_tx,
|
||||
usb_policy_tx,
|
||||
disk_encryption_tx,
|
||||
print_audit_tx,
|
||||
clipboard_control_tx,
|
||||
};
|
||||
|
||||
// Spawn core monitoring tasks
|
||||
@@ -138,8 +148,10 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
usb_audit::start(usb_audit_rx, audit_data_tx, audit_uid).await;
|
||||
});
|
||||
|
||||
let pb_data_tx = data_tx.clone();
|
||||
let pb_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
popup_blocker::start(popup_blocker_rx).await;
|
||||
popup_blocker::start(popup_blocker_rx, pb_data_tx, pb_uid).await;
|
||||
});
|
||||
|
||||
let sw_data_tx = data_tx.clone();
|
||||
@@ -152,6 +164,24 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
web_filter::start(web_filter_rx).await;
|
||||
});
|
||||
|
||||
let de_data_tx = data_tx.clone();
|
||||
let de_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
disk_encryption::start(disk_encryption_rx, de_data_tx, de_uid).await;
|
||||
});
|
||||
|
||||
let pa_data_tx = data_tx.clone();
|
||||
let pa_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
print_audit::start(print_audit_rx, pa_data_tx, pa_uid).await;
|
||||
});
|
||||
|
||||
let cc_data_tx = data_tx.clone();
|
||||
let cc_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
clipboard_control::start(clipboard_control_rx, cc_data_tx, cc_uid).await;
|
||||
});
|
||||
|
||||
// Connect to server with reconnect
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let max_backoff = Duration::from_secs(60);
|
||||
@@ -163,6 +193,7 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
}
|
||||
|
||||
match network::connect_and_run(&state, &mut data_rx, &plugins).await {
|
||||
// Plugin channels moved into plugins struct — watchers are already cloned per-task
|
||||
Ok(()) => {
|
||||
warn!("Disconnected from server, reconnecting...");
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
@@ -186,29 +217,58 @@ pub async fn run(state: ClientState) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get directory for storing persistent files (next to the executable)
|
||||
fn data_dir() -> std::path::PathBuf {
|
||||
std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|p| p.to_path_buf()))
|
||||
.unwrap_or_else(|| std::path::PathBuf::from("."))
|
||||
}
|
||||
|
||||
fn load_or_create_device_uid() -> Result<String> {
|
||||
let uid_file = "device_uid.txt";
|
||||
if std::path::Path::new(uid_file).exists() {
|
||||
let uid = std::fs::read_to_string(uid_file)?;
|
||||
let uid_file = data_dir().join("device_uid.txt");
|
||||
if uid_file.exists() {
|
||||
let uid = std::fs::read_to_string(&uid_file)?;
|
||||
Ok(uid.trim().to_string())
|
||||
} else {
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
std::fs::write(uid_file, &uid)?;
|
||||
write_restricted_file(&uid_file, &uid)?;
|
||||
Ok(uid)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load persisted device_secret from disk (if available)
|
||||
pub fn load_device_secret() -> Option<String> {
|
||||
let secret_file = "device_secret.txt";
|
||||
let secret = std::fs::read_to_string(secret_file).ok()?;
|
||||
let secret_file = data_dir().join("device_secret.txt");
|
||||
let secret = std::fs::read_to_string(&secret_file).ok()?;
|
||||
let trimmed = secret.trim().to_string();
|
||||
if trimmed.is_empty() { None } else { Some(trimmed) }
|
||||
}
|
||||
|
||||
/// Persist device_secret to disk
|
||||
/// Persist device_secret to disk with restricted permissions
|
||||
pub fn save_device_secret(secret: &str) {
|
||||
if let Err(e) = std::fs::write("device_secret.txt", secret) {
|
||||
let secret_file = data_dir().join("device_secret.txt");
|
||||
if let Err(e) = write_restricted_file(&secret_file, secret) {
|
||||
warn!("Failed to persist device_secret: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a file with owner-only permissions (0o600 on Unix).
|
||||
/// On Windows, the file inherits the directory's ACL — consider setting
|
||||
/// explicit ACLs via PowerShell for production deployments.
|
||||
#[cfg(unix)]
|
||||
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(path)
|
||||
.and_then(|mut f| std::io::Write::write_all(&mut f, content.as_bytes()))
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn write_restricted_file(path: &std::path::Path, content: &str) -> std::io::Result<()> {
|
||||
std::fs::write(path, content)
|
||||
}
|
||||
|
||||
@@ -1,24 +1,33 @@
|
||||
use anyhow::Result;
|
||||
use csm_protocol::{Frame, MessageType, DeviceStatus, ProcessInfo};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{info, error, debug};
|
||||
use sysinfo::System;
|
||||
use sysinfo::Disks;
|
||||
use sysinfo::Networks;
|
||||
|
||||
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
let interval = Duration::from_secs(60);
|
||||
let mut prev_rx: Option<u64> = None;
|
||||
let mut prev_tx: Option<u64> = None;
|
||||
|
||||
loop {
|
||||
// Run blocking sysinfo collection on a dedicated thread
|
||||
let uid_clone = device_uid.clone();
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
collect_system_status(&uid_clone)
|
||||
}).await;
|
||||
let prev_rx_c = prev_rx;
|
||||
let prev_tx_c = prev_tx;
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
collect_system_status(&uid_clone, prev_rx_c, prev_tx_c)
|
||||
}).await;
|
||||
match result {
|
||||
Ok(Ok(status)) => {
|
||||
Ok(Ok((status, new_rx, new_tx))) => {
|
||||
prev_rx = Some(new_rx);
|
||||
prev_tx = Some(new_tx);
|
||||
if let Ok(frame) = Frame::new_json(MessageType::StatusReport, &status) {
|
||||
debug!("Sending status report: cpu={:.1}%, mem={:.1}%", status.cpu_usage, status.memory_usage);
|
||||
debug!(
|
||||
"Sending status: cpu={:.1}%, mem={:.1}%, disk={:.1}%",
|
||||
status.cpu_usage, status.memory_usage, status.disk_usage
|
||||
);
|
||||
if tx.send(frame).await.is_err() {
|
||||
info!("Monitor channel closed, exiting");
|
||||
break;
|
||||
@@ -37,25 +46,68 @@ pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
|
||||
fn collect_system_status(
|
||||
device_uid: &str,
|
||||
prev_rx: Option<u64>,
|
||||
prev_tx: Option<u64>,
|
||||
) -> anyhow::Result<(DeviceStatus, u64, u64)> {
|
||||
let mut sys = System::new_all();
|
||||
sys.refresh_all();
|
||||
let disks = Disks::new_with_refreshed_list();
|
||||
let networks = Networks::new_with_refreshed_list();
|
||||
|
||||
// Brief wait for CPU usage to stabilize
|
||||
sys.refresh_all();
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
sys.refresh_all();
|
||||
sys.refresh_cpu_usage();
|
||||
|
||||
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
|
||||
|
||||
let total_memory = sys.total_memory() / 1024 / 1024; // Convert bytes to MB (sysinfo 0.30 returns bytes)
|
||||
let used_memory = sys.used_memory() / 1024 / 1024;
|
||||
let memory_usage = if total_memory > 0 {
|
||||
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
|
||||
let total_memory = sys.total_memory() / 1024 / 1024;
|
||||
// Convert bytes to MB
|
||||
let used_memory = sys.used_memory() / 1024 / 1024;
|
||||
let memory_usage = if total_memory > 0 {
|
||||
(used_memory as f64 / total_memory as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
// Disk usage
|
||||
let (disk_usage, disk_total_mb) = {
|
||||
let mut total_space: u64 = 0;
|
||||
let mut total_available: u64 = 0;
|
||||
for disk in disks.list() {
|
||||
let total = disk.total_space() / 1024 / 1024;
|
||||
// MB
|
||||
let available = disk.available_space() / 1024 / 1024;
|
||||
// MB
|
||||
total_space += total;
|
||||
total_available += available;
|
||||
}
|
||||
let used_mb = total_space.saturating_sub(total_available);
|
||||
let usage_pct = if total_space > 0 {
|
||||
(used_mb as f64 / total_space as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
(usage_pct, total_space)
|
||||
};
|
||||
// Network rate
|
||||
let (network_rx_rate, network_tx_rate, current_rx, current_tx) = {
|
||||
let mut cur_rx: u64 = 0;
|
||||
let mut cur_tx: u64 = 0;
|
||||
for (_, data) in networks.iter() {
|
||||
cur_rx += data.received();
|
||||
cur_tx += data.transmitted();
|
||||
}
|
||||
let rx_rate = match prev_rx {
|
||||
Some(prev) => cur_rx.saturating_sub(prev) / 60, // bytes/sec (60s interval)
|
||||
None => 0,
|
||||
};
|
||||
let tx_rate = match prev_tx {
|
||||
Some(prev) => cur_tx.saturating_sub(prev) / 60,
|
||||
None => 0,
|
||||
};
|
||||
(rx_rate, tx_rate, cur_rx, cur_tx)
|
||||
};
|
||||
|
||||
// Top processes by CPU
|
||||
// Top processes by CPU
|
||||
let mut processes: Vec<ProcessInfo> = sys.processes()
|
||||
.iter()
|
||||
.map(|(_, p)| {
|
||||
@@ -63,24 +115,24 @@ fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
|
||||
name: p.name().to_string(),
|
||||
pid: p.pid().as_u32(),
|
||||
cpu_usage: p.cpu_usage() as f64,
|
||||
memory_mb: p.memory() / 1024 / 1024, // bytes to MB (sysinfo 0.30)
|
||||
memory_mb: p.memory() / 1024 / 1024,
|
||||
// bytes to MB
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
processes.sort_by(|a, b| b.cpu_usage.partial_cmp(&a.cpu_usage).unwrap_or(std::cmp::Ordering::Equal));
|
||||
processes.truncate(10);
|
||||
|
||||
Ok(DeviceStatus {
|
||||
processes.truncate(10);
|
||||
let status = DeviceStatus {
|
||||
device_uid: device_uid.to_string(),
|
||||
cpu_usage,
|
||||
memory_usage,
|
||||
memory_total_mb: total_memory as u64,
|
||||
disk_usage: 0.0, // TODO: implement disk usage via Windows API
|
||||
disk_total_mb: 0,
|
||||
network_rx_rate: 0,
|
||||
network_tx_rate: 0,
|
||||
disk_usage,
|
||||
disk_total_mb,
|
||||
network_rx_rate: network_rx_rate,
|
||||
network_tx_rate: network_tx_rate,
|
||||
running_procs: sys.processes().len() as u32,
|
||||
top_processes: processes,
|
||||
})
|
||||
};
|
||||
Ok((status, current_rx, current_tx))
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
@@ -18,6 +18,9 @@ pub struct PluginChannels {
|
||||
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
|
||||
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
|
||||
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
|
||||
pub disk_encryption_tx: tokio::sync::watch::Sender<crate::disk_encryption::DiskEncryptionConfig>,
|
||||
pub print_audit_tx: tokio::sync::watch::Sender<crate::print_audit::PrintAuditConfig>,
|
||||
pub clipboard_control_tx: tokio::sync::watch::Sender<crate::clipboard_control::ClipboardControlConfig>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
@@ -286,6 +289,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
|
||||
plugins.popup_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::DiskEncryptionConfig => {
|
||||
let config: DiskEncryptionConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid disk encryption config: {}", e))?;
|
||||
info!("Received disk encryption config: enabled={}, interval={}s", config.enabled, config.report_interval_secs);
|
||||
let plugin_config = crate::disk_encryption::DiskEncryptionConfig {
|
||||
enabled: config.enabled,
|
||||
report_interval_secs: config.report_interval_secs,
|
||||
};
|
||||
plugins.disk_encryption_tx.send(plugin_config)?;
|
||||
}
|
||||
MessageType::PluginEnable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
|
||||
@@ -299,6 +312,16 @@ fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
info!("Plugin disabled: {}", payload.plugin_name);
|
||||
handle_plugin_control(&payload, plugins, false)?;
|
||||
}
|
||||
MessageType::ClipboardRules => {
|
||||
let payload: csm_protocol::ClipboardRulesPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid clipboard rules: {}", e))?;
|
||||
info!("Received clipboard rules update: {} rules", payload.rules.len());
|
||||
let config = crate::clipboard_control::ClipboardControlConfig {
|
||||
enabled: true,
|
||||
rules: payload.rules,
|
||||
};
|
||||
plugins.clipboard_control_tx.send(config)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
@@ -346,6 +369,21 @@ fn handle_plugin_control(
|
||||
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
"disk_encryption" => {
|
||||
if !enabled {
|
||||
plugins.disk_encryption_tx.send(crate::disk_encryption::DiskEncryptionConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
"print_audit" => {
|
||||
if !enabled {
|
||||
plugins.print_audit_tx.send(crate::print_audit::PrintAuditConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
"clipboard_control" => {
|
||||
if !enabled {
|
||||
plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
|
||||
@@ -1,370 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::ClientState;
|
||||
|
||||
/// Maximum accumulated read buffer size per connection (8 MB)
|
||||
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Holds senders for all plugin config channels
|
||||
pub struct PluginChannels {
|
||||
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
|
||||
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
|
||||
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
|
||||
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
|
||||
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
|
||||
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
pub async fn connect_and_run(
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
|
||||
info!("TCP connected to {}", state.server_addr);
|
||||
|
||||
if state.use_tls {
|
||||
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
|
||||
run_comm_loop(tls_stream, state, data_rx, plugins).await
|
||||
} else {
|
||||
run_comm_loop(tcp_stream, state, data_rx, plugins).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a TCP stream with TLS.
|
||||
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
|
||||
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
|
||||
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load custom CA certificate if specified
|
||||
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
|
||||
let ca_pem = std::fs::read(&ca_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
|
||||
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
|
||||
for cert in certs {
|
||||
root_store.add(cert)?;
|
||||
}
|
||||
info!("Loaded custom CA certificates from {}", ca_path);
|
||||
}
|
||||
|
||||
// Always include system roots as fallback
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
|
||||
warn!("TLS certificate verification DISABLED — do not use in production!");
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
||||
|
||||
// Extract hostname from server_addr (strip port)
|
||||
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
|
||||
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
|
||||
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
|
||||
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
info!("TLS handshake completed with {}", domain);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer],
|
||||
_server_name: &rustls_pki_types::ServerName,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Main communication loop over any read+write stream
|
||||
async fn run_comm_loop<S>(
|
||||
mut stream: S,
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||
{
|
||||
// Send registration
|
||||
let register = RegisterRequest {
|
||||
device_uid: state.device_uid.clone(),
|
||||
hostname: hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string()),
|
||||
registration_token: state.registration_token.clone(),
|
||||
os_version: get_os_info(),
|
||||
mac_address: None,
|
||||
};
|
||||
|
||||
let frame = Frame::new_json(MessageType::Register, ®ister)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
info!("Registration request sent");
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let heartbeat_secs = state.config.heartbeat_interval_secs;
|
||||
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
|
||||
heartbeat_interval.tick().await; // Skip first tick
|
||||
|
||||
// HMAC key — set after receiving RegisterResponse
|
||||
let mut device_secret: Option<String> = state.device_secret.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Read from server
|
||||
result = stream.read(&mut buffer) => {
|
||||
let n = result?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("Server closed connection"));
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Process complete frames
|
||||
loop {
|
||||
match Frame::decode(&read_buf)? {
|
||||
Some(frame) => {
|
||||
let consumed = frame.encoded_size();
|
||||
read_buf.drain(..consumed);
|
||||
// Capture device_secret from registration response
|
||||
if frame.msg_type == MessageType::RegisterResponse {
|
||||
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
|
||||
device_secret = Some(resp.device_secret.clone());
|
||||
crate::save_device_secret(&resp.device_secret);
|
||||
info!("Device secret received and persisted, HMAC enabled for heartbeats");
|
||||
}
|
||||
}
|
||||
handle_server_message(frame, plugins)?;
|
||||
}
|
||||
None => break, // Incomplete frame, wait for more data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queued data
|
||||
frame = data_rx.recv() => {
|
||||
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
}
|
||||
|
||||
// Heartbeat
|
||||
_ = heartbeat_interval.tick() => {
|
||||
let timestamp = chrono::Utc::now().to_rfc3339();
|
||||
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, ×tamp);
|
||||
let heartbeat = HeartbeatPayload {
|
||||
device_uid: state.device_uid.clone(),
|
||||
timestamp,
|
||||
hmac: hmac_value,
|
||||
};
|
||||
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::RegisterResponse => {
|
||||
let resp: RegisterResponse = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
|
||||
info!("Registration accepted by server (server version: {})", resp.config.server_version);
|
||||
}
|
||||
MessageType::PolicyUpdate => {
|
||||
let policy: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
|
||||
info!("Received policy update: {}", policy);
|
||||
}
|
||||
MessageType::ConfigUpdate => {
|
||||
info!("Received config update");
|
||||
}
|
||||
MessageType::TaskExecute => {
|
||||
warn!("Task execution requested (not yet implemented)");
|
||||
}
|
||||
MessageType::WatermarkConfig => {
|
||||
let config: WatermarkConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
|
||||
info!("Received watermark config: enabled={}", config.enabled);
|
||||
plugins.watermark_tx.send(Some(config))?;
|
||||
}
|
||||
MessageType::WebFilterRuleUpdate => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
|
||||
info!("Received web filter rules update");
|
||||
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
|
||||
plugins.web_filter_tx.send(config)?;
|
||||
}
|
||||
MessageType::SoftwareBlacklist => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
|
||||
info!("Received software blacklist update");
|
||||
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
|
||||
plugins.software_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PopupRules => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
|
||||
info!("Received popup blocker rules update");
|
||||
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
|
||||
plugins.popup_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PluginEnable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
|
||||
info!("Plugin enabled: {}", payload.plugin_name);
|
||||
// Route to appropriate plugin channel based on plugin_name
|
||||
handle_plugin_control(&payload, plugins, true)?;
|
||||
}
|
||||
MessageType::PluginDisable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
|
||||
info!("Plugin disabled: {}", payload.plugin_name);
|
||||
handle_plugin_control(&payload, plugins, false)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_plugin_control(
|
||||
payload: &csm_protocol::PluginControlPayload,
|
||||
plugins: &PluginChannels,
|
||||
enabled: bool,
|
||||
) -> Result<()> {
|
||||
match payload.plugin_name.as_str() {
|
||||
"watermark" => {
|
||||
if !enabled {
|
||||
// Send disabled config to remove overlay
|
||||
plugins.watermark_tx.send(None)?;
|
||||
}
|
||||
// When enabling, server will push the actual config next
|
||||
}
|
||||
"web_filter" => {
|
||||
if !enabled {
|
||||
// Clear hosts rules on disable
|
||||
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
// When enabling, server will push rules
|
||||
}
|
||||
"software_blocker" => {
|
||||
if !enabled {
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
|
||||
}
|
||||
}
|
||||
"popup_blocker" => {
|
||||
if !enabled {
|
||||
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usb_audit" => {
|
||||
if !enabled {
|
||||
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usage_timer" => {
|
||||
if !enabled {
|
||||
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
|
||||
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
|
||||
let secret = match secret {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return String::new(),
|
||||
};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
fn get_os_info() -> String {
|
||||
use sysinfo::System;
|
||||
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
|
||||
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
|
||||
format!("{} {}", name, version)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
use tokio::sync::watch;
|
||||
use tokio::sync::{watch, mpsc};
|
||||
use tracing::{info, debug};
|
||||
use serde::Deserialize;
|
||||
use csm_protocol::{Frame, MessageType, PopupBlockStatsPayload, PopupRuleStat};
|
||||
|
||||
/// Popup blocker rule from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -23,15 +24,27 @@ pub struct PopupBlockerConfig {
|
||||
struct ScanContext {
|
||||
rules: Vec<PopupRule>,
|
||||
blocked_count: u32,
|
||||
rule_hits: std::collections::HashMap<i64, u32>,
|
||||
}
|
||||
|
||||
/// Start popup blocker plugin.
|
||||
/// Periodically enumerates windows and closes those matching rules.
|
||||
pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
|
||||
/// Reports statistics to server every 60 seconds.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<PopupBlockerConfig>,
|
||||
data_tx: mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Popup blocker plugin started");
|
||||
let mut config = PopupBlockerConfig::default();
|
||||
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(2));
|
||||
scan_interval.tick().await;
|
||||
let mut stats_interval = tokio::time::interval(std::time::Duration::from_secs(60));
|
||||
stats_interval.tick().await;
|
||||
|
||||
// Accumulated stats
|
||||
let mut total_blocked: u32 = 0;
|
||||
let mut rule_hits: std::collections::HashMap<i64, u32> = std::collections::HashMap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -47,13 +60,34 @@ pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
|
||||
if !config.enabled || config.rules.is_empty() {
|
||||
continue;
|
||||
}
|
||||
scan_and_block(&config.rules);
|
||||
let ctx = scan_and_block(&config.rules);
|
||||
total_blocked += ctx.blocked_count;
|
||||
for (rule_id, hits) in ctx.rule_hits {
|
||||
*rule_hits.entry(rule_id).or_insert(0) += hits;
|
||||
}
|
||||
}
|
||||
_ = stats_interval.tick() => {
|
||||
if total_blocked > 0 {
|
||||
let stats = PopupBlockStatsPayload {
|
||||
device_uid: device_uid.clone(),
|
||||
blocked_count: total_blocked,
|
||||
rule_stats: rule_hits.iter().map(|(&id, &hits)| PopupRuleStat { rule_id: id, hits }).collect(),
|
||||
period_secs: 60,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PopupBlockStats, &stats) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
debug!("Failed to send popup block stats: channel closed");
|
||||
}
|
||||
}
|
||||
total_blocked = 0;
|
||||
rule_hits.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_and_block(rules: &[PopupRule]) {
|
||||
fn scan_and_block(rules: &[PopupRule]) -> ScanContext {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::WindowsAndMessaging::EnumWindows;
|
||||
@@ -62,6 +96,7 @@ fn scan_and_block(rules: &[PopupRule]) {
|
||||
let mut ctx = ScanContext {
|
||||
rules: rules.to_vec(),
|
||||
blocked_count: 0,
|
||||
rule_hits: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
@@ -73,10 +108,12 @@ fn scan_and_block(rules: &[PopupRule]) {
|
||||
if ctx.blocked_count > 0 {
|
||||
debug!("Popup scan blocked {} windows", ctx.blocked_count);
|
||||
}
|
||||
ctx
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = rules;
|
||||
ScanContext { rules: vec![], blocked_count: 0, rule_hits: std::collections::HashMap::new() }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +170,7 @@ unsafe extern "system" fn enum_windows_callback(
|
||||
if matches {
|
||||
let _ = PostMessageW(hwnd, WM_CLOSE, WPARAM(0), LPARAM(0));
|
||||
ctx.blocked_count += 1;
|
||||
*ctx.rule_hits.entry(rule.id).or_insert(0) += 1;
|
||||
info!(
|
||||
"Blocked popup: title='{}' class='{}' process='{}' (rule_id={})",
|
||||
title, class_name, process_name, rule.id
|
||||
|
||||
249
crates/client/src/print_audit/mod.rs
Normal file
249
crates/client/src/print_audit/mod.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use std::collections::HashSet;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
use csm_protocol::{Frame, MessageType, PrintEventPayload};
|
||||
|
||||
/// Print audit configuration pushed from server
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PrintAuditConfig {
|
||||
pub enabled: bool,
|
||||
pub report_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Start the print audit plugin.
|
||||
/// On startup and periodically, queries Windows print spooler for recent
|
||||
/// print jobs via WMI and sends new events to the server.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<PrintAuditConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Print audit plugin started");
|
||||
|
||||
let mut config = PrintAuditConfig::default();
|
||||
let default_interval_secs: u64 = 300;
|
||||
let mut report_interval = tokio::time::interval(Duration::from_secs(default_interval_secs));
|
||||
report_interval.tick().await;
|
||||
|
||||
// Track seen print job IDs to avoid duplicates
|
||||
let mut seen_jobs: HashSet<String> = HashSet::new();
|
||||
|
||||
// Collect and report once on startup if enabled
|
||||
if config.enabled {
|
||||
collect_and_report(&data_tx, &device_uid, &mut seen_jobs).await;
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
if new_config.enabled != config.enabled {
|
||||
info!("Print audit enabled: {}", new_config.enabled);
|
||||
}
|
||||
config = new_config;
|
||||
if config.enabled {
|
||||
let secs = if config.report_interval_secs > 0 {
|
||||
config.report_interval_secs
|
||||
} else {
|
||||
default_interval_secs
|
||||
};
|
||||
report_interval = tokio::time::interval(Duration::from_secs(secs));
|
||||
report_interval.tick().await;
|
||||
}
|
||||
}
|
||||
_ = report_interval.tick() => {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
collect_and_report(&data_tx, &device_uid, &mut seen_jobs).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_and_report(
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
seen_jobs: &mut HashSet<String>,
|
||||
) {
|
||||
let uid = device_uid.to_string();
|
||||
match tokio::task::spawn_blocking(move || collect_print_jobs()).await {
|
||||
Ok(jobs) => {
|
||||
for job in jobs {
|
||||
let job_key = format!("{}|{}|{}", job.document_name.as_deref().unwrap_or(""), job.printer_name.as_deref().unwrap_or(""), &job.timestamp);
|
||||
if seen_jobs.contains(&job_key) {
|
||||
continue;
|
||||
}
|
||||
seen_jobs.insert(job_key.clone());
|
||||
|
||||
let payload = PrintEventPayload {
|
||||
device_uid: uid.clone(),
|
||||
document_name: job.document_name,
|
||||
printer_name: job.printer_name,
|
||||
pages: job.pages,
|
||||
copies: job.copies,
|
||||
user_name: job.user_name,
|
||||
file_size_bytes: job.file_size_bytes,
|
||||
timestamp: job.timestamp,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PrintEvent, &payload) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
warn!("Failed to send print event: channel closed");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Keep seen_jobs bounded — evict entries older than what we'd reasonably see
|
||||
if seen_jobs.len() > 10000 {
|
||||
seen_jobs.clear();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to collect print jobs: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PrintJob {
|
||||
document_name: Option<String>,
|
||||
printer_name: Option<String>,
|
||||
pages: Option<i32>,
|
||||
copies: Option<i32>,
|
||||
user_name: Option<String>,
|
||||
file_size_bytes: Option<i64>,
|
||||
timestamp: String,
|
||||
}
|
||||
|
||||
/// Collect recent print jobs via WMI on Windows.
|
||||
/// Queries Win32_PrintJob for jobs that completed in the recent period.
|
||||
fn collect_print_jobs() -> Vec<PrintJob> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
let output = std::process::Command::new("powershell")
|
||||
.args([
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
"Get-WinEvent -FilterHashtable @{LogName='Microsoft-Windows-PrintService/Operational'; ID=307} -MaxEvents 50 -ErrorAction SilentlyContinue | Select-Object TimeCreated, Message | ConvertTo-Json -Compress",
|
||||
])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(out) if out.status.success() => {
|
||||
let stdout = String::from_utf8_lossy(&out.stdout);
|
||||
let trimmed = stdout.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
// PowerShell may return single object or array
|
||||
let json_str = if trimmed.starts_with('{') {
|
||||
format!("[{}]", trimmed)
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
};
|
||||
match serde_json::from_str::<Vec<serde_json::Value>>(&json_str) {
|
||||
Ok(entries) => entries.into_iter().filter_map(|e| parse_print_event(&e)).collect(),
|
||||
Err(e) => {
|
||||
warn!("Failed to parse print event output: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_) => {
|
||||
// No print events or error — not logged as warning since this is normal
|
||||
Vec::new()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to run PowerShell for print events: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Windows Event Log entry for print event (Event ID 307).
|
||||
/// The Message field contains: "Document N, owner owned by USER was printed on PRINTER via port PORT. Size in bytes: SIZE. Pages printed: PAGES. No pages for the client."
|
||||
fn parse_print_event(entry: &serde_json::Value) -> Option<PrintJob> {
|
||||
let timestamp = entry.get("TimeCreated")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
if timestamp.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = entry.get("Message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let (document_name, printer_name, user_name, pages, file_size_bytes) = parse_print_message(message);
|
||||
|
||||
Some(PrintJob {
|
||||
document_name: if document_name.is_empty() { None } else { Some(document_name) },
|
||||
printer_name: if printer_name.is_empty() { None } else { Some(printer_name) },
|
||||
pages,
|
||||
copies: Some(1),
|
||||
user_name: if user_name.is_empty() { None } else { Some(user_name) },
|
||||
file_size_bytes,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse the print event message text to extract fields.
|
||||
/// Example: "Document 10, Test Page owned by JOHN was printed on HP LaserJet via port PORT. Size in bytes: 12345. Pages printed: 1."
|
||||
fn parse_print_message(msg: &str) -> (String, String, String, Option<i32>, Option<i64>) {
|
||||
let mut document_name = String::new();
|
||||
let mut printer_name = String::new();
|
||||
let mut user_name = String::new();
|
||||
let mut pages: Option<i32> = None;
|
||||
let mut file_size_bytes: Option<i64> = None;
|
||||
|
||||
// Extract document name: "Document N, <name> owned by"
|
||||
if let Some(start) = msg.find("Document ") {
|
||||
let rest = &msg[start + "Document ".len()..];
|
||||
// Skip job number and comma
|
||||
if let Some(comma_pos) = rest.find(", ") {
|
||||
let after_comma = &rest[comma_pos + 2..];
|
||||
if let Some(owned_pos) = after_comma.find(" owned by ") {
|
||||
document_name = after_comma[..owned_pos].trim().to_string();
|
||||
let after_owned = &after_comma[owned_pos + " owned by ".len()..];
|
||||
if let Some(was_pos) = after_owned.find(" was printed on ") {
|
||||
user_name = after_owned[..was_pos].trim().to_string();
|
||||
let after_printer = &after_owned[was_pos + " was printed on ".len()..];
|
||||
if let Some(via_pos) = after_printer.find(" via port") {
|
||||
printer_name = after_printer[..via_pos].trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract pages: "Pages printed: N."
|
||||
if let Some(pages_start) = msg.find("Pages printed: ") {
|
||||
let rest = &msg[pages_start + "Pages printed: ".len()..];
|
||||
let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
|
||||
if !num_str.is_empty() {
|
||||
pages = num_str.parse().ok();
|
||||
}
|
||||
}
|
||||
|
||||
// Extract file size: "Size in bytes: N."
|
||||
if let Some(size_start) = msg.find("Size in bytes: ") {
|
||||
let rest = &msg[size_start + "Size in bytes: ".len()..];
|
||||
let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
|
||||
if !num_str.is_empty() {
|
||||
file_size_bytes = num_str.parse().ok();
|
||||
}
|
||||
}
|
||||
|
||||
(document_name, printer_name, user_name, pages, file_size_bytes)
|
||||
}
|
||||
@@ -25,6 +25,7 @@ const PROTECTED_PROCESSES: &[&str] = &[
|
||||
/// Software blacklist entry from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BlacklistEntry {
|
||||
#[allow(dead_code)]
|
||||
pub id: i64,
|
||||
pub name_pattern: String,
|
||||
pub action: String,
|
||||
|
||||
@@ -221,9 +221,10 @@ unsafe extern "system" fn watermark_wnd_proc(
|
||||
GetSystemMetrics(SM_CYSCREEN),
|
||||
SWP_SHOWWINDOW,
|
||||
);
|
||||
let alpha = (s.opacity * 255.0).clamp(0.0, 255.0) as u8;
|
||||
// Color key black background to transparent, apply alpha to text
|
||||
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY | LWA_ALPHA);
|
||||
let alpha = (s.opacity * 255.0).clamp(1.0, 255.0) as u8;
|
||||
// Use only LWA_COLORKEY: black background becomes fully transparent.
|
||||
// Text is drawn with the actual color, no additional alpha dimming.
|
||||
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY);
|
||||
let _ = InvalidateRect(hwnd, None, true);
|
||||
} else {
|
||||
let _ = ShowWindow(hwnd, SW_HIDE);
|
||||
@@ -243,7 +244,7 @@ unsafe extern "system" fn watermark_wnd_proc(
|
||||
fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkState) {
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::core::PCSTR;
|
||||
use windows::core::PCWSTR;
|
||||
|
||||
unsafe {
|
||||
let mut ps = PAINTSTRUCT::default();
|
||||
@@ -252,8 +253,11 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
|
||||
let color = parse_color(&state.color);
|
||||
let font_size = state.font_size.max(1);
|
||||
|
||||
// Create font with rotation
|
||||
let font = CreateFontA(
|
||||
// Create wide font name for CreateFontW (supports CJK characters)
|
||||
let font_name: Vec<u16> = "Microsoft YaHei\0".encode_utf16().collect();
|
||||
|
||||
// Create font with rotation using CreateFontW for proper Unicode support
|
||||
let font = CreateFontW(
|
||||
(font_size as i32) * 2,
|
||||
0,
|
||||
(state.angle as i32) * 10,
|
||||
@@ -265,7 +269,7 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
|
||||
CLIP_DEFAULT_PRECIS.0 as u32,
|
||||
DEFAULT_QUALITY.0 as u32,
|
||||
DEFAULT_PITCH.0 as u32 | FF_DONTCARE.0 as u32,
|
||||
PCSTR("Arial\0".as_ptr()),
|
||||
PCWSTR(font_name.as_ptr()),
|
||||
);
|
||||
|
||||
let old_font = SelectObject(hdc, font);
|
||||
@@ -273,12 +277,13 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
|
||||
let _ = SetBkMode(hdc, TRANSPARENT);
|
||||
let _ = SetTextColor(hdc, color);
|
||||
|
||||
// Draw tiled watermark text
|
||||
// Draw tiled watermark text using TextOutW with UTF-16 encoding
|
||||
let screen_w = GetSystemMetrics(SM_CXSCREEN);
|
||||
let screen_h = GetSystemMetrics(SM_CYSCREEN);
|
||||
|
||||
let content_bytes: Vec<u8> = state.content.bytes().chain(std::iter::once(0)).collect();
|
||||
let text_slice = &content_bytes[..content_bytes.len().saturating_sub(1)];
|
||||
// Encode content as UTF-16 for TextOutW (supports Chinese and all Unicode)
|
||||
let wide_content: Vec<u16> = state.content.encode_utf16().collect();
|
||||
let text_slice = wide_content.as_slice();
|
||||
|
||||
let spacing_x = 400i32;
|
||||
let spacing_y = 200i32;
|
||||
@@ -287,7 +292,7 @@ fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkStat
|
||||
while y < screen_h + 100 {
|
||||
let mut x = -200i32;
|
||||
while x < screen_w + 200 {
|
||||
let _ = TextOutA(hdc, x, y, text_slice);
|
||||
let _ = TextOutW(hdc, x, y, text_slice);
|
||||
x += spacing_x;
|
||||
}
|
||||
y += spacing_y;
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::io;
|
||||
/// Web filter rule from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct WebFilterRule {
|
||||
#[allow(dead_code)]
|
||||
pub id: i64,
|
||||
pub rule_type: String,
|
||||
pub pattern: String,
|
||||
|
||||
@@ -24,4 +24,9 @@ pub use message::{
|
||||
SoftwareViolationReport, UsbFileOpEntry,
|
||||
WatermarkConfigPayload, PluginControlPayload,
|
||||
UsbPolicyPayload, UsbDeviceRule,
|
||||
DiskEncryptionStatusPayload, DriveEncryptionInfo,
|
||||
DiskEncryptionConfigPayload,
|
||||
PrintEventPayload,
|
||||
ClipboardRulesPayload, ClipboardRule, ClipboardViolationPayload,
|
||||
PopupBlockStatsPayload, PopupRuleStat,
|
||||
};
|
||||
|
||||
@@ -46,6 +46,7 @@ pub enum MessageType {
|
||||
|
||||
// Plugin: Popup Blocker (弹窗拦截)
|
||||
PopupRules = 0x50,
|
||||
PopupBlockStats = 0x51,
|
||||
|
||||
// Plugin: USB File Audit (U盘文件操作记录)
|
||||
UsbFileOp = 0x60,
|
||||
@@ -59,6 +60,17 @@ pub enum MessageType {
|
||||
// Plugin control
|
||||
PluginEnable = 0x80,
|
||||
PluginDisable = 0x81,
|
||||
|
||||
// Plugin: Disk Encryption (磁盘加密检测)
|
||||
DiskEncryptionStatus = 0x90,
|
||||
DiskEncryptionConfig = 0x93,
|
||||
|
||||
// Plugin: Print Audit (打印审计)
|
||||
PrintEvent = 0x91,
|
||||
|
||||
// Plugin: Clipboard Control (剪贴板管控)
|
||||
ClipboardRules = 0x94,
|
||||
ClipboardViolation = 0x95,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for MessageType {
|
||||
@@ -85,11 +97,17 @@ impl TryFrom<u8> for MessageType {
|
||||
0x40 => Ok(Self::SoftwareBlacklist),
|
||||
0x41 => Ok(Self::SoftwareViolation),
|
||||
0x50 => Ok(Self::PopupRules),
|
||||
0x51 => Ok(Self::PopupBlockStats),
|
||||
0x60 => Ok(Self::UsbFileOp),
|
||||
0x70 => Ok(Self::WatermarkConfig),
|
||||
0x71 => Ok(Self::UsbPolicyUpdate),
|
||||
0x80 => Ok(Self::PluginEnable),
|
||||
0x81 => Ok(Self::PluginDisable),
|
||||
0x90 => Ok(Self::DiskEncryptionStatus),
|
||||
0x93 => Ok(Self::DiskEncryptionConfig),
|
||||
0x91 => Ok(Self::PrintEvent),
|
||||
0x94 => Ok(Self::ClipboardRules),
|
||||
0x95 => Ok(Self::ClipboardViolation),
|
||||
_ => Err(format!("Unknown message type: 0x{:02X}", value)),
|
||||
}
|
||||
}
|
||||
@@ -337,6 +355,93 @@ pub struct UsbDeviceRule {
|
||||
pub device_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Plugin: Disk Encryption Status (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DiskEncryptionStatusPayload {
|
||||
pub device_uid: String,
|
||||
pub drives: Vec<DriveEncryptionInfo>,
|
||||
}
|
||||
|
||||
/// Information about a single drive's encryption status.
|
||||
/// Field names and types match the migration 012 disk_encryption_status table.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DriveEncryptionInfo {
|
||||
pub drive_letter: String,
|
||||
pub volume_name: Option<String>,
|
||||
pub encryption_method: Option<String>,
|
||||
pub protection_status: String, // "On", "Off", "Unknown"
|
||||
pub encryption_percentage: f64,
|
||||
pub lock_status: String, // "Locked", "Unlocked", "Unknown"
|
||||
}
|
||||
|
||||
/// Plugin: Disk Encryption Config (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct DiskEncryptionConfigPayload {
|
||||
pub enabled: bool,
|
||||
pub report_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Plugin: Print Event (Client → Server)
|
||||
/// Field names and types match the migration 013 print_events table.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PrintEventPayload {
|
||||
pub device_uid: String,
|
||||
pub document_name: Option<String>,
|
||||
pub printer_name: Option<String>,
|
||||
pub pages: Option<i32>,
|
||||
pub copies: Option<i32>,
|
||||
pub user_name: Option<String>,
|
||||
pub file_size_bytes: Option<i64>,
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Plugin: Clipboard Rules (Server → Client)
|
||||
/// Pushed from server to client to define clipboard operation policies.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ClipboardRulesPayload {
|
||||
pub rules: Vec<ClipboardRule>,
|
||||
}
|
||||
|
||||
/// A single clipboard control rule.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ClipboardRule {
|
||||
pub id: i64,
|
||||
pub rule_type: String, // "block" | "allow"
|
||||
pub direction: String, // "out" | "in" | "both"
|
||||
pub source_process: Option<String>,
|
||||
pub target_process: Option<String>,
|
||||
pub content_pattern: Option<String>,
|
||||
}
|
||||
|
||||
/// Plugin: Clipboard Violation (Client → Server)
|
||||
/// Field names and types match the migration 014 clipboard_violations table.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ClipboardViolationPayload {
|
||||
pub device_uid: String,
|
||||
pub source_process: Option<String>,
|
||||
pub target_process: Option<String>,
|
||||
pub content_preview: Option<String>,
|
||||
pub action_taken: String, // "blocked" | "allowed"
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Plugin: Popup Block Stats (Client → Server)
|
||||
/// Periodic statistics from the popup blocker plugin.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PopupBlockStatsPayload {
|
||||
pub device_uid: String,
|
||||
pub blocked_count: u32,
|
||||
pub rule_stats: Vec<PopupRuleStat>,
|
||||
pub period_secs: u64,
|
||||
}
|
||||
|
||||
/// Statistics for a single popup blocker rule.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PopupRuleStat {
|
||||
pub rule_id: i64,
|
||||
pub hits: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -41,11 +41,13 @@ tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
# Static file embedding
|
||||
include_dir = "0.7"
|
||||
|
||||
# Utilities
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
include_dir = "0.7"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
|
||||
@@ -63,7 +63,9 @@ pub async fn cleanup_task(state: AppState) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send email notification
|
||||
/// Send email notification.
|
||||
/// TODO: Wire up email notifications to alert rules.
|
||||
#[allow(dead_code)]
|
||||
pub async fn send_email(
|
||||
smtp_config: &crate::config::SmtpConfig,
|
||||
to: &str,
|
||||
@@ -97,8 +99,10 @@ pub async fn send_email(
|
||||
|
||||
/// Shared HTTP client for webhook notifications.
|
||||
/// Lazily initialized once and reused across calls to benefit from connection pooling.
|
||||
#[allow(dead_code)]
|
||||
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn webhook_client() -> &'static reqwest::Client {
|
||||
WEBHOOK_CLIENT.get_or_init(|| {
|
||||
reqwest::Client::builder()
|
||||
@@ -108,7 +112,9 @@ fn webhook_client() -> &'static reqwest::Client {
|
||||
})
|
||||
}
|
||||
|
||||
/// Send webhook notification
|
||||
/// Send webhook notification.
|
||||
/// TODO: Wire up webhook notifications to alert rules.
|
||||
#[allow(dead_code)]
|
||||
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
|
||||
webhook_client().post(url)
|
||||
.json(payload)
|
||||
|
||||
@@ -2,7 +2,7 @@ use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::{ApiResponse, Pagination};
|
||||
use super::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AssetListParams {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response, Extension};
|
||||
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
|
||||
use axum::{routing::{get, post, put, delete}, Router, Json, middleware};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
|
||||
247
crates/server/src/api/plugins/clipboard_control.rs
Normal file
247
crates/server/src/api/plugins/clipboard_control.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use axum::{extract::{State, Path, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub rule_type: Option<String>, // "block" | "allow"
|
||||
pub direction: Option<String>, // "out" | "in" | "both"
|
||||
pub source_process: Option<String>,
|
||||
pub target_process: Option<String>,
|
||||
pub content_pattern: Option<String>,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled, updated_at \
|
||||
FROM clipboard_rules ORDER BY updated_at DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"target_type": r.get::<String, _>("target_type"),
|
||||
"target_id": r.get::<Option<String>, _>("target_id"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"direction": r.get::<String, _>("direction"),
|
||||
"source_process": r.get::<Option<String>, _>("source_process"),
|
||||
"target_process": r.get::<Option<String>, _>("target_process"),
|
||||
"content_pattern": r.get::<Option<String>, _>("content_pattern"),
|
||||
"enabled": r.get::<bool, _>("enabled"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect::<Vec<_>>()
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query clipboard rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_rule(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<CreateRuleRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
let rule_type = req.rule_type.unwrap_or_else(|| "block".to_string());
|
||||
let direction = req.direction.unwrap_or_else(|| "out".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if !matches!(rule_type.as_str(), "block" | "allow") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
|
||||
}
|
||||
if !matches!(direction.as_str(), "out" | "in" | "both") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("direction must be 'out', 'in' or 'both'")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
|
||||
match sqlx::query(
|
||||
"INSERT INTO clipboard_rules (target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&target_type)
|
||||
.bind(&req.target_id)
|
||||
.bind(&rule_type)
|
||||
.bind(&direction)
|
||||
.bind(&req.source_process)
|
||||
.bind(&req.target_process)
|
||||
.bind(&req.content_pattern)
|
||||
.bind(enabled)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let rules = fetch_clipboard_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(
|
||||
&state.db, &state.clients, MessageType::ClipboardRules,
|
||||
&serde_json::json!({"rules": rules}),
|
||||
&target_type, req.target_id.as_deref(),
|
||||
).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create clipboard rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest {
|
||||
pub rule_type: Option<String>,
|
||||
pub direction: Option<String>,
|
||||
pub source_process: Option<String>,
|
||||
pub target_process: Option<String>,
|
||||
pub content_pattern: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn update_rule(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdateRuleRequest>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM clipboard_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query clipboard rule", e)),
|
||||
};
|
||||
|
||||
let rule_type = body.rule_type.or_else(|| Some(existing.get::<String, _>("rule_type")));
|
||||
let direction = body.direction.or_else(|| Some(existing.get::<String, _>("direction")));
|
||||
let source_process = body.source_process.or_else(|| existing.get::<Option<String>, _>("source_process"));
|
||||
let target_process = body.target_process.or_else(|| existing.get::<Option<String>, _>("target_process"));
|
||||
let content_pattern = body.content_pattern.or_else(|| existing.get::<Option<String>, _>("content_pattern"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE clipboard_rules SET rule_type = ?, direction = ?, source_process = ?, target_process = ?, content_pattern = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&rule_type)
|
||||
.bind(&direction)
|
||||
.bind(&source_process)
|
||||
.bind(&target_process)
|
||||
.bind(&content_pattern)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let rules = fetch_clipboard_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(
|
||||
&state.db, &state.clients, MessageType::ClipboardRules,
|
||||
&serde_json::json!({"rules": rules}),
|
||||
&target_type_val, target_id_val.as_deref(),
|
||||
).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update clipboard rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM clipboard_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
|
||||
match sqlx::query("DELETE FROM clipboard_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let rules = fetch_clipboard_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(
|
||||
&state.db, &state.clients, MessageType::ClipboardRules,
|
||||
&serde_json::json!({"rules": rules}),
|
||||
&target_type, target_id.as_deref(),
|
||||
).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_violations(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, source_process, target_process, content_preview, action_taken, timestamp, reported_at \
|
||||
FROM clipboard_violations ORDER BY reported_at DESC LIMIT 200"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"violations": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"source_process": r.get::<Option<String>, _>("source_process"),
|
||||
"target_process": r.get::<Option<String>, _>("target_process"),
|
||||
"content_preview": r.get::<Option<String>, _>("content_preview"),
|
||||
"action_taken": r.get::<String, _>("action_taken"),
|
||||
"timestamp": r.get::<String, _>("timestamp"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
})).collect::<Vec<_>>()
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query clipboard violations", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_clipboard_rules_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
|
||||
FROM clipboard_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
)
|
||||
.bind(target_id),
|
||||
"group" => sqlx::query(
|
||||
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
|
||||
FROM clipboard_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'group' AND target_id = ?))"
|
||||
)
|
||||
.bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
|
||||
FROM clipboard_rules WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"direction": r.get::<String, _>("direction"),
|
||||
"source_process": r.get::<Option<String>, _>("source_process"),
|
||||
"target_process": r.get::<Option<String>, _>("target_process"),
|
||||
"content_pattern": r.get::<Option<String>, _>("content_pattern"),
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
97
crates/server/src/api/plugins/disk_encryption.rs
Normal file
97
crates/server/src/api/plugins/disk_encryption.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use axum::{extract::{State, Path, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StatusFilter {
|
||||
pub device_uid: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_status(
|
||||
State(state): State<AppState>,
|
||||
Query(filter): Query<StatusFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let result = if let Some(uid) = &filter.device_uid {
|
||||
sqlx::query(
|
||||
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
|
||||
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
|
||||
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
|
||||
WHERE s.device_uid = ? ORDER BY s.drive_letter"
|
||||
)
|
||||
.bind(uid)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
} else {
|
||||
sqlx::query(
|
||||
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
|
||||
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
|
||||
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
|
||||
ORDER BY s.device_uid, s.drive_letter"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"entries": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"hostname": r.get::<Option<String>, _>("hostname"),
|
||||
"drive_letter": r.get::<String, _>("drive_letter"),
|
||||
"volume_name": r.get::<Option<String>, _>("volume_name"),
|
||||
"encryption_method": r.get::<Option<String>, _>("encryption_method"),
|
||||
"protection_status": r.get::<String, _>("protection_status"),
|
||||
"encryption_percentage": r.get::<f64, _>("encryption_percentage"),
|
||||
"lock_status": r.get::<String, _>("lock_status"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect::<Vec<_>>()
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query disk encryption status", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_alerts(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT a.id, a.device_uid, a.drive_letter, a.alert_type, a.status, a.created_at, a.resolved_at, \
|
||||
d.hostname FROM encryption_alerts a LEFT JOIN devices d ON a.device_uid = d.device_uid \
|
||||
ORDER BY a.created_at DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"alerts": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"hostname": r.get::<Option<String>, _>("hostname"),
|
||||
"drive_letter": r.get::<String, _>("drive_letter"),
|
||||
"alert_type": r.get::<String, _>("alert_type"),
|
||||
"status": r.get::<String, _>("status"),
|
||||
"created_at": r.get::<String, _>("created_at"),
|
||||
"resolved_at": r.get::<Option<String>, _>("resolved_at"),
|
||||
})).collect::<Vec<_>>()
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query encryption alerts", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn acknowledge_alert(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
match sqlx::query(
|
||||
"UPDATE encryption_alerts SET status = 'acknowledged', resolved_at = datetime('now') WHERE id = ? AND status = 'open'"
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.rows_affected() > 0 => Json(ApiResponse::ok(())),
|
||||
Ok(_) => Json(ApiResponse::error("Alert not found or already acknowledged")),
|
||||
Err(e) => Json(ApiResponse::internal_error("acknowledge encryption alert", e)),
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,10 @@ pub mod software_blocker;
|
||||
pub mod popup_blocker;
|
||||
pub mod usb_file_audit;
|
||||
pub mod watermark;
|
||||
pub mod disk_encryption;
|
||||
pub mod print_audit;
|
||||
pub mod clipboard_control;
|
||||
pub mod plugin_control;
|
||||
|
||||
use axum::{Router, routing::{get, post, put}};
|
||||
use crate::AppState;
|
||||
@@ -29,6 +33,18 @@ pub fn read_routes() -> Router<AppState> {
|
||||
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
|
||||
// Watermark
|
||||
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
|
||||
// Disk Encryption
|
||||
.route("/api/plugins/disk-encryption/status", get(disk_encryption::list_status))
|
||||
.route("/api/plugins/disk-encryption/alerts", get(disk_encryption::list_alerts))
|
||||
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
|
||||
// Print Audit
|
||||
.route("/api/plugins/print-audit/events", get(print_audit::list_events))
|
||||
.route("/api/plugins/print-audit/events/:id", get(print_audit::get_event))
|
||||
// Clipboard Control
|
||||
.route("/api/plugins/clipboard-control/rules", get(clipboard_control::list_rules))
|
||||
.route("/api/plugins/clipboard-control/violations", get(clipboard_control::list_violations))
|
||||
// Plugin Control
|
||||
.route("/api/plugins/control", get(plugin_control::list_plugins))
|
||||
}
|
||||
|
||||
/// Write plugin routes (admin only — require_admin middleware applied by caller)
|
||||
@@ -46,4 +62,9 @@ pub fn write_routes() -> Router<AppState> {
|
||||
// Watermark
|
||||
.route("/api/plugins/watermark/config", post(watermark::create_config))
|
||||
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
|
||||
// Clipboard Control
|
||||
.route("/api/plugins/clipboard-control/rules", post(clipboard_control::create_rule))
|
||||
.route("/api/plugins/clipboard-control/rules/:id", put(clipboard_control::update_rule).delete(clipboard_control::delete_rule))
|
||||
// Plugin Control (enable/disable)
|
||||
.route("/api/plugins/control/:plugin_name", put(plugin_control::set_plugin_state))
|
||||
}
|
||||
|
||||
95
crates/server/src/api/plugins/plugin_control.rs
Normal file
95
crates/server/src/api/plugins/plugin_control.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use axum::{extract::{State, Path, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
use csm_protocol::{MessageType, PluginControlPayload};
|
||||
|
||||
/// List all plugin states
|
||||
pub async fn list_plugins(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, plugin_name, enabled, target_type, target_id, updated_at FROM plugin_state ORDER BY plugin_name"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"plugins": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"plugin_name": r.get::<String, _>("plugin_name"),
|
||||
"enabled": r.get::<bool, _>("enabled"),
|
||||
"target_type": r.get::<String, _>("target_type"),
|
||||
"target_id": r.get::<Option<String>, _>("target_id"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect::<Vec<_>>()
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query plugin state", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SetPluginStateRequest {
|
||||
pub enabled: bool,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Enable or disable a plugin. Pushes PluginEnable/PluginDisable to matching clients.
|
||||
pub async fn set_plugin_state(
|
||||
State(state): State<AppState>,
|
||||
Path(plugin_name): Path<String>,
|
||||
Json(req): Json<SetPluginStateRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let valid_plugins = [
|
||||
"web_filter", "usage_timer", "software_blocker",
|
||||
"popup_blocker", "usb_file_audit", "watermark",
|
||||
"disk_encryption", "usb_audit", "print_audit", "clipboard_control",
|
||||
];
|
||||
if !valid_plugins.contains(&plugin_name.as_str()) {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("unknown plugin name")));
|
||||
}
|
||||
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
// Upsert plugin state
|
||||
match sqlx::query(
|
||||
"INSERT INTO plugin_state (plugin_name, enabled, target_type, target_id, updated_at) \
|
||||
VALUES (?, ?, ?, ?, datetime('now')) \
|
||||
ON CONFLICT(plugin_name) DO UPDATE SET enabled = excluded.enabled, target_type = excluded.target_type, \
|
||||
target_id = excluded.target_id, updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&plugin_name)
|
||||
.bind(req.enabled)
|
||||
.bind(&target_type)
|
||||
.bind(&req.target_id)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
// Push enable/disable to matching clients
|
||||
let payload = PluginControlPayload {
|
||||
plugin_name: plugin_name.clone(),
|
||||
enabled: req.enabled,
|
||||
};
|
||||
let msg_type = if req.enabled {
|
||||
MessageType::PluginEnable
|
||||
} else {
|
||||
MessageType::PluginDisable
|
||||
};
|
||||
push_to_targets(
|
||||
&state.db, &state.clients, msg_type, &payload,
|
||||
&target_type, req.target_id.as_deref(),
|
||||
).await;
|
||||
|
||||
(StatusCode::OK, Json(ApiResponse::ok(serde_json::json!({
|
||||
"plugin_name": plugin_name,
|
||||
"enabled": req.enabled,
|
||||
}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("set plugin state", e))),
|
||||
}
|
||||
}
|
||||
101
crates/server/src/api/plugins/print_audit.rs
Normal file
101
crates/server/src/api/plugins/print_audit.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use axum::{extract::{State, Query, Path}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListEventsParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_events(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<ListEventsParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let page = params.page.unwrap_or(1).max(1);
|
||||
let page_size = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = (page - 1) * page_size;
|
||||
|
||||
let (rows, total) = if let Some(ref device_uid) = params.device_uid {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp, reported_at \
|
||||
FROM print_events WHERE device_uid = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db).await;
|
||||
|
||||
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM print_events WHERE device_uid = ?")
|
||||
.bind(device_uid)
|
||||
.fetch_one(&state.db).await.unwrap_or(0);
|
||||
|
||||
(rows, total)
|
||||
} else {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp, reported_at \
|
||||
FROM print_events ORDER BY timestamp DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db).await;
|
||||
|
||||
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM print_events")
|
||||
.fetch_one(&state.db).await.unwrap_or(0);
|
||||
|
||||
(rows, total)
|
||||
};
|
||||
|
||||
match rows {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"events": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"document_name": r.get::<Option<String>, _>("document_name"),
|
||||
"printer_name": r.get::<Option<String>, _>("printer_name"),
|
||||
"pages": r.get::<Option<i32>, _>("pages"),
|
||||
"copies": r.get::<Option<i32>, _>("copies"),
|
||||
"user_name": r.get::<Option<String>, _>("user_name"),
|
||||
"file_size_bytes": r.get::<Option<i64>, _>("file_size_bytes"),
|
||||
"timestamp": r.get::<String, _>("timestamp"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
})).collect::<Vec<_>>(),
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query print events", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_event(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp, reported_at \
|
||||
FROM print_events WHERE id = ?"
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(row)) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"device_uid": row.get::<String, _>("device_uid"),
|
||||
"document_name": row.get::<Option<String>, _>("document_name"),
|
||||
"printer_name": row.get::<Option<String>, _>("printer_name"),
|
||||
"pages": row.get::<Option<i32>, _>("pages"),
|
||||
"copies": row.get::<Option<i32>, _>("copies"),
|
||||
"user_name": row.get::<Option<String>, _>("user_name"),
|
||||
"file_size_bytes": row.get::<Option<i64>, _>("file_size_bytes"),
|
||||
"timestamp": row.get::<String, _>("timestamp"),
|
||||
"reported_at": row.get::<String, _>("reported_at"),
|
||||
}))),
|
||||
Ok(None) => Json(ApiResponse::error("Print event not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("query print event", e)),
|
||||
}
|
||||
}
|
||||
@@ -143,6 +143,9 @@ async fn fetch_blacklist_for_push(
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
"group" => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'group' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
|
||||
@@ -6,9 +6,6 @@ use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub rule_type: String,
|
||||
@@ -144,6 +141,9 @@ async fn fetch_rules_for_push(
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
"group" => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'group' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
|
||||
@@ -2,6 +2,8 @@ use sqlx::SqlitePool;
|
||||
use tracing::debug;
|
||||
|
||||
/// Record an admin audit log entry.
|
||||
/// TODO: Wire up audit logging to all admin API handlers.
|
||||
#[allow(dead_code)]
|
||||
pub async fn audit_log(
|
||||
db: &SqlitePool,
|
||||
user_id: i64,
|
||||
|
||||
@@ -82,18 +82,32 @@ pub struct SmtpConfig {
|
||||
|
||||
impl AppConfig {
|
||||
pub async fn load(path: &str) -> Result<Self> {
|
||||
if Path::new(path).exists() {
|
||||
let mut config = if Path::new(path).exists() {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
let config: AppConfig = toml::from_str(&content)?;
|
||||
Ok(config)
|
||||
toml::from_str(&content)?
|
||||
} else {
|
||||
let config = default_config();
|
||||
// Write default config for reference
|
||||
let toml_str = toml::to_string_pretty(&config)?;
|
||||
tokio::fs::write(path, &toml_str).await?;
|
||||
tracing::warn!("Created default config at {}", path);
|
||||
Ok(config)
|
||||
config
|
||||
};
|
||||
|
||||
// Environment variable overrides (take precedence over config file)
|
||||
if let Ok(val) = std::env::var("CSM_JWT_SECRET") {
|
||||
if !val.is_empty() {
|
||||
tracing::info!("JWT secret loaded from CSM_JWT_SECRET env var");
|
||||
config.auth.jwt_secret = val;
|
||||
}
|
||||
}
|
||||
if let Ok(val) = std::env::var("CSM_REGISTRATION_TOKEN") {
|
||||
if !val.is_empty() {
|
||||
tracing::info!("Registration token loaded from CSM_REGISTRATION_TOKEN env var");
|
||||
config.registration_token = val;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -117,12 +117,15 @@ impl DeviceRepo {
|
||||
}
|
||||
|
||||
pub async fn upsert_software(pool: &SqlitePool, asset: &csm_protocol::SoftwareAsset) -> Result<()> {
|
||||
// Use INSERT OR REPLACE to handle the UNIQUE(device_uid, name, version) constraint
|
||||
// where version can be NULL (treated as distinct by SQLite)
|
||||
let version = asset.version.as_deref().unwrap_or("");
|
||||
sqlx::query(
|
||||
"INSERT OR REPLACE INTO software_assets (device_uid, name, version, publisher, install_date, install_path, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
"INSERT INTO software_assets (device_uid, name, version, publisher, install_date, install_path, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(device_uid, name, version) DO UPDATE SET
|
||||
publisher = excluded.publisher,
|
||||
install_date = excluded.install_date,
|
||||
install_path = excluded.install_path,
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&asset.device_uid)
|
||||
.bind(&asset.name)
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
use anyhow::Result;
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request, Response, StatusCode, header};
|
||||
use axum::middleware::Next;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::cors::{CorsLayer, Any};
|
||||
use axum::http::Method as HttpMethod;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::compression::CompressionLayer;
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
use tracing::{info, warn, error};
|
||||
use include_dir::{include_dir, Dir};
|
||||
|
||||
mod api;
|
||||
mod audit;
|
||||
@@ -21,6 +26,10 @@ mod alert;
|
||||
|
||||
use config::AppConfig;
|
||||
|
||||
/// Embedded frontend assets from web/dist/ (compiled into the binary at build time).
|
||||
/// Falls back gracefully at runtime if the directory is empty (dev mode).
|
||||
static FRONTEND_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/../../web/dist");
|
||||
|
||||
/// Application shared state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@@ -46,6 +55,15 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Load configuration
|
||||
let config = AppConfig::load("config.toml").await?;
|
||||
|
||||
// Security checks
|
||||
if config.registration_token.is_empty() {
|
||||
warn!("SECURITY: registration_token is empty — any device can register!");
|
||||
}
|
||||
if config.auth.jwt_secret.len() < 32 {
|
||||
warn!("SECURITY: jwt_secret is too short ({} chars) — consider using a 32+ byte key from CSM_JWT_SECRET env var", config.auth.jwt_secret.len());
|
||||
}
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Initialize database
|
||||
@@ -86,6 +104,9 @@ async fn main() -> Result<()> {
|
||||
// Build HTTP router
|
||||
let app = Router::new()
|
||||
.merge(api::routes(state.clone()))
|
||||
// SPA fallback: serve embedded frontend for non-API routes
|
||||
.fallback(spa_fallback)
|
||||
.layer(axum::middleware::from_fn(json_rejection_handler))
|
||||
.layer(
|
||||
build_cors_layer(&config.server.cors_origins),
|
||||
)
|
||||
@@ -171,6 +192,11 @@ async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
|
||||
include_str!("../../../migrations/010_plugins_watermark.sql"),
|
||||
include_str!("../../../migrations/011_token_security.sql"),
|
||||
include_str!("../../../migrations/012_disk_encryption.sql"),
|
||||
include_str!("../../../migrations/013_print_audit.sql"),
|
||||
include_str!("../../../migrations/014_clipboard_control.sql"),
|
||||
include_str!("../../../migrations/015_plugin_control.sql"),
|
||||
include_str!("../../../migrations/016_encryption_alerts_unique.sql"),
|
||||
];
|
||||
|
||||
// Create migrations tracking table
|
||||
@@ -257,8 +283,102 @@ fn build_cors_layer(origins: &[String]) -> CorsLayer {
|
||||
} else {
|
||||
CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.allow_methods([HttpMethod::GET, HttpMethod::POST, HttpMethod::PUT, HttpMethod::DELETE])
|
||||
.allow_headers([axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE])
|
||||
.max_age(std::time::Duration::from_secs(3600))
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware to convert axum's default 422 text/plain rejection responses
|
||||
/// (e.g., JSON deserialization errors) into proper JSON ApiResponse format.
|
||||
async fn json_rejection_handler(
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response<Body> {
|
||||
let response = next.run(req).await;
|
||||
let status = response.status();
|
||||
|
||||
if status == StatusCode::UNPROCESSABLE_ENTITY {
|
||||
let ct = response.headers()
|
||||
.get(axum::http::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if ct.starts_with("text/plain") {
|
||||
// Convert to JSON error response
|
||||
let body = serde_json::json!({
|
||||
"success": false,
|
||||
"data": null,
|
||||
"error": "Invalid request body"
|
||||
});
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNPROCESSABLE_ENTITY)
|
||||
.header(axum::http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&body).unwrap_or_default()))
|
||||
.unwrap_or(response);
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// SPA fallback handler: serves embedded frontend static files.
|
||||
/// For known asset paths (JS/CSS/images), returns the file with correct MIME type.
|
||||
/// For all other paths, returns index.html (SPA client-side routing).
|
||||
async fn spa_fallback(req: Request<Body>) -> Response<Body> {
|
||||
let path = req.uri().path().trim_start_matches('/');
|
||||
|
||||
// Try to serve the exact file first (e.g., assets/index-xxx.js)
|
||||
if !path.is_empty() {
|
||||
if let Some(file) = FRONTEND_DIR.get_file(path) {
|
||||
let content_type = guess_content_type(path);
|
||||
return Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, content_type)
|
||||
.header(header::CACHE_CONTROL, "public, max-age=31536000".to_string())
|
||||
.body(Body::from(file.contents().to_vec()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("Internal error")));
|
||||
}
|
||||
}
|
||||
|
||||
// SPA fallback: return index.html for all unmatched routes
|
||||
match FRONTEND_DIR.get_file("index.html") {
|
||||
Some(file) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
|
||||
.header(header::CACHE_CONTROL, "no-cache".to_string())
|
||||
.body(Body::from(file.contents().to_vec()))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("Internal error"))),
|
||||
None => Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from("Frontend not built. Run: cd web && npm run build"))
|
||||
.unwrap_or_else(|_| Response::new(Body::from("Not found"))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Guess MIME type from file extension.
|
||||
fn guess_content_type(path: &str) -> &'static str {
|
||||
if path.ends_with(".js") {
|
||||
"application/javascript; charset=utf-8"
|
||||
} else if path.ends_with(".css") {
|
||||
"text/css; charset=utf-8"
|
||||
} else if path.ends_with(".html") {
|
||||
"text/html; charset=utf-8"
|
||||
} else if path.ends_with(".json") {
|
||||
"application/json"
|
||||
} else if path.ends_with(".png") {
|
||||
"image/png"
|
||||
} else if path.ends_with(".jpg") || path.ends_with(".jpeg") {
|
||||
"image/jpeg"
|
||||
} else if path.ends_with(".svg") {
|
||||
"image/svg+xml"
|
||||
} else if path.ends_with(".ico") {
|
||||
"image/x-icon"
|
||||
} else if path.ends_with(".woff") || path.ends_with(".woff2") {
|
||||
"font/woff2"
|
||||
} else if path.ends_with(".ttf") {
|
||||
"font/ttf"
|
||||
} else {
|
||||
"application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
@@ -13,6 +14,15 @@ use crate::AppState;
|
||||
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
|
||||
const RATE_LIMIT_MAX_FRAMES: usize = 100;
|
||||
|
||||
/// Maximum concurrent TCP connections
|
||||
const MAX_CONNECTIONS: usize = 500;
|
||||
|
||||
/// Maximum consecutive HMAC failures before disconnecting
|
||||
const MAX_HMAC_FAILURES: u32 = 3;
|
||||
|
||||
/// Idle timeout for TCP connections (seconds) — disconnect if no data received
|
||||
const IDLE_TIMEOUT_SECS: u64 = 180;
|
||||
|
||||
/// Per-connection rate limiter using a sliding window of frame timestamps
|
||||
struct RateLimiter {
|
||||
timestamps: Vec<Instant>,
|
||||
@@ -226,6 +236,61 @@ pub async fn push_all_plugin_configs(
|
||||
}
|
||||
}
|
||||
|
||||
// Clipboard control rules
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
|
||||
FROM clipboard_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let rules: Vec<csm_protocol::ClipboardRule> = rows.iter().map(|r| csm_protocol::ClipboardRule {
|
||||
id: r.get::<i64, _>("id"),
|
||||
rule_type: r.get::<String, _>("rule_type"),
|
||||
direction: r.get::<String, _>("direction"),
|
||||
source_process: r.get::<Option<String>, _>("source_process"),
|
||||
target_process: r.get::<Option<String>, _>("target_process"),
|
||||
content_pattern: r.get::<Option<String>, _>("content_pattern"),
|
||||
}).collect();
|
||||
if !rules.is_empty() {
|
||||
let payload = csm_protocol::ClipboardRulesPayload { rules };
|
||||
if let Ok(frame) = Frame::new_json(MessageType::ClipboardRules, &payload) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Disk encryption config — push default reporting interval (no dedicated config table)
|
||||
{
|
||||
let config = csm_protocol::DiskEncryptionConfigPayload {
|
||||
enabled: true,
|
||||
report_interval_secs: 3600,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionConfig, &config) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Push plugin enable/disable state — disable any plugins that admin has turned off
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT plugin_name FROM plugin_state WHERE enabled = 0"
|
||||
)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
for row in &rows {
|
||||
let plugin_name: String = row.get("plugin_name");
|
||||
let payload = csm_protocol::PluginControlPayload {
|
||||
plugin_name: plugin_name.clone(),
|
||||
enabled: false,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PluginDisable, &payload) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
debug!("Pushed PluginDisable for {} to device {}", plugin_name, device_uid);
|
||||
}
|
||||
}
|
||||
|
||||
info!("Pushed all plugin configs to newly registered device {}", device_uid);
|
||||
}
|
||||
|
||||
@@ -283,6 +348,15 @@ pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<(
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
|
||||
// Enforce maximum connection limit
|
||||
let current_count = state.clients.count().await;
|
||||
if current_count >= MAX_CONNECTIONS {
|
||||
warn!("Rejecting connection from {}: limit reached ({}/{})", peer_addr, current_count, MAX_CONNECTIONS);
|
||||
drop(stream);
|
||||
continue;
|
||||
}
|
||||
|
||||
let state = state.clone();
|
||||
let acceptor = tls_acceptor.clone();
|
||||
|
||||
@@ -361,20 +435,6 @@ async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
|
||||
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
/// Verify that a frame sender is a registered device and that the claimed device_uid
|
||||
/// matches the one registered on this connection. Returns true if valid.
|
||||
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
|
||||
@@ -392,11 +452,13 @@ fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &
|
||||
}
|
||||
|
||||
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
|
||||
/// `hmac_fail_count` tracks consecutive HMAC failures; caller checks it for disconnect threshold.
|
||||
async fn process_frame(
|
||||
frame: Frame,
|
||||
state: &AppState,
|
||||
device_uid: &mut Option<String>,
|
||||
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
hmac_fail_count: &Arc<AtomicU32>,
|
||||
) -> anyhow::Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::Register => {
|
||||
@@ -438,7 +500,7 @@ async fn process_frame(
|
||||
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
|
||||
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
|
||||
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
|
||||
mac_address=excluded.mac_address, status='online'"
|
||||
mac_address=excluded.mac_address, status='online', last_heartbeat=datetime('now')"
|
||||
)
|
||||
.bind(&req.device_uid)
|
||||
.bind(&req.hostname)
|
||||
@@ -493,6 +555,7 @@ async fn process_frame(
|
||||
if !secret.is_empty() {
|
||||
if heartbeat.hmac.is_empty() {
|
||||
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
|
||||
hmac_fail_count.fetch_add(1, Ordering::Relaxed);
|
||||
return Ok(());
|
||||
}
|
||||
// Constant-time HMAC verification using hmac::Mac::verify_slice
|
||||
@@ -502,9 +565,11 @@ async fn process_frame(
|
||||
mac.update(message.as_bytes());
|
||||
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
|
||||
if mac.verify_slice(&provided_bytes).is_err() {
|
||||
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
|
||||
warn!("Heartbeat HMAC mismatch for device {} (fail #{})", heartbeat.device_uid, hmac_fail_count.fetch_add(1, Ordering::Relaxed) + 1);
|
||||
return Ok(());
|
||||
}
|
||||
// Successful verification — reset failure counter
|
||||
hmac_fail_count.store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -600,7 +665,8 @@ async fn process_frame(
|
||||
total_active_minutes = excluded.total_active_minutes, \
|
||||
total_idle_minutes = excluded.total_idle_minutes, \
|
||||
first_active_at = excluded.first_active_at, \
|
||||
last_active_at = excluded.last_active_at"
|
||||
last_active_at = excluded.last_active_at, \
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.date)
|
||||
@@ -627,7 +693,8 @@ async fn process_frame(
|
||||
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
|
||||
VALUES (?, ?, ?, ?) \
|
||||
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
|
||||
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
|
||||
usage_minutes = MAX(usage_minutes, excluded.usage_minutes), \
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.date)
|
||||
@@ -716,6 +783,150 @@ async fn process_frame(
|
||||
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
|
||||
}
|
||||
|
||||
MessageType::DiskEncryptionStatus => {
|
||||
let payload: csm_protocol::DiskEncryptionStatusPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid disk encryption status: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "DiskEncryptionStatus", &payload.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for drive in &payload.drives {
|
||||
sqlx::query(
|
||||
"INSERT INTO disk_encryption_status (device_uid, drive_letter, volume_name, encryption_method, protection_status, encryption_percentage, lock_status, reported_at, updated_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) \
|
||||
ON CONFLICT(device_uid, drive_letter) DO UPDATE SET \
|
||||
volume_name=excluded.volume_name, encryption_method=excluded.encryption_method, \
|
||||
protection_status=excluded.protection_status, encryption_percentage=excluded.encryption_percentage, \
|
||||
lock_status=excluded.lock_status, updated_at=datetime('now')"
|
||||
)
|
||||
.bind(&payload.device_uid)
|
||||
.bind(&drive.drive_letter)
|
||||
.bind(&drive.volume_name)
|
||||
.bind(&drive.encryption_method)
|
||||
.bind(&drive.protection_status)
|
||||
.bind(drive.encryption_percentage)
|
||||
.bind(&drive.lock_status)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting disk encryption status: {}", e))?;
|
||||
|
||||
// Generate alert for unencrypted drives
|
||||
if drive.protection_status == "Off" {
|
||||
sqlx::query(
|
||||
"INSERT INTO encryption_alerts (device_uid, drive_letter, alert_type, status) \
|
||||
VALUES (?, ?, 'not_encrypted', 'open') \
|
||||
ON CONFLICT(device_uid, drive_letter, alert_type, status) DO NOTHING"
|
||||
)
|
||||
.bind(&payload.device_uid)
|
||||
.bind(&drive.drive_letter)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
info!("Disk encryption status reported: {} ({} drives)", payload.device_uid, payload.drives.len());
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "disk_encryption_status",
|
||||
"device_uid": payload.device_uid,
|
||||
"drive_count": payload.drives.len()
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::PrintEvent => {
|
||||
let event: csm_protocol::PrintEventPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid print event: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "PrintEvent", &event.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO print_events (device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp) \
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&event.device_uid)
|
||||
.bind(&event.document_name)
|
||||
.bind(&event.printer_name)
|
||||
.bind(event.pages)
|
||||
.bind(event.copies)
|
||||
.bind(&event.user_name)
|
||||
.bind(event.file_size_bytes)
|
||||
.bind(&event.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting print event: {}", e))?;
|
||||
|
||||
debug!("Print event: {} doc={:?} printer={:?} pages={:?}",
|
||||
event.device_uid, event.document_name, event.printer_name, event.pages);
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "print_event",
|
||||
"device_uid": event.device_uid,
|
||||
"document_name": event.document_name,
|
||||
"printer_name": event.printer_name
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::ClipboardViolation => {
|
||||
let violation: csm_protocol::ClipboardViolationPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid clipboard violation: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "ClipboardViolation", &violation.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO clipboard_violations (device_uid, source_process, target_process, content_preview, action_taken, timestamp) \
|
||||
VALUES (?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&violation.device_uid)
|
||||
.bind(&violation.source_process)
|
||||
.bind(&violation.target_process)
|
||||
.bind(&violation.content_preview)
|
||||
.bind(&violation.action_taken)
|
||||
.bind(&violation.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting clipboard violation: {}", e))?;
|
||||
|
||||
debug!("Clipboard violation: {} action={}", violation.device_uid, violation.action_taken);
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "clipboard_violation",
|
||||
"device_uid": violation.device_uid,
|
||||
"source_process": violation.source_process,
|
||||
"action_taken": violation.action_taken
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::PopupBlockStats => {
|
||||
let stats: csm_protocol::PopupBlockStatsPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid popup block stats: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "PopupBlockStats", &stats.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for rule_stat in &stats.rule_stats {
|
||||
sqlx::query(
|
||||
"INSERT INTO popup_block_stats (device_uid, rule_id, blocked_count, period_secs, reported_at) \
|
||||
VALUES (?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&stats.device_uid)
|
||||
.bind(rule_stat.rule_id)
|
||||
.bind(rule_stat.hits as i32)
|
||||
.bind(stats.period_secs as i32)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
debug!("Popup block stats: {} blocked {} windows in {}s", stats.device_uid, stats.blocked_count, stats.period_secs);
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
@@ -728,7 +939,6 @@ async fn process_frame(
|
||||
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Set read timeout to detect stale connections
|
||||
let _ = stream.set_nodelay(true);
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
@@ -739,6 +949,7 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
let hmac_fail_count = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// Writer task: forwards messages from channel to TCP stream
|
||||
let write_task = tokio::spawn(async move {
|
||||
@@ -749,12 +960,22 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
|
||||
}
|
||||
});
|
||||
|
||||
// Reader loop
|
||||
// Reader loop with idle timeout
|
||||
'reader: loop {
|
||||
let n = reader.read(&mut buffer).await?;
|
||||
if n == 0 {
|
||||
break; // Connection closed
|
||||
}
|
||||
let read_result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
|
||||
reader.read(&mut buffer),
|
||||
).await;
|
||||
|
||||
let n = match read_result {
|
||||
Ok(Ok(0)) => break, // Connection closed
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => {
|
||||
warn!("Idle timeout for device {:?}, disconnecting", device_uid);
|
||||
break;
|
||||
}
|
||||
};
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Guard against unbounded buffer growth
|
||||
@@ -766,7 +987,6 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
|
||||
// Process complete frames
|
||||
while let Some(frame) = Frame::decode(&read_buf)? {
|
||||
let frame_size = frame.encoded_size();
|
||||
// Remove consumed bytes without reallocating
|
||||
read_buf.drain(..frame_size);
|
||||
|
||||
// Rate limit check
|
||||
@@ -781,9 +1001,15 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
|
||||
// Disconnect if too many consecutive HMAC failures
|
||||
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
|
||||
warn!("Too many HMAC failures for device {:?}, disconnecting", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -807,6 +1033,7 @@ async fn handle_client_tls(
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
let hmac_fail_count = Arc::new(AtomicU32::new(0));
|
||||
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
@@ -816,12 +1043,22 @@ async fn handle_client_tls(
|
||||
}
|
||||
});
|
||||
|
||||
// Reader loop — same logic as plaintext handler
|
||||
// Reader loop with idle timeout
|
||||
'reader: loop {
|
||||
let n = reader.read(&mut buffer).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
let read_result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
|
||||
reader.read(&mut buffer),
|
||||
).await;
|
||||
|
||||
let n = match read_result {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => {
|
||||
warn!("Idle timeout for TLS device {:?}, disconnecting", device_uid);
|
||||
break;
|
||||
}
|
||||
};
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
if read_buf.len() > MAX_READ_BUF_SIZE {
|
||||
@@ -843,9 +1080,15 @@ async fn handle_client_tls(
|
||||
break 'reader;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
|
||||
// Disconnect if too many consecutive HMAC failures
|
||||
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
|
||||
warn!("Too many HMAC failures for TLS device {:?}, disconnecting", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user