feat: 初始化项目基础架构和核心功能

- 添加项目基础结构:Cargo.toml、.gitignore、设备UID和密钥文件
- 实现前端Vue3项目结构:路由、登录页面、设备管理页面
- 添加核心协议定义(crates/protocol):设备状态、资产、USB事件等
- 实现客户端监控模块:系统状态收集、资产收集
- 实现服务端基础API和插件系统
- 添加数据库迁移脚本:设备管理、资产跟踪、告警系统等
- 实现前端设备状态展示和基本交互
- 添加使用时长统计和水印功能插件
This commit is contained in:
iven
2026-04-05 00:57:51 +08:00
commit fd6fb5cca0
87 changed files with 19576 additions and 0 deletions

48
crates/client/Cargo.toml Normal file
View File

@@ -0,0 +1,48 @@
[package]
name = "csm-client"
version.workspace = true
edition.workspace = true
[dependencies]
csm-protocol = { path = "../protocol" }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
sysinfo = "0.30"
tokio-rustls = "0.26"
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
rustls-pki-types = "1"
webpki-roots = "0.26"
rustls-pemfile = "2"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"
[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.54", features = [
"Win32_Foundation",
"Win32_System_SystemInformation",
"Win32_System_Registry",
"Win32_System_IO",
"Win32_Security",
"Win32_NetworkManagement_IpHelper",
"Win32_Storage_FileSystem",
"Win32_UI_WindowsAndMessaging",
"Win32_UI_Input_KeyboardAndMouse",
"Win32_System_Threading",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_LibraryLoader",
"Win32_System_Performance",
"Win32_Graphics_Gdi",
] }
windows-service = "0.7"
hostname = "0.4"
[target.'cfg(not(target_os = "windows"))'.dependencies]
hostname = "0.4"

View File

@@ -0,0 +1,54 @@
use csm_protocol::{Frame, MessageType, HardwareAsset};
use std::time::Duration;
use tokio::sync::mpsc::Sender;
use tracing::{info, error};
use sysinfo::System;
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
let interval = Duration::from_secs(86400); // Once per day
// Initial collection on startup
if let Err(e) = collect_and_send(&tx, &device_uid).await {
error!("Initial asset collection failed: {}", e);
}
loop {
tokio::time::sleep(interval).await;
if let Err(e) = collect_and_send(&tx, &device_uid).await {
error!("Asset collection failed: {}", e);
}
}
}
async fn collect_and_send(tx: &Sender<Frame>, device_uid: &str) -> anyhow::Result<()> {
let hardware = collect_hardware(device_uid)?;
let frame = Frame::new_json(MessageType::AssetReport, &hardware)?;
tx.send(frame).await.map_err(|e| anyhow::anyhow!("Channel send failed: {}", e))?;
info!("Asset report sent for {}", device_uid);
Ok(())
}
fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
let mut sys = System::new_all();
sys.refresh_all();
let cpu_model = sys.cpus().first()
.map(|c| c.brand().to_string())
.unwrap_or_else(|| "Unknown".to_string());
let cpu_cores = sys.cpus().len() as u32;
let memory_total_mb = sys.total_memory() / 1024 / 1024; // bytes to MB (sysinfo 0.30)
Ok(HardwareAsset {
device_uid: device_uid.to_string(),
cpu_model,
cpu_cores,
memory_total_mb: memory_total_mb as u64,
disk_model: "Unknown".to_string(),
disk_total_mb: 0,
gpu_model: None,
motherboard: None,
serial_number: None,
})
}

206
crates/client/src/main.rs Normal file
View File

@@ -0,0 +1,206 @@
use anyhow::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tracing::{info, error, warn};
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
mod monitor;
mod asset;
mod usb;
mod network;
mod watermark;
mod usage_timer;
mod usb_audit;
mod popup_blocker;
mod software_blocker;
mod web_filter;
/// Shared shutdown flag
static SHUTDOWN: AtomicBool = AtomicBool::new(false);
/// Client configuration
struct ClientState {
device_uid: String,
server_addr: String,
config: ClientConfig,
device_secret: Option<String>,
registration_token: String,
/// Whether to use TLS when connecting to the server
use_tls: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("csm_client=info")
.init();
info!("CSM Client starting...");
// Load or generate device identity
let device_uid = load_or_create_device_uid()?;
info!("Device UID: {}", device_uid);
// Load server address
let server_addr = std::env::var("CSM_SERVER")
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
let state = ClientState {
device_uid,
server_addr,
config: ClientConfig::default(),
device_secret: load_device_secret(),
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
};
// TODO: Register as Windows Service on Windows
// For development, run directly
// Main event loop
run(state).await
}
async fn run(state: ClientState) -> Result<()> {
let (data_tx, mut data_rx) = tokio::sync::mpsc::channel::<Frame>(1024);
// Spawn Ctrl+C handler
tokio::spawn(async {
tokio::signal::ctrl_c().await.ok();
info!("Received Ctrl+C, initiating graceful shutdown...");
SHUTDOWN.store(true, Ordering::SeqCst);
});
// Create plugin config channels
let (watermark_tx, watermark_rx) = tokio::sync::watch::channel(None);
let (web_filter_tx, web_filter_rx) = tokio::sync::watch::channel(web_filter::WebFilterConfig::default());
let (software_blocker_tx, software_blocker_rx) = tokio::sync::watch::channel(software_blocker::SoftwareBlockerConfig::default());
let (popup_blocker_tx, popup_blocker_rx) = tokio::sync::watch::channel(popup_blocker::PopupBlockerConfig::default());
let (usb_audit_tx, usb_audit_rx) = tokio::sync::watch::channel(usb_audit::UsbAuditConfig::default());
let (usage_timer_tx, usage_timer_rx) = tokio::sync::watch::channel(usage_timer::UsageConfig::default());
let (usb_policy_tx, usb_policy_rx) = tokio::sync::watch::channel(None::<UsbPolicyPayload>);
let plugins = network::PluginChannels {
watermark_tx,
web_filter_tx,
software_blocker_tx,
popup_blocker_tx,
usb_audit_tx,
usage_timer_tx,
usb_policy_tx,
};
// Spawn core monitoring tasks
let monitor_tx = data_tx.clone();
let uid = state.device_uid.clone();
tokio::spawn(async move {
monitor::start_collecting(monitor_tx, uid).await;
});
let asset_tx = data_tx.clone();
let uid = state.device_uid.clone();
tokio::spawn(async move {
asset::start_collecting(asset_tx, uid).await;
});
let usb_tx = data_tx.clone();
let uid = state.device_uid.clone();
tokio::spawn(async move {
usb::start_monitoring(usb_tx, uid, usb_policy_rx).await;
});
// Spawn plugin tasks
tokio::spawn(async move {
watermark::start(watermark_rx).await;
});
let usage_data_tx = data_tx.clone();
let usage_uid = state.device_uid.clone();
tokio::spawn(async move {
usage_timer::start(usage_timer_rx, usage_data_tx, usage_uid).await;
});
let audit_data_tx = data_tx.clone();
let audit_uid = state.device_uid.clone();
tokio::spawn(async move {
usb_audit::start(usb_audit_rx, audit_data_tx, audit_uid).await;
});
tokio::spawn(async move {
popup_blocker::start(popup_blocker_rx).await;
});
let sw_data_tx = data_tx.clone();
let sw_uid = state.device_uid.clone();
tokio::spawn(async move {
software_blocker::start(software_blocker_rx, sw_data_tx, sw_uid).await;
});
tokio::spawn(async move {
web_filter::start(web_filter_rx).await;
});
// Connect to server with reconnect
let mut backoff = Duration::from_secs(1);
let max_backoff = Duration::from_secs(60);
loop {
if SHUTDOWN.load(Ordering::SeqCst) {
info!("Shutting down gracefully...");
break Ok(());
}
match network::connect_and_run(&state, &mut data_rx, &plugins).await {
Ok(()) => {
warn!("Disconnected from server, reconnecting...");
// Use a short fixed delay for clean disconnects (server-initiated),
// but don't reset to zero to prevent connection storms
tokio::time::sleep(Duration::from_secs(2)).await;
}
Err(e) => {
error!("Connection error: {}, reconnecting...", e);
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(max_backoff);
}
}
// Drain stale frames that accumulated during disconnection
let drained = data_rx.try_recv().ok().map(|_| 1).unwrap_or(0);
if drained > 0 {
let mut count = drained;
while data_rx.try_recv().is_ok() {
count += 1;
}
warn!("Drained {} stale frames from channel", count);
}
}
}
fn load_or_create_device_uid() -> Result<String> {
// In production, store in Windows Credential Store or local config
// For now, use a simple file
let uid_file = "device_uid.txt";
if std::path::Path::new(uid_file).exists() {
let uid = std::fs::read_to_string(uid_file)?;
Ok(uid.trim().to_string())
} else {
let uid = uuid::Uuid::new_v4().to_string();
std::fs::write(uid_file, &uid)?;
Ok(uid)
}
}
/// Load persisted device_secret from disk (if available)
pub fn load_device_secret() -> Option<String> {
let secret_file = "device_secret.txt";
let secret = std::fs::read_to_string(secret_file).ok()?;
let trimmed = secret.trim().to_string();
if trimmed.is_empty() { None } else { Some(trimmed) }
}
/// Persist device_secret to disk
pub fn save_device_secret(secret: &str) {
if let Err(e) = std::fs::write("device_secret.txt", secret) {
warn!("Failed to persist device_secret: {}", e);
}
}

View File

@@ -0,0 +1,86 @@
use anyhow::Result;
use csm_protocol::{Frame, MessageType, DeviceStatus, ProcessInfo};
use std::time::Duration;
use tokio::sync::mpsc::Sender;
use tracing::{info, error, debug};
use sysinfo::System;
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
let interval = Duration::from_secs(60);
loop {
// Run blocking sysinfo collection on a dedicated thread
let uid_clone = device_uid.clone();
let result = tokio::task::spawn_blocking(move || {
collect_system_status(&uid_clone)
}).await;
match result {
Ok(Ok(status)) => {
if let Ok(frame) = Frame::new_json(MessageType::StatusReport, &status) {
debug!("Sending status report: cpu={:.1}%, mem={:.1}%", status.cpu_usage, status.memory_usage);
if tx.send(frame).await.is_err() {
info!("Monitor channel closed, exiting");
break;
}
}
}
Ok(Err(e)) => {
error!("Failed to collect system status: {}", e);
}
Err(e) => {
error!("Monitor task join error: {}", e);
}
}
tokio::time::sleep(interval).await;
}
}
fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
let mut sys = System::new_all();
sys.refresh_all();
// Brief wait for CPU usage to stabilize
std::thread::sleep(Duration::from_millis(200));
sys.refresh_all();
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
let total_memory = sys.total_memory() / 1024 / 1024; // Convert bytes to MB (sysinfo 0.30 returns bytes)
let used_memory = sys.used_memory() / 1024 / 1024;
let memory_usage = if total_memory > 0 {
(used_memory as f64 / total_memory as f64) * 100.0
} else {
0.0
};
// Top processes by CPU
let mut processes: Vec<ProcessInfo> = sys.processes()
.iter()
.map(|(_, p)| {
ProcessInfo {
name: p.name().to_string(),
pid: p.pid().as_u32(),
cpu_usage: p.cpu_usage() as f64,
memory_mb: p.memory() / 1024 / 1024, // bytes to MB (sysinfo 0.30)
}
})
.collect();
processes.sort_by(|a, b| b.cpu_usage.partial_cmp(&a.cpu_usage).unwrap_or(std::cmp::Ordering::Equal));
processes.truncate(10);
Ok(DeviceStatus {
device_uid: device_uid.to_string(),
cpu_usage,
memory_usage,
memory_total_mb: total_memory as u64,
disk_usage: 0.0, // TODO: implement disk usage via Windows API
disk_total_mb: 0,
network_rx_rate: 0,
network_tx_rate: 0,
running_procs: sys.processes().len() as u32,
top_processes: processes,
})
}

View File

@@ -0,0 +1,380 @@
use anyhow::Result;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::ClientState;
/// Holds senders for all plugin config channels
pub struct PluginChannels {
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
}
/// Connect to server and run the main communication loop
pub async fn connect_and_run(
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
run_comm_loop(tls_stream, state, data_rx, plugins).await
} else {
run_comm_loop(tcp_stream, state, data_rx, plugins).await
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
// Load custom CA certificate if specified
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
let ca_pem = std::fs::read(&ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
for cert in certs {
root_store.add(cert)?;
}
info!("Loaded custom CA certificates from {}", ca_path);
}
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_no_client_auth()
} else {
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
let tls_stream = connector.connect(server_name, stream).await?;
info!("TLS handshake completed with {}", domain);
Ok(tls_stream)
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &rustls_pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
/// Main communication loop over any read+write stream
async fn run_comm_loop<S>(
mut stream: S,
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
// Send registration
let register = RegisterRequest {
device_uid: state.device_uid.clone(),
hostname: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()),
registration_token: state.registration_token.clone(),
os_version: get_os_info(),
mac_address: None,
};
let frame = Frame::new_json(MessageType::Register, &register)?;
stream.write_all(&frame.encode()).await?;
info!("Registration request sent");
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
// Clamp heartbeat interval to sane range [5, 3600] to prevent CPU spin or effective disable
let heartbeat_secs = state.config.heartbeat_interval_secs.clamp(5, 3600);
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
heartbeat_interval.tick().await; // Skip first tick
// HMAC key — set after receiving RegisterResponse
let mut device_secret: Option<String> = state.device_secret.clone();
loop {
tokio::select! {
// Read from server
result = stream.read(&mut buffer) => {
let n = result?;
if n == 0 {
return Err(anyhow::anyhow!("Server closed connection"));
}
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth from a malicious server
if read_buf.len() > 1_048_576 {
return Err(anyhow::anyhow!("Read buffer exceeded 1MB, server may be malicious"));
}
// Process complete frames
loop {
match Frame::decode(&read_buf)? {
Some(frame) => {
let consumed = frame.encoded_size();
read_buf.drain(..consumed);
// Capture device_secret from registration response
if frame.msg_type == MessageType::RegisterResponse {
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
device_secret = Some(resp.device_secret.clone());
crate::save_device_secret(&resp.device_secret);
info!("Device secret received and persisted, HMAC enabled for heartbeats");
}
}
handle_server_message(frame, plugins)?;
}
None => break, // Incomplete frame, wait for more data
}
}
}
// Send queued data
frame = data_rx.recv() => {
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
stream.write_all(&frame.encode()).await?;
}
// Heartbeat
_ = heartbeat_interval.tick() => {
let timestamp = chrono::Utc::now().to_rfc3339();
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, &timestamp);
let heartbeat = HeartbeatPayload {
device_uid: state.device_uid.clone(),
timestamp,
hmac: hmac_value,
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
stream.write_all(&frame.encode()).await?;
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
}
}
}
}
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
match frame.msg_type {
MessageType::RegisterResponse => {
let resp: RegisterResponse = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
info!("Registration accepted by server (server version: {})", resp.config.server_version);
}
MessageType::PolicyUpdate => {
let policy: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
}
MessageType::WatermarkConfig => {
let config: WatermarkConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
info!("Received watermark config: enabled={}", config.enabled);
plugins.watermark_tx.send(Some(config))?;
}
MessageType::UsbPolicyUpdate => {
let policy: UsbPolicyPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB policy: {}", e))?;
info!("Received USB policy: type={}, enabled={}", policy.policy_type, policy.enabled);
plugins.usb_policy_tx.send(Some(policy))?;
}
MessageType::WebFilterRuleUpdate => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
info!("Received web filter rules update");
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
plugins.web_filter_tx.send(config)?;
}
MessageType::SoftwareBlacklist => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
info!("Received software blacklist update");
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
info!("Received popup blocker rules update");
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
info!("Plugin enabled: {}", payload.plugin_name);
// Route to appropriate plugin channel based on plugin_name
handle_plugin_control(&payload, plugins, true)?;
}
MessageType::PluginDisable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
fn handle_plugin_control(
payload: &csm_protocol::PluginControlPayload,
plugins: &PluginChannels,
enabled: bool,
) -> Result<()> {
match payload.plugin_name.as_str() {
"watermark" => {
if !enabled {
// Send disabled config to remove overlay
plugins.watermark_tx.send(None)?;
}
// When enabling, server will push the actual config next
}
"web_filter" => {
if !enabled {
// Clear hosts rules on disable
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
}
// When enabling, server will push rules
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
}
}
"popup_blocker" => {
if !enabled {
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
}
}
"usb_audit" => {
if !enabled {
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
}
}
"usage_timer" => {
if !enabled {
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}
}
Ok(())
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
let secret = match secret {
Some(s) if !s.is_empty() => s,
_ => return String::new(),
};
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
fn get_os_info() -> String {
use sysinfo::System;
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
format!("{} {}", name, version)
}

View File

@@ -0,0 +1,370 @@
use anyhow::Result;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::ClientState;
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Holds senders for all plugin config channels
pub struct PluginChannels {
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
}
/// Connect to server and run the main communication loop
pub async fn connect_and_run(
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
run_comm_loop(tls_stream, state, data_rx, plugins).await
} else {
run_comm_loop(tcp_stream, state, data_rx, plugins).await
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
// Load custom CA certificate if specified
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
let ca_pem = std::fs::read(&ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
for cert in certs {
root_store.add(cert)?;
}
info!("Loaded custom CA certificates from {}", ca_path);
}
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_no_client_auth()
} else {
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
let tls_stream = connector.connect(server_name, stream).await?;
info!("TLS handshake completed with {}", domain);
Ok(tls_stream)
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &rustls_pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
/// Main communication loop over any read+write stream
async fn run_comm_loop<S>(
mut stream: S,
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
// Send registration
let register = RegisterRequest {
device_uid: state.device_uid.clone(),
hostname: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()),
registration_token: state.registration_token.clone(),
os_version: get_os_info(),
mac_address: None,
};
let frame = Frame::new_json(MessageType::Register, &register)?;
stream.write_all(&frame.encode()).await?;
info!("Registration request sent");
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let heartbeat_secs = state.config.heartbeat_interval_secs;
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
heartbeat_interval.tick().await; // Skip first tick
// HMAC key — set after receiving RegisterResponse
let mut device_secret: Option<String> = state.device_secret.clone();
loop {
tokio::select! {
// Read from server
result = stream.read(&mut buffer) => {
let n = result?;
if n == 0 {
return Err(anyhow::anyhow!("Server closed connection"));
}
read_buf.extend_from_slice(&buffer[..n]);
// Process complete frames
loop {
match Frame::decode(&read_buf)? {
Some(frame) => {
let consumed = frame.encoded_size();
read_buf.drain(..consumed);
// Capture device_secret from registration response
if frame.msg_type == MessageType::RegisterResponse {
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
device_secret = Some(resp.device_secret.clone());
crate::save_device_secret(&resp.device_secret);
info!("Device secret received and persisted, HMAC enabled for heartbeats");
}
}
handle_server_message(frame, plugins)?;
}
None => break, // Incomplete frame, wait for more data
}
}
}
// Send queued data
frame = data_rx.recv() => {
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
stream.write_all(&frame.encode()).await?;
}
// Heartbeat
_ = heartbeat_interval.tick() => {
let timestamp = chrono::Utc::now().to_rfc3339();
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, &timestamp);
let heartbeat = HeartbeatPayload {
device_uid: state.device_uid.clone(),
timestamp,
hmac: hmac_value,
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
stream.write_all(&frame.encode()).await?;
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
}
}
}
}
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
match frame.msg_type {
MessageType::RegisterResponse => {
let resp: RegisterResponse = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
info!("Registration accepted by server (server version: {})", resp.config.server_version);
}
MessageType::PolicyUpdate => {
let policy: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
}
MessageType::WatermarkConfig => {
let config: WatermarkConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
info!("Received watermark config: enabled={}", config.enabled);
plugins.watermark_tx.send(Some(config))?;
}
MessageType::WebFilterRuleUpdate => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
info!("Received web filter rules update");
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
plugins.web_filter_tx.send(config)?;
}
MessageType::SoftwareBlacklist => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
info!("Received software blacklist update");
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
info!("Received popup blocker rules update");
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
info!("Plugin enabled: {}", payload.plugin_name);
// Route to appropriate plugin channel based on plugin_name
handle_plugin_control(&payload, plugins, true)?;
}
MessageType::PluginDisable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
fn handle_plugin_control(
payload: &csm_protocol::PluginControlPayload,
plugins: &PluginChannels,
enabled: bool,
) -> Result<()> {
match payload.plugin_name.as_str() {
"watermark" => {
if !enabled {
// Send disabled config to remove overlay
plugins.watermark_tx.send(None)?;
}
// When enabling, server will push the actual config next
}
"web_filter" => {
if !enabled {
// Clear hosts rules on disable
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
}
// When enabling, server will push rules
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
}
}
"popup_blocker" => {
if !enabled {
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
}
}
"usb_audit" => {
if !enabled {
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
}
}
"usage_timer" => {
if !enabled {
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}
}
Ok(())
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
let secret = match secret {
Some(s) if !s.is_empty() => s,
_ => return String::new(),
};
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
fn get_os_info() -> String {
use sysinfo::System;
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
format!("{} {}", name, version)
}

View File

@@ -0,0 +1,254 @@
use tokio::sync::watch;
use tracing::{info, debug};
use serde::Deserialize;
/// Popup blocker rule from server
#[derive(Debug, Clone, Deserialize)]
pub struct PopupRule {
pub id: i64,
pub rule_type: String,
pub window_title: Option<String>,
pub window_class: Option<String>,
pub process_name: Option<String>,
}
/// Popup blocker configuration
#[derive(Debug, Clone, Default)]
pub struct PopupBlockerConfig {
pub enabled: bool,
pub rules: Vec<PopupRule>,
}
/// Context passed to EnumWindows callback via LPARAM
struct ScanContext {
rules: Vec<PopupRule>,
blocked_count: u32,
}
/// Start popup blocker plugin.
/// Periodically enumerates windows and closes those matching rules.
pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
info!("Popup blocker plugin started");
let mut config = PopupBlockerConfig::default();
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(2));
scan_interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("Popup blocker config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
config = new_config;
}
_ = scan_interval.tick() => {
if !config.enabled || config.rules.is_empty() {
continue;
}
scan_and_block(&config.rules);
}
}
}
}
fn scan_and_block(rules: &[PopupRule]) {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::WindowsAndMessaging::EnumWindows;
use windows::Win32::Foundation::LPARAM;
let mut ctx = ScanContext {
rules: rules.to_vec(),
blocked_count: 0,
};
unsafe {
let _ = EnumWindows(
Some(enum_windows_callback),
LPARAM(&mut ctx as *mut ScanContext as isize),
);
}
if ctx.blocked_count > 0 {
debug!("Popup scan blocked {} windows", ctx.blocked_count);
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = rules;
}
}
#[cfg(target_os = "windows")]
unsafe extern "system" fn enum_windows_callback(
hwnd: windows::Win32::Foundation::HWND,
lparam: windows::Win32::Foundation::LPARAM,
) -> windows::Win32::Foundation::BOOL {
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::Win32::Foundation::*;
// Only consider visible, top-level windows without an owner
if !IsWindowVisible(hwnd).as_bool() {
return BOOL(1);
}
// Skip windows that have an owner (they're child dialogs, not popups)
if GetWindow(hwnd, GW_OWNER).0 != 0 {
return BOOL(1);
}
// Get window title
let mut title_buf = [0u16; 512];
let title_len = GetWindowTextW(hwnd, &mut title_buf);
let title_len = title_len.max(0) as usize;
let title = String::from_utf16_lossy(&title_buf[..title_len]);
// Skip windows with empty titles
if title.is_empty() {
return BOOL(1);
}
// Get class name
let mut class_buf = [0u16; 256];
let class_len = GetClassNameW(hwnd, &mut class_buf);
let class_len = class_len.max(0) as usize;
let class_name = String::from_utf16_lossy(&class_buf[..class_len]);
// Get process name from PID
let mut pid: u32 = 0;
GetWindowThreadProcessId(hwnd, Some(&mut pid));
let process_name = get_process_name(pid);
// Recover the ScanContext from LPARAM
let ctx = &mut *(lparam.0 as *mut ScanContext);
// Check against each rule
for rule in &ctx.rules {
if rule.rule_type != "block" {
continue;
}
let matches = rule_matches(rule, &title, &class_name, &process_name);
if matches {
let _ = PostMessageW(hwnd, WM_CLOSE, WPARAM(0), LPARAM(0));
ctx.blocked_count += 1;
info!(
"Blocked popup: title='{}' class='{}' process='{}' (rule_id={})",
title, class_name, process_name, rule.id
);
break; // One match is enough per window
}
}
BOOL(1) // Continue enumeration
}
fn rule_matches(rule: &PopupRule, title: &str, class_name: &str, process_name: &str) -> bool {
let title_match = match &rule.window_title {
Some(pattern) => pattern_match(pattern, title),
None => true, // No title filter = match all
};
let class_match = match &rule.window_class {
Some(pattern) => pattern_match(pattern, class_name),
None => true,
};
let process_match = match &rule.process_name {
Some(pattern) => pattern_match(pattern, process_name),
None => true,
};
title_match && class_match && process_match
}
/// Simple case-insensitive wildcard pattern matching.
/// Supports `*` as wildcard (matches any characters).
fn pattern_match(pattern: &str, text: &str) -> bool {
let p = pattern.to_lowercase();
let t = text.to_lowercase();
if !p.contains('*') {
return t.contains(&p);
}
let parts: Vec<&str> = p.split('*').collect();
if parts.is_empty() {
return true;
}
let mut pos = 0usize;
let mut matched_any = false;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
// Leading empty = pattern starts with * → no start anchor
// Trailing empty = pattern ends with * → no end anchor
continue;
}
matched_any = true;
if i == 0 && !parts[0].is_empty() {
// Pattern starts with literal → must match at start
if !t.starts_with(part) {
return false;
}
pos = part.len();
} else {
// Find this segment anywhere after current position
match t[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with literal (no trailing *), must match at end
if matched_any && !parts.last().map_or(true, |p| p.is_empty()) {
return t.ends_with(parts.last().unwrap());
}
true
}
#[cfg(target_os = "windows")]
fn get_process_name(pid: u32) -> String {
use windows::Win32::System::Diagnostics::ToolHelp::*;
use windows::Win32::Foundation::CloseHandle;
unsafe {
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
Ok(s) => s,
Err(_) => return format!("pid:{}", pid),
};
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
if entry.th32ProcessID == pid {
let name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
let _ = CloseHandle(snapshot);
return name;
}
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if Process32NextW(snapshot, &mut entry).is_err() {
break;
}
}
}
let _ = CloseHandle(snapshot);
format!("pid:{}", pid)
}
}
#[cfg(not(target_os = "windows"))]
fn get_process_name(pid: u32) -> String {
format!("pid:{}", pid)
}

View File

@@ -0,0 +1,254 @@
use tokio::sync::watch;
use tracing::{info, warn};
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
use serde::Deserialize;
/// System-critical processes that must never be killed regardless of server rules.
/// Killing any of these would cause system instability or a BSOD.
const PROTECTED_PROCESSES: &[&str] = &[
"system",
"system idle process",
"svchost.exe",
"lsass.exe",
"csrss.exe",
"wininit.exe",
"winlogon.exe",
"services.exe",
"dwm.exe",
"explorer.exe",
"taskhostw.exe",
"registry",
"smss.exe",
"conhost.exe",
];
/// Software blacklist entry from server
#[derive(Debug, Clone, Deserialize)]
pub struct BlacklistEntry {
pub id: i64,
pub name_pattern: String,
pub action: String,
}
/// Software blocker configuration
#[derive(Debug, Clone, Default)]
pub struct SoftwareBlockerConfig {
pub enabled: bool,
pub blacklist: Vec<BlacklistEntry>,
}
/// Start software blocker plugin.
/// Periodically scans running processes against the blacklist.
pub async fn start(
mut config_rx: watch::Receiver<SoftwareBlockerConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Software blocker plugin started");
let mut config = SoftwareBlockerConfig::default();
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
scan_interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
config = new_config;
}
_ = scan_interval.tick() => {
if !config.enabled || config.blacklist.is_empty() {
continue;
}
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
}
}
}
}
async fn scan_processes(
blacklist: &[BlacklistEntry],
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
) {
let running = get_running_processes_with_pids();
for entry in blacklist {
for (process_name, pid) in &running {
if pattern_matches(&entry.name_pattern, process_name) {
// Never kill system-critical processes
if is_protected_process(process_name) {
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
continue;
}
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
// Report violation to server
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
let action_taken = match entry.action.as_str() {
"block" => "blocked_install",
"alert" => "alerted",
other => other,
};
let violation = SoftwareViolationReport {
device_uid: device_uid.to_string(),
software_name: process_name.clone(),
action_taken: action_taken.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
let _ = data_tx.send(frame).await;
}
// Kill the process directly by captured PID (avoids TOCTOU race)
if entry.action == "block" {
kill_process_by_pid(*pid, process_name);
}
}
}
}
}
fn pattern_matches(pattern: &str, name: &str) -> bool {
let pattern_lower = pattern.to_lowercase();
let name_lower = name.to_lowercase();
// Support wildcard patterns
if pattern_lower.contains('*') {
let parts: Vec<&str> = pattern_lower.split('*').collect();
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 && !parts[0].is_empty() {
// Pattern starts with literal → must match at start
if !name_lower.starts_with(part) {
return false;
}
pos = part.len();
} else {
match name_lower[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with literal (no trailing *), must match at end
if !parts.last().map_or(true, |p| p.is_empty()) {
return name_lower.ends_with(parts.last().unwrap());
}
true
} else {
name_lower.contains(&pattern_lower)
}
}
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
#[cfg(target_os = "windows")]
{
use windows::Win32::System::Diagnostics::ToolHelp::*;
use windows::Win32::Foundation::CloseHandle;
let mut procs = Vec::new();
unsafe {
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
Ok(s) => s,
Err(_) => return procs,
};
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
let name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
procs.push((name, entry.th32ProcessID));
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if !Process32NextW(snapshot, &mut entry).is_ok() {
break;
}
}
}
let _ = CloseHandle(snapshot);
}
procs
}
#[cfg(not(target_os = "windows"))]
{
Vec::new()
}
}
fn kill_process_by_pid(pid: u32, expected_name: &str) {
#[cfg(target_os = "windows")]
{
use windows::Win32::System::Threading::{OpenProcess, TerminateProcess, PROCESS_TERMINATE};
use windows::Win32::Foundation::CloseHandle;
use windows::Win32::System::Diagnostics::ToolHelp::*;
unsafe {
// Verify the PID still belongs to the expected process (prevent PID reuse kills)
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
Ok(s) => s,
Err(_) => return,
};
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
let mut name_matches = false;
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
if entry.th32ProcessID == pid {
let current_name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
name_matches = current_name.to_lowercase() == expected_name.to_lowercase();
break;
}
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if !Process32NextW(snapshot, &mut entry).is_ok() {
break;
}
}
}
let _ = CloseHandle(snapshot);
if !name_matches {
warn!("Skipping kill: PID {} no longer matches expected process '{}'", pid, expected_name);
return;
}
if let Ok(handle) = OpenProcess(PROCESS_TERMINATE, false, pid) {
let terminated = TerminateProcess(handle, 1).is_ok();
let _ = CloseHandle(handle);
if terminated {
warn!("Killed process: {} (pid={})", expected_name, pid);
} else {
warn!("TerminateProcess failed for: {} (pid={})", expected_name, pid);
}
} else {
warn!("OpenProcess failed for: {} (pid={}) — insufficient privileges?", expected_name, pid);
}
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = (pid, expected_name);
}
}
fn is_protected_process(name: &str) -> bool {
let lower = name.to_lowercase();
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
}

View File

@@ -0,0 +1,197 @@
use std::time::{Duration, Instant};
use tokio::sync::watch;
use tracing::{info, debug};
use csm_protocol::{Frame, MessageType, UsageDailyReport, AppUsageEntry};
/// Usage tracking configuration pushed from server
#[derive(Debug, Clone, Default)]
pub struct UsageConfig {
pub enabled: bool,
pub idle_threshold_secs: u64,
pub report_interval_secs: u64,
}
/// Start the usage timer plugin.
/// Tracks active/idle time and foreground application usage.
pub async fn start(
mut config_rx: watch::Receiver<UsageConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Usage timer plugin started");
let mut config = UsageConfig {
enabled: false,
idle_threshold_secs: 300, // 5 minutes default
report_interval_secs: 300,
};
let mut report_interval = tokio::time::interval(Duration::from_secs(60));
report_interval.tick().await;
let mut active_secs: u64 = 0;
let mut idle_secs: u64 = 0;
let mut app_usage: std::collections::HashMap<String, u64> = std::collections::HashMap::new();
let mut last_tick = Instant::now();
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
if new_config.enabled != config.enabled {
info!("Usage timer enabled: {}", new_config.enabled);
}
config = new_config;
if config.enabled {
report_interval = tokio::time::interval(Duration::from_secs(config.report_interval_secs));
report_interval.tick().await;
last_tick = Instant::now();
}
}
_ = report_interval.tick() => {
if !config.enabled {
continue;
}
// Measure actual elapsed time since last tick
let now = Instant::now();
let elapsed_secs = now.duration_since(last_tick).as_secs();
last_tick = now;
let idle_ms = get_idle_millis();
let is_idle = idle_ms as u64 > config.idle_threshold_secs * 1000;
if is_idle {
idle_secs += elapsed_secs;
} else {
active_secs += elapsed_secs;
}
// Track foreground app
if let Some(app) = get_foreground_app_name() {
*app_usage.entry(app).or_insert(0) += elapsed_secs;
}
// Report usage to server
if active_secs > 0 || idle_secs > 0 {
let report = UsageDailyReport {
device_uid: device_uid.clone(),
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
total_active_minutes: (active_secs / 60) as u32,
total_idle_minutes: (idle_secs / 60) as u32,
first_active_at: None,
last_active_at: Some(chrono::Utc::now().to_rfc3339()),
};
if let Ok(frame) = Frame::new_json(MessageType::UsageReport, &report) {
if data_tx.send(frame).await.is_err() {
break;
}
}
// Report per-app usage
for (app, secs) in &app_usage {
let entry = AppUsageEntry {
device_uid: device_uid.clone(),
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
app_name: app.clone(),
usage_minutes: (secs / 60) as u32,
};
if let Ok(frame) = Frame::new_json(MessageType::AppUsageReport, &entry) {
let _ = data_tx.send(frame).await;
}
}
// Reset counters after reporting
active_secs = 0;
idle_secs = 0;
app_usage.clear();
}
debug!("Usage report sent (idle_ms={}, elapsed={}s)", idle_ms, elapsed_secs);
}
}
}
}
/// Get system idle time in milliseconds using GetLastInputInfo + GetTickCount64.
/// Both use the same time base. LASTINPUTINFO.dwTime is u32 (GetTickCount legacy),
/// so we take the low 32 bits of GetTickCount64 for correct wrapping comparison.
fn get_idle_millis() -> u32 {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::Input::KeyboardAndMouse::{GetLastInputInfo, LASTINPUTINFO};
unsafe {
let mut lii = LASTINPUTINFO {
cbSize: std::mem::size_of::<LASTINPUTINFO>() as u32,
dwTime: 0,
};
if GetLastInputInfo(&mut lii).as_bool() {
// GetTickCount64 returns u64, but dwTime is u32.
// Take low 32 bits so wrapping_sub produces correct idle delta.
let tick_low32 = (windows::Win32::System::SystemInformation::GetTickCount64() & 0xFFFFFFFF) as u32;
return tick_low32.wrapping_sub(lii.dwTime);
}
}
0
}
#[cfg(not(target_os = "windows"))]
{
0
}
}
/// Get the foreground window's process name using CreateToolhelp32Snapshot
fn get_foreground_app_name() -> Option<String> {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::WindowsAndMessaging::{GetForegroundWindow, GetWindowThreadProcessId};
use windows::Win32::System::Diagnostics::ToolHelp::*;
use windows::Win32::Foundation::CloseHandle;
unsafe {
let hwnd = GetForegroundWindow();
if hwnd.0 == 0 {
return None;
}
let mut pid: u32 = 0;
GetWindowThreadProcessId(hwnd, Some(&mut pid));
if pid == 0 {
return None;
}
// Get process name via CreateToolhelp32Snapshot
let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()?;
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
if entry.th32ProcessID == pid {
let name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
let _ = CloseHandle(snapshot);
return Some(name);
}
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if !Process32NextW(snapshot, &mut entry).is_ok() {
break;
}
}
}
let _ = CloseHandle(snapshot);
None
}
}
#[cfg(not(target_os = "windows"))]
{
None
}
}

View File

@@ -0,0 +1,249 @@
use csm_protocol::{Frame, MessageType, UsbEvent, UsbEventType, UsbPolicyPayload, UsbDeviceRule};
use std::time::Duration;
use tokio::sync::{mpsc::Sender, watch};
use tracing::{info, warn};
/// Start USB monitoring with policy enforcement.
/// Monitors for removable drive insertions/removals and enforces USB policies.
pub async fn start_monitoring(
tx: Sender<Frame>,
device_uid: String,
mut policy_rx: watch::Receiver<Option<UsbPolicyPayload>>,
) {
info!("USB monitoring started for {}", device_uid);
let mut current_policy: Option<UsbPolicyPayload> = None;
let mut known_drives: Vec<String> = Vec::new();
let interval = Duration::from_secs(10);
loop {
// Check for policy updates (non-blocking)
if policy_rx.has_changed().unwrap_or(false) {
let new_policy = policy_rx.borrow_and_update().clone();
let policy_desc = new_policy.as_ref()
.map(|p| format!("type={}, enabled={}", p.policy_type, p.enabled))
.unwrap_or_else(|| "None".to_string());
info!("USB policy updated: {}", policy_desc);
current_policy = new_policy;
// Re-check all currently mounted drives against new policy
let drives = detect_removable_drives();
for drive in &drives {
if should_block_drive(drive, &current_policy) {
warn!("USB device {} blocked by policy, ejecting...", drive);
if let Err(e) = eject_drive(drive) {
warn!("Failed to eject {}: {}", drive, e);
} else {
info!("Successfully ejected {}", drive);
}
// Report blocked event
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
}
}
}
let current_drives = detect_removable_drives();
// Detect new drives (inserted)
for drive in &current_drives {
if !known_drives.iter().any(|d: &String| d == drive) {
info!("USB device inserted: {}", drive);
// Check against policy before reporting
if should_block_drive(drive, &current_policy) {
warn!("USB device {} blocked by policy, ejecting...", drive);
if let Err(e) = eject_drive(drive) {
warn!("Failed to eject {}: {}", drive, e);
} else {
info!("Successfully ejected {}", drive);
}
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
continue;
}
send_usb_event(&tx, &device_uid, UsbEventType::Inserted, drive).await;
}
}
// Detect removed drives
for drive in &known_drives {
if !current_drives.iter().any(|d| d == drive) {
info!("USB device removed: {}", drive);
send_usb_event(&tx, &device_uid, UsbEventType::Removed, drive).await;
}
}
known_drives = current_drives;
tokio::time::sleep(interval).await;
}
}
async fn send_usb_event(
tx: &Sender<Frame>,
device_uid: &str,
event_type: UsbEventType,
drive: &str,
) {
let event = UsbEvent {
device_uid: device_uid.to_string(),
event_type,
vendor_id: None,
product_id: None,
serial: None,
device_name: Some(drive.to_string()),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbEvent, &event) {
tx.send(frame).await.ok();
}
}
/// Check if a drive should be blocked based on the current policy.
fn should_block_drive(drive: &str, policy: &Option<UsbPolicyPayload>) -> bool {
let policy = match policy {
Some(p) if p.enabled => p,
_ => return false,
};
match policy.policy_type.as_str() {
"all_block" => true,
"blacklist" => {
// Block if any rule matches
policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
}
"whitelist" => {
// Block if NOT in whitelist (empty whitelist = block all)
if policy.rules.is_empty() {
return true;
}
!policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
}
_ => false,
}
}
/// Check if a device matches a rule pattern.
/// Currently matches by device_name (drive letter path) since that's what we detect.
fn device_matches_rule(drive: &str, rule: &UsbDeviceRule) -> bool {
// Match by device name (drive root path like "E:\")
if let Some(ref name) = rule.device_name {
if drive.eq_ignore_ascii_case(name) || drive.eq_ignore_ascii_case(&name.replace('\\', "")) {
return true;
}
}
// vendor_id, product_id, serial matching would require WMI or SetupDi queries
// For now, these are placeholder checks
let _ = (drive, rule);
false
}
/// DRIVE_REMOVABLE = 2 in Windows API
const DRIVE_REMOVABLE: u32 = 2;
fn detect_removable_drives() -> Vec<String> {
let mut drives = Vec::new();
#[cfg(target_os = "windows")]
{
for letter in b'A'..=b'Z' {
let root = format!("{}:\\", letter as char);
let root_wide: Vec<u16> = root.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let drive_type = windows::Win32::Storage::FileSystem::GetDriveTypeW(
windows::core::PCWSTR(root_wide.as_ptr()),
);
if drive_type == DRIVE_REMOVABLE {
drives.push(root);
}
}
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = &drives; // Suppress unused warning on non-Windows
}
drives
}
/// Eject a removable drive using Windows API.
/// Opens the volume handle, dismounts the filesystem, and ejects the media.
#[cfg(target_os = "windows")]
fn eject_drive(drive: &str) -> Result<(), anyhow::Error> {
use windows::Win32::Storage::FileSystem::*;
use windows::Win32::System::IO::DeviceIoControl;
use windows::Win32::Foundation::*;
use windows::core::PCWSTR;
// IOCTL control codes (not exposed directly in windows crate)
const FSCTL_DISMOUNT_VOLUME: u32 = 0x00090020;
const IOCTL_STORAGE_EJECT_MEDIA: u32 = 0x002D4808;
// Extract drive letter from path like "E:\"
let letter = drive.chars().next().unwrap_or('A');
let path = format!("\\\\.\\{}:\0", letter);
let path_wide: Vec<u16> = path.encode_utf16().collect();
unsafe {
let handle = CreateFileW(
PCWSTR(path_wide.as_ptr()),
GENERIC_READ.0,
FILE_SHARE_READ | FILE_SHARE_WRITE,
None,
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
None,
)?;
if handle.is_invalid() {
return Err(anyhow::anyhow!("Failed to open volume handle for {}", drive));
}
// Dismount the filesystem
let mut bytes_returned = 0u32;
let dismount_ok = DeviceIoControl(
handle,
FSCTL_DISMOUNT_VOLUME,
None,
0,
None,
0,
Some(&mut bytes_returned),
None,
).is_ok();
if !dismount_ok {
warn!("FSCTL_DISMOUNT_VOLUME failed for {}", drive);
}
// Eject the media
let eject_ok = DeviceIoControl(
handle,
IOCTL_STORAGE_EJECT_MEDIA,
None,
0,
None,
0,
Some(&mut bytes_returned),
None,
).is_ok();
if !eject_ok {
warn!("IOCTL_STORAGE_EJECT_MEDIA failed for {}", drive);
}
let _ = CloseHandle(handle);
if !dismount_ok && !eject_ok {
Err(anyhow::anyhow!("Failed to eject {}", drive))
} else {
Ok(())
}
}
}
#[cfg(not(target_os = "windows"))]
fn eject_drive(_drive: &str) -> Result<(), anyhow::Error> {
Ok(())
}

View File

@@ -0,0 +1,190 @@
use tokio::sync::watch;
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, UsbFileOpEntry};
use std::collections::{HashMap, HashSet};
/// USB file audit configuration
#[derive(Debug, Clone, Default)]
pub struct UsbAuditConfig {
pub enabled: bool,
pub monitored_extensions: Vec<String>,
}
/// Start USB file audit plugin.
/// Detects removable drives and monitors file changes via periodic scanning.
pub async fn start(
mut config_rx: watch::Receiver<UsbAuditConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("USB file audit plugin started");
let mut config = UsbAuditConfig::default();
let mut active_drives: HashSet<String> = HashSet::new();
// Track file listings per drive to detect changes
let mut drive_snapshots: HashMap<String, HashSet<String>> = HashMap::new();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("USB audit config updated: enabled={}", new_config.enabled);
config = new_config;
}
_ = interval.tick() => {
if !config.enabled {
continue;
}
let drives = detect_removable_drives();
// Detect new drives
for drive in &drives {
if !active_drives.contains(drive) {
info!("New removable drive detected: {}", drive);
// Take initial snapshot in blocking thread to avoid freezing
let drive_clone = drive.clone();
let files = tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await;
match files {
Ok(files) => { drive_snapshots.insert(drive.clone(), files); }
Err(e) => { warn!("Failed to scan drive {}: {}", drive, e); }
}
}
}
// Scan existing drives for changes (each in a blocking thread)
for drive in &drives {
let drive_clone = drive.clone();
let current_files = match tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await {
Ok(files) => files,
Err(e) => { warn!("Drive scan failed for {}: {}", drive, e); continue; }
};
if let Some(prev_files) = drive_snapshots.get_mut(drive) {
// Find new files (created)
for file in &current_files {
if !prev_files.contains(file) {
report_file_op(&data_tx, &device_uid, drive, file, "create", &config.monitored_extensions).await;
}
}
// Find deleted files
for file in prev_files.iter() {
if !current_files.contains(file) {
report_file_op(&data_tx, &device_uid, drive, file, "delete", &config.monitored_extensions).await;
}
}
*prev_files = current_files;
} else {
drive_snapshots.insert(drive.clone(), current_files);
}
}
// Clean up removed drives
active_drives.retain(|d| drives.contains(d));
drive_snapshots.retain(|d, _| drives.contains(d));
}
}
}
}
async fn report_file_op(
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
drive: &str,
file_path: &str,
operation: &str,
ext_filter: &[String],
) {
// Check extension filter
let should_report = if ext_filter.is_empty() {
true
} else {
let file_ext = std::path::Path::new(file_path)
.extension()
.and_then(|e| e.to_str())
.map(|e| e.to_lowercase());
file_ext.map_or(true, |ext| {
ext_filter.iter().any(|f| f.to_lowercase() == ext)
})
};
if should_report {
let entry = UsbFileOpEntry {
device_uid: device_uid.to_string(),
usb_serial: None,
drive_letter: Some(drive.to_string()),
operation: operation.to_string(),
file_path: file_path.to_string(),
file_size: None,
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbFileOp, &entry) {
let _ = data_tx.send(frame).await;
}
debug!("USB file op: {} {}", operation, file_path);
}
}
/// Recursively scan a drive and return all file paths
fn scan_drive_files(drive: &str) -> HashSet<String> {
let mut files = HashSet::new();
let max_depth = 3; // Limit recursion depth for performance
scan_dir_recursive(drive, &mut files, 0, max_depth);
files
}
fn scan_dir_recursive(dir: &str, files: &mut HashSet<String>, depth: usize, max_depth: usize) {
if depth >= max_depth {
return;
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
let path_str = path.to_string_lossy().to_string();
if path.is_dir() {
scan_dir_recursive(&path_str, files, depth + 1, max_depth);
} else {
files.insert(path_str);
}
}
}
}
fn detect_removable_drives() -> Vec<String> {
#[cfg(target_os = "windows")]
{
use windows::Win32::Storage::FileSystem::GetDriveTypeW;
use windows::core::PCWSTR;
let mut drives = Vec::new();
let mask = unsafe { windows::Win32::Storage::FileSystem::GetLogicalDrives() };
let mut mask = mask as u32;
let mut letter = b'A';
while mask != 0 {
if mask & 1 != 0 {
let drive_letter = format!("{}:\\", letter as char);
let drive_wide: Vec<u16> = drive_letter.encode_utf16().chain(std::iter::once(0)).collect();
let drive_type = unsafe {
GetDriveTypeW(PCWSTR(drive_wide.as_ptr()))
};
// DRIVE_REMOVABLE = 2
if drive_type == 2 {
drives.push(drive_letter);
}
}
mask >>= 1;
letter += 1;
}
drives
}
#[cfg(not(target_os = "windows"))]
{
Vec::new()
}
}

View File

@@ -0,0 +1,313 @@
use std::sync::{Arc, Mutex};
use tokio::sync::watch;
use tracing::{info, warn, error};
use csm_protocol::WatermarkConfigPayload;
/// Watermark overlay state
struct WatermarkState {
enabled: bool,
content: String,
font_size: u32,
opacity: f64,
color: String,
angle: i32,
hwnd: Option<isize>,
}
impl Default for WatermarkState {
fn default() -> Self {
Self {
enabled: false,
content: String::new(),
font_size: 14,
opacity: 0.15,
color: "#808080".to_string(),
angle: -30,
hwnd: None,
}
}
}
static WATERMARK_STATE: std::sync::OnceLock<Arc<Mutex<WatermarkState>>> = std::sync::OnceLock::new();
const WM_USER_UPDATE: u32 = 0x0401;
const OVERLAY_CLASS_NAME: &str = "CSM_WatermarkOverlay";
/// Start the watermark plugin, listening for config updates.
pub async fn start(mut rx: watch::Receiver<Option<WatermarkConfigPayload>>) {
info!("Watermark plugin started");
let state = WATERMARK_STATE.get_or_init(|| Arc::new(Mutex::new(WatermarkState::default())));
let state_clone = state.clone();
// Spawn a dedicated thread for the Win32 message loop
std::thread::spawn(move || {
#[cfg(target_os = "windows")]
run_overlay_message_loop(state_clone);
#[cfg(not(target_os = "windows"))]
{
let _ = state_clone;
info!("Watermark overlay not supported on this platform");
}
});
loop {
tokio::select! {
result = rx.changed() => {
if result.is_err() {
warn!("Watermark config channel closed");
break;
}
let config = rx.borrow_and_update().clone();
if let Some(cfg) = config {
info!("Watermark config updated: enabled={}, content='{}'", cfg.enabled, cfg.content);
{
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
s.enabled = cfg.enabled;
s.content = cfg.content.clone();
s.font_size = cfg.font_size;
s.opacity = cfg.opacity;
s.color = cfg.color.clone();
s.angle = cfg.angle;
}
// Post message to overlay window thread
#[cfg(target_os = "windows")]
post_overlay_update();
}
}
}
}
}
#[cfg(target_os = "windows")]
fn post_overlay_update() {
use windows::Win32::UI::WindowsAndMessaging::{FindWindowA, PostMessageW};
use windows::Win32::Foundation::{WPARAM, LPARAM};
use windows::core::PCSTR;
unsafe {
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
let hwnd = FindWindowA(PCSTR(class_name.as_ptr()), PCSTR::null());
if hwnd.0 != 0 {
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
}
}
}
#[cfg(target_os = "windows")]
fn run_overlay_message_loop(state: Arc<Mutex<WatermarkState>>) {
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::Foundation::*;
use windows::Win32::System::LibraryLoader::GetModuleHandleA;
use windows::core::PCSTR;
// Register window class
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
let instance = unsafe { GetModuleHandleA(PCSTR::null()) }.unwrap_or_default();
let instance: HINSTANCE = instance.into();
let wc = WNDCLASSA {
style: CS_HREDRAW | CS_VREDRAW,
lpfnWndProc: Some(watermark_wnd_proc),
hInstance: instance,
lpszClassName: PCSTR(class_name.as_ptr()),
hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) },
..Default::default()
};
unsafe { RegisterClassA(&wc); }
let screen_w = unsafe { GetSystemMetrics(SM_CXSCREEN) };
let screen_h = unsafe { GetSystemMetrics(SM_CYSCREEN) };
// WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW
let ex_style = WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW;
let style = WS_POPUP;
let hwnd = unsafe {
CreateWindowExA(
ex_style,
PCSTR(class_name.as_ptr()),
PCSTR("CSM Watermark\0".as_ptr()),
style,
0, 0, screen_w, screen_h,
HWND::default(),
HMENU::default(),
instance,
None,
)
};
if hwnd.0 == 0 {
error!("Failed to create watermark overlay window");
return;
}
{
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
s.hwnd = Some(hwnd.0);
}
info!("Watermark overlay window created (hwnd={})", hwnd.0);
// Post initial update to self — ensures overlay shows even if config
// arrived before window creation completed (race between threads)
unsafe {
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
}
// Message loop
let mut msg = MSG::default();
unsafe {
loop {
match GetMessageA(&mut msg, HWND::default(), 0, 0) {
BOOL(0) | BOOL(-1) => break,
_ => {
TranslateMessage(&msg);
DispatchMessageA(&msg);
}
}
}
}
}
#[cfg(target_os = "windows")]
unsafe extern "system" fn watermark_wnd_proc(
hwnd: windows::Win32::Foundation::HWND,
msg: u32,
wparam: windows::Win32::Foundation::WPARAM,
lparam: windows::Win32::Foundation::LPARAM,
) -> windows::Win32::Foundation::LRESULT {
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::Foundation::*;
match msg {
WM_PAINT => {
if let Some(state) = WATERMARK_STATE.get() {
let s = state.lock().unwrap_or_else(|e| e.into_inner());
if s.enabled && !s.content.is_empty() {
paint_watermark(hwnd, &s);
}
}
LRESULT(0)
}
WM_ERASEBKGND => {
// Fill with black (will be color-keyed to transparent)
let hdc = HDC(wparam.0 as isize);
unsafe {
let rect = {
let mut r = RECT::default();
let _ = GetClientRect(hwnd, &mut r);
r
};
let brush = CreateSolidBrush(COLORREF(0x00000000));
let _ = FillRect(hdc, &rect, brush);
let _ = DeleteObject(brush);
}
LRESULT(1)
}
m if m == WM_USER_UPDATE => {
if let Some(state) = WATERMARK_STATE.get() {
let s = state.lock().unwrap_or_else(|e| e.into_inner());
if s.enabled {
let _ = SetWindowPos(
hwnd, HWND_TOPMOST,
0, 0,
GetSystemMetrics(SM_CXSCREEN),
GetSystemMetrics(SM_CYSCREEN),
SWP_SHOWWINDOW,
);
let alpha = (s.opacity * 255.0).clamp(0.0, 255.0) as u8;
// Color key black background to transparent, apply alpha to text
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY | LWA_ALPHA);
let _ = InvalidateRect(hwnd, None, true);
} else {
let _ = ShowWindow(hwnd, SW_HIDE);
}
}
LRESULT(0)
}
WM_DESTROY => {
PostQuitMessage(0);
LRESULT(0)
}
_ => DefWindowProcA(hwnd, msg, wparam, lparam),
}
}
#[cfg(target_os = "windows")]
fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkState) {
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::core::PCSTR;
unsafe {
let mut ps = PAINTSTRUCT::default();
let hdc = BeginPaint(hwnd, &mut ps);
let color = parse_color(&state.color);
let font_size = state.font_size.max(1);
// Create font with rotation
let font = CreateFontA(
(font_size as i32) * 2,
0,
(state.angle as i32) * 10,
0,
FW_NORMAL.0 as i32,
0, 0, 0,
DEFAULT_CHARSET.0 as u32,
OUT_DEFAULT_PRECIS.0 as u32,
CLIP_DEFAULT_PRECIS.0 as u32,
DEFAULT_QUALITY.0 as u32,
DEFAULT_PITCH.0 as u32 | FF_DONTCARE.0 as u32,
PCSTR("Arial\0".as_ptr()),
);
let old_font = SelectObject(hdc, font);
let _ = SetBkMode(hdc, TRANSPARENT);
let _ = SetTextColor(hdc, color);
// Draw tiled watermark text
let screen_w = GetSystemMetrics(SM_CXSCREEN);
let screen_h = GetSystemMetrics(SM_CYSCREEN);
let content_bytes: Vec<u8> = state.content.bytes().chain(std::iter::once(0)).collect();
let text_slice = &content_bytes[..content_bytes.len().saturating_sub(1)];
let spacing_x = 400i32;
let spacing_y = 200i32;
let mut y = -100i32;
while y < screen_h + 100 {
let mut x = -200i32;
while x < screen_w + 200 {
let _ = TextOutA(hdc, x, y, text_slice);
x += spacing_x;
}
y += spacing_y;
}
SelectObject(hdc, old_font);
let _ = DeleteObject(font);
EndPaint(hwnd, &ps);
}
}
fn parse_color(hex: &str) -> windows::Win32::Foundation::COLORREF {
use windows::Win32::Foundation::COLORREF;
let hex = hex.trim_start_matches('#');
if hex.len() == 6 {
let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(128);
let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(128);
let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(128);
COLORREF((b as u32) << 16 | (g as u32) << 8 | (r as u32))
} else {
COLORREF(0x00808080)
}
}

View File

@@ -0,0 +1,169 @@
use tokio::sync::watch;
use tracing::{info, warn, debug};
use serde::Deserialize;
use std::io;
/// Web filter rule from server
#[derive(Debug, Clone, Deserialize)]
pub struct WebFilterRule {
pub id: i64,
pub rule_type: String,
pub pattern: String,
}
/// Web filter configuration
#[derive(Debug, Clone, Default)]
pub struct WebFilterConfig {
pub enabled: bool,
pub rules: Vec<WebFilterRule>,
}
/// Start web filter plugin.
/// Manages the hosts file to block/allow URLs based on server rules.
pub async fn start(
mut config_rx: watch::Receiver<WebFilterConfig>,
) {
info!("Web filter plugin started");
let mut _config = WebFilterConfig::default();
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("Web filter config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
if new_config.enabled {
match apply_hosts_rules(&new_config.rules) {
Ok(()) => info!("Web filter hosts rules applied ({} rules)", new_config.rules.len()),
Err(e) => warn!("Failed to apply hosts rules: {}", e),
}
} else {
match clear_hosts_rules() {
Ok(()) => info!("Web filter hosts rules cleared"),
Err(e) => warn!("Failed to clear hosts rules: {}", e),
}
}
_config = new_config;
}
}
}
}
const HOSTS_MARKER_START: &str = "# CSM_WEB_FILTER_START";
const HOSTS_MARKER_END: &str = "# CSM_WEB_FILTER_END";
fn apply_hosts_rules(rules: &[WebFilterRule]) -> io::Result<()> {
let hosts_path = get_hosts_path();
let original = std::fs::read_to_string(&hosts_path)?;
// Remove existing CSM block
let clean = remove_csm_block(&original);
// Build new block with block rules
let block_rules: Vec<&WebFilterRule> = rules.iter()
.filter(|r| r.rule_type == "block")
.filter(|r| is_valid_hosts_entry(&r.pattern))
.collect();
if block_rules.is_empty() {
atomic_write(&hosts_path, &clean)?;
return Ok(());
}
let mut new_block = format!("{}\n", HOSTS_MARKER_START);
for rule in &block_rules {
// Redirect blocked domains to 127.0.0.1
new_block.push_str(&format!("127.0.0.1 {}\n", rule.pattern));
}
new_block.push_str(HOSTS_MARKER_END);
new_block.push('\n');
let new_content = format!("{}{}", clean, new_block);
atomic_write(&hosts_path, &new_content)?;
debug!("Applied {} web filter rules to hosts file", block_rules.len());
Ok(())
}
fn clear_hosts_rules() -> io::Result<()> {
let hosts_path = get_hosts_path();
let original = std::fs::read_to_string(&hosts_path)?;
let clean = remove_csm_block(&original);
atomic_write(&hosts_path, &clean)?;
Ok(())
}
/// Write content to a file atomically.
/// On Windows, hosts file may be locked by the DNS cache service,
/// so we use direct overwrite with truncation instead of rename.
fn atomic_write(path: &str, content: &str) -> io::Result<()> {
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.create(true)
.open(path)?;
file.write_all(content.as_bytes())?;
file.sync_data()?;
Ok(())
}
fn remove_csm_block(content: &str) -> String {
// Handle paired markers (normal case)
let start_idx = content.find(HOSTS_MARKER_START);
let end_idx = content.find(HOSTS_MARKER_END);
match (start_idx, end_idx) {
(Some(s), Some(e)) if e > s => {
let mut result = String::new();
result.push_str(&content[..s]);
result.push_str(&content[e + HOSTS_MARKER_END.len()..]);
return result.trim_end().to_string() + "\n";
}
_ => {}
}
// Handle orphan markers: remove any lone START or END marker line
let lines: Vec<&str> = content.lines().collect();
let cleaned: Vec<&str> = lines.iter()
.filter(|line| {
let trimmed = line.trim();
trimmed != HOSTS_MARKER_START && trimmed != HOSTS_MARKER_END
})
.copied()
.collect();
let mut result = cleaned.join("\n");
if !result.ends_with('\n') {
result.push('\n');
}
result
}
fn get_hosts_path() -> String {
#[cfg(target_os = "windows")]
{
r"C:\Windows\System32\drivers\etc\hosts".to_string()
}
#[cfg(not(target_os = "windows"))]
{
"/etc/hosts".to_string()
}
}
/// Validate that a pattern is safe to write to the hosts file.
/// Rejects patterns with whitespace, control chars, or comment markers.
fn is_valid_hosts_entry(pattern: &str) -> bool {
if pattern.is_empty() {
return false;
}
// Reject if contains whitespace, control chars, or comment markers
if pattern.chars().any(|c| c.is_whitespace() || c.is_control() || c == '#') {
return false;
}
// Must look like a hostname (alphanumeric, dots, hyphens, underscores, asterisks)
pattern.chars().all(|c| c.is_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '*')
}

View File

@@ -0,0 +1,12 @@
[package]
name = "csm-protocol"
version.workspace = true
edition.workspace = true
[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
anyhow = { workspace = true }

View File

@@ -0,0 +1,108 @@
use serde::{Deserialize, Serialize};
/// Real-time device status report
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DeviceStatus {
pub device_uid: String,
pub cpu_usage: f64,
pub memory_usage: f64,
pub memory_total_mb: u64,
pub disk_usage: f64,
pub disk_total_mb: u64,
pub network_rx_rate: u64,
pub network_tx_rate: u64,
pub running_procs: u32,
pub top_processes: Vec<ProcessInfo>,
}
/// Top process information
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ProcessInfo {
pub name: String,
pub pid: u32,
pub cpu_usage: f64,
pub memory_mb: u64,
}
/// Hardware asset information
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct HardwareAsset {
pub device_uid: String,
pub cpu_model: String,
pub cpu_cores: u32,
pub memory_total_mb: u64,
pub disk_model: String,
pub disk_total_mb: u64,
pub gpu_model: Option<String>,
pub motherboard: Option<String>,
pub serial_number: Option<String>,
}
/// Software asset information
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SoftwareAsset {
pub device_uid: String,
pub name: String,
pub version: Option<String>,
pub publisher: Option<String>,
pub install_date: Option<String>,
pub install_path: Option<String>,
}
/// USB device event
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbEvent {
pub device_uid: String,
pub event_type: UsbEventType,
pub vendor_id: Option<String>,
pub product_id: Option<String>,
pub serial: Option<String>,
pub device_name: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum UsbEventType {
Inserted,
Removed,
Blocked,
}
/// Asset change event
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AssetChange {
pub device_uid: String,
pub change_type: AssetChangeType,
pub change_detail: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AssetChangeType {
Hardware,
SoftwareAdded,
SoftwareRemoved,
}
/// USB policy (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbPolicy {
pub policy_id: i64,
pub policy_type: UsbPolicyType,
pub allowed_devices: Vec<UsbDevicePattern>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum UsbPolicyType {
AllBlock,
Whitelist,
Blacklist,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbDevicePattern {
pub vendor_id: Option<String>,
pub product_id: Option<String>,
pub serial: Option<String>,
}

View File

@@ -0,0 +1,27 @@
pub mod message;
pub mod device;
// Re-export constants from message module
pub use message::{MAGIC, PROTOCOL_VERSION, FRAME_HEADER_SIZE, MAX_PAYLOAD_SIZE};
// Core frame & message types
pub use message::{
Frame, FrameError, MessageType,
RegisterRequest, RegisterResponse, ClientConfig,
HeartbeatPayload, TaskExecutePayload, ConfigUpdateType,
};
// Device status & asset types
pub use device::{
DeviceStatus, ProcessInfo, HardwareAsset, SoftwareAsset,
UsbEvent, UsbEventType, AssetChange, AssetChangeType,
UsbPolicy, UsbPolicyType, UsbDevicePattern,
};
// Plugin message payloads
pub use message::{
WebAccessLogEntry, UsageDailyReport, AppUsageEntry,
SoftwareViolationReport, UsbFileOpEntry,
WatermarkConfigPayload, PluginControlPayload,
UsbPolicyPayload, UsbDeviceRule,
};

View File

@@ -0,0 +1,417 @@
use serde::{Deserialize, Serialize};
/// Protocol magic bytes: "CSM\0"
pub const MAGIC: [u8; 4] = [0x43, 0x53, 0x4D, 0x00];
/// Current protocol version
pub const PROTOCOL_VERSION: u8 = 0x01;
/// Frame header size: magic(4) + version(1) + type(1) + length(4)
pub const FRAME_HEADER_SIZE: usize = 10;
/// Maximum payload size: 4 MB — prevents memory exhaustion from malicious frames
pub const MAX_PAYLOAD_SIZE: usize = 4 * 1024 * 1024;
/// Binary message types for client-server communication
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageType {
// Client → Server (Core)
Heartbeat = 0x01,
Register = 0x02,
StatusReport = 0x03,
AssetReport = 0x04,
AssetChange = 0x05,
UsbEvent = 0x06,
AlertAck = 0x07,
// Server → Client (Core)
RegisterResponse = 0x08,
PolicyUpdate = 0x10,
ConfigUpdate = 0x11,
TaskExecute = 0x12,
// Plugin: Web Filter (上网拦截)
WebFilterRuleUpdate = 0x20,
WebAccessLog = 0x21,
// Plugin: Usage Timer (时长记录)
UsageReport = 0x30,
AppUsageReport = 0x31,
// Plugin: Software Blocker (软件禁止安装)
SoftwareBlacklist = 0x40,
SoftwareViolation = 0x41,
// Plugin: Popup Blocker (弹窗拦截)
PopupRules = 0x50,
// Plugin: USB File Audit (U盘文件操作记录)
UsbFileOp = 0x60,
// Plugin: Screen Watermark (水印管理)
WatermarkConfig = 0x70,
// Plugin: USB Policy (U盘管控策略)
UsbPolicyUpdate = 0x71,
// Plugin control
PluginEnable = 0x80,
PluginDisable = 0x81,
}
impl TryFrom<u8> for MessageType {
type Error = String;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(Self::Heartbeat),
0x02 => Ok(Self::Register),
0x03 => Ok(Self::StatusReport),
0x04 => Ok(Self::AssetReport),
0x05 => Ok(Self::AssetChange),
0x06 => Ok(Self::UsbEvent),
0x07 => Ok(Self::AlertAck),
0x08 => Ok(Self::RegisterResponse),
0x10 => Ok(Self::PolicyUpdate),
0x11 => Ok(Self::ConfigUpdate),
0x12 => Ok(Self::TaskExecute),
0x20 => Ok(Self::WebFilterRuleUpdate),
0x21 => Ok(Self::WebAccessLog),
0x30 => Ok(Self::UsageReport),
0x31 => Ok(Self::AppUsageReport),
0x40 => Ok(Self::SoftwareBlacklist),
0x41 => Ok(Self::SoftwareViolation),
0x50 => Ok(Self::PopupRules),
0x60 => Ok(Self::UsbFileOp),
0x70 => Ok(Self::WatermarkConfig),
0x71 => Ok(Self::UsbPolicyUpdate),
0x80 => Ok(Self::PluginEnable),
0x81 => Ok(Self::PluginDisable),
_ => Err(format!("Unknown message type: 0x{:02X}", value)),
}
}
}
/// A wire-format frame for transmission over TCP
#[derive(Debug, Clone)]
pub struct Frame {
pub version: u8,
pub msg_type: MessageType,
pub payload: Vec<u8>,
}
impl Frame {
/// Create a new frame with the current protocol version
pub fn new(msg_type: MessageType, payload: Vec<u8>) -> Self {
Self {
version: PROTOCOL_VERSION,
msg_type,
payload,
}
}
/// Create a new frame with JSON-serialized payload
pub fn new_json<T: Serialize>(msg_type: MessageType, data: &T) -> anyhow::Result<Self> {
let payload = serde_json::to_vec(data)?;
Ok(Self::new(msg_type, payload))
}
/// Encode frame to bytes for transmission
/// Format: MAGIC(4) + VERSION(1) + TYPE(1) + LENGTH(4) + PAYLOAD(var)
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + self.payload.len());
buf.extend_from_slice(&MAGIC);
buf.push(self.version);
buf.push(self.msg_type as u8);
buf.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
buf.extend_from_slice(&self.payload);
buf
}
/// Decode frame from bytes. Returns Ok(Some(frame)) when a complete frame is available.
pub fn decode(data: &[u8]) -> Result<Option<Frame>, FrameError> {
if data.len() < FRAME_HEADER_SIZE {
return Ok(None);
}
if data[0..4] != MAGIC {
return Err(FrameError::InvalidMagic);
}
let version = data[4];
let msg_type_byte = data[5];
let payload_len = u32::from_be_bytes([data[6], data[7], data[8], data[9]]) as usize;
if payload_len > MAX_PAYLOAD_SIZE {
return Err(FrameError::PayloadTooLarge(payload_len));
}
if data.len() < FRAME_HEADER_SIZE + payload_len {
return Ok(None);
}
let msg_type = MessageType::try_from(msg_type_byte)
.map_err(|e| FrameError::UnknownMessageType(msg_type_byte, e))?;
let payload = data[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len].to_vec();
Ok(Some(Frame {
version,
msg_type,
payload,
}))
}
/// Deserialize the payload as JSON
pub fn decode_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
serde_json::from_slice(&self.payload)
}
/// Total encoded size of this frame
pub fn encoded_size(&self) -> usize {
FRAME_HEADER_SIZE + self.payload.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum FrameError {
#[error("Invalid magic bytes in frame header")]
InvalidMagic,
#[error("Unknown message type: 0x{0:02X} - {1}")]
UnknownMessageType(u8, String),
#[error("Payload too large: {0} bytes (max {})", MAX_PAYLOAD_SIZE)]
PayloadTooLarge(usize),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
// ==================== Core Message Payloads ====================
/// Registration request payload
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterRequest {
pub device_uid: String,
pub hostname: String,
pub registration_token: String,
pub os_version: String,
pub mac_address: Option<String>,
}
/// Registration response payload
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterResponse {
pub device_secret: String,
pub config: ClientConfig,
}
/// Server-pushed client configuration
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ClientConfig {
pub heartbeat_interval_secs: u64,
pub status_report_interval_secs: u64,
pub asset_report_interval_secs: u64,
pub server_version: String,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
heartbeat_interval_secs: 30,
status_report_interval_secs: 60,
asset_report_interval_secs: 86400,
server_version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}
/// Heartbeat payload (minimal)
#[derive(Debug, Serialize, Deserialize)]
pub struct HeartbeatPayload {
pub device_uid: String,
pub timestamp: String,
pub hmac: String,
}
/// Task execution request (Server → Client)
#[derive(Debug, Serialize, Deserialize)]
pub struct TaskExecutePayload {
pub task_type: String,
pub params: serde_json::Value,
}
/// Config update types (Server → Client)
#[derive(Debug, Serialize, Deserialize)]
pub enum ConfigUpdateType {
UpdateIntervals { heartbeat: u64, status: u64, asset: u64 },
TlsCertRotate,
SelfDestruct,
}
// ==================== Plugin Message Payloads ====================
/// Plugin: Web Access Log entry (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct WebAccessLogEntry {
pub device_uid: String,
pub url: String,
pub action: String, // "allowed" | "blocked"
pub timestamp: String,
}
/// Plugin: Daily Usage Report (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct UsageDailyReport {
pub device_uid: String,
pub date: String,
pub total_active_minutes: u32,
pub total_idle_minutes: u32,
pub first_active_at: Option<String>,
pub last_active_at: Option<String>,
}
/// Plugin: App Usage Report (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct AppUsageEntry {
pub device_uid: String,
pub date: String,
pub app_name: String,
pub usage_minutes: u32,
}
/// Plugin: Software Violation (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct SoftwareViolationReport {
pub device_uid: String,
pub software_name: String,
pub action_taken: String, // "blocked_install" | "auto_uninstalled" | "alerted"
pub timestamp: String,
}
/// Plugin: USB File Operation (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct UsbFileOpEntry {
pub device_uid: String,
pub usb_serial: Option<String>,
pub drive_letter: Option<String>,
pub operation: String, // "create" | "delete" | "rename" | "modify"
pub file_path: String,
pub file_size: Option<u64>,
pub timestamp: String,
}
/// Plugin: Watermark Config (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct WatermarkConfigPayload {
pub content: String,
pub font_size: u32,
pub opacity: f64,
pub color: String,
pub angle: i32,
pub enabled: bool,
}
/// Plugin enable/disable command (Server → Client)
#[derive(Debug, Serialize, Deserialize)]
pub struct PluginControlPayload {
pub plugin_name: String,
pub enabled: bool,
}
/// Plugin: USB Policy Config (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbPolicyPayload {
pub policy_type: String, // "all_block" | "whitelist" | "blacklist"
pub enabled: bool,
pub rules: Vec<UsbDeviceRule>,
}
/// A single USB device rule for whitelist/blacklist matching
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbDeviceRule {
pub vendor_id: Option<String>,
pub product_id: Option<String>,
pub serial: Option<String>,
pub device_name: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_encode_decode_roundtrip() {
let original = Frame::new(MessageType::Heartbeat, b"test payload".to_vec());
let encoded = original.encode();
let decoded = Frame::decode(&encoded).unwrap().unwrap();
assert_eq!(decoded.version, PROTOCOL_VERSION);
assert_eq!(decoded.msg_type, MessageType::Heartbeat);
assert_eq!(decoded.payload, b"test payload");
}
#[test]
fn test_frame_decode_incomplete_data() {
let data = [0x43, 0x53, 0x4D, 0x01, 0x01];
let result = Frame::decode(&data).unwrap();
assert!(result.is_none());
}
#[test]
fn test_frame_decode_invalid_magic() {
let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00];
let result = Frame::decode(&data);
assert!(matches!(result, Err(FrameError::InvalidMagic)));
}
#[test]
fn test_json_frame_roundtrip() {
let heartbeat = HeartbeatPayload {
device_uid: "test-uid".to_string(),
timestamp: "2026-04-03T12:00:00Z".to_string(),
hmac: "abc123".to_string(),
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat).unwrap();
let encoded = frame.encode();
let decoded = Frame::decode(&encoded).unwrap().unwrap();
let parsed: HeartbeatPayload = decoded.decode_payload().unwrap();
assert_eq!(parsed.device_uid, "test-uid");
assert_eq!(parsed.hmac, "abc123");
}
#[test]
fn test_plugin_message_types_roundtrip() {
let types = [
MessageType::WebAccessLog,
MessageType::UsageReport,
MessageType::AppUsageReport,
MessageType::SoftwareViolation,
MessageType::UsbFileOp,
MessageType::WatermarkConfig,
MessageType::PluginEnable,
MessageType::PluginDisable,
];
for mt in types {
let frame = Frame::new(mt, vec![1, 2, 3]);
let encoded = frame.encode();
let decoded = Frame::decode(&encoded).unwrap().unwrap();
assert_eq!(decoded.msg_type, mt);
}
}
#[test]
fn test_frame_decode_payload_too_large() {
// Craft a header that claims a 10 MB payload
let mut data = Vec::with_capacity(FRAME_HEADER_SIZE);
data.extend_from_slice(&MAGIC);
data.push(PROTOCOL_VERSION);
data.push(MessageType::Heartbeat as u8);
data.extend_from_slice(&(10 * 1024 * 1024u32).to_be_bytes());
// Don't actually include the payload — the size check should reject first
let result = Frame::decode(&data);
assert!(matches!(result, Err(FrameError::PayloadTooLarge(_))));
}
}

51
crates/server/Cargo.toml Normal file
View File

@@ -0,0 +1,51 @@
[package]
name = "csm-server"
version.workspace = true
edition.workspace = true
[dependencies]
csm-protocol = { path = "../protocol" }
# Async runtime
tokio = { workspace = true }
# Web framework
axum = { version = "0.7", features = ["ws"] }
tower-http = { version = "0.5", features = ["cors", "fs", "trace", "compression-gzip", "set-header"] }
tower = "0.4"
# Database
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
# TLS
rustls = "0.23"
tokio-rustls = "0.26"
rustls-pemfile = "2"
rustls-pki-types = "1"
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
# Auth
jsonwebtoken = "9"
bcrypt = "0.15"
# Notifications
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder", "hostname"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
# Config & logging
toml = "0.8"
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
# Utilities
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
include_dir = "0.7"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"

118
crates/server/src/alert.rs Normal file
View File

@@ -0,0 +1,118 @@
use crate::AppState;
use tracing::{info, warn, error};
/// Background task for data cleanup and alert processing
pub async fn cleanup_task(state: AppState) {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600));
loop {
interval.tick().await;
// Cleanup old status history
if let Err(e) = sqlx::query(
"DELETE FROM device_status_history WHERE reported_at < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.status_history_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup status history: {}", e);
}
// Cleanup old USB events
if let Err(e) = sqlx::query(
"DELETE FROM usb_events WHERE event_time < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.usb_events_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup USB events: {}", e);
}
// Cleanup handled alert records
if let Err(e) = sqlx::query(
"DELETE FROM alert_records WHERE handled = 1 AND triggered_at < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.alert_records_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup alert records: {}", e);
}
// Mark devices as offline if no heartbeat for 2 minutes
if let Err(e) = sqlx::query(
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"
)
.execute(&state.db)
.await
{
error!("Failed to mark stale devices offline: {}", e);
}
// SQLite WAL checkpoint
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
.execute(&state.db)
.await
{
warn!("WAL checkpoint failed: {}", e);
}
info!("Cleanup cycle completed");
}
}
/// Send email notification
pub async fn send_email(
smtp_config: &crate::config::SmtpConfig,
to: &str,
subject: &str,
body: &str,
) -> anyhow::Result<()> {
use lettre::message::header::ContentType;
use lettre::{Message, SmtpTransport, Transport};
use lettre::transport::smtp::authentication::Credentials;
let email = Message::builder()
.from(smtp_config.from.parse()?)
.to(to.parse()?)
.subject(subject)
.header(ContentType::TEXT_HTML)
.body(body.to_string())?;
let creds = Credentials::new(
smtp_config.username.clone(),
smtp_config.password.clone(),
);
let mailer = SmtpTransport::starttls_relay(&smtp_config.host)?
.port(smtp_config.port)
.credentials(creds)
.build();
mailer.send(&email)?;
Ok(())
}
/// Shared HTTP client for webhook notifications.
/// Lazily initialized once and reused across calls to benefit from connection pooling.
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
fn webhook_client() -> &'static reqwest::Client {
WEBHOOK_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new())
})
}
/// Send webhook notification
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
webhook_client().post(url)
.json(payload)
.send()
.await?;
Ok(())
}

View File

@@ -0,0 +1,243 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
use super::auth::Claims;
#[derive(Debug, Deserialize)]
pub struct AlertRecordListParams {
pub device_uid: Option<String>,
pub alert_type: Option<String>,
pub severity: Option<String>,
pub handled: Option<i32>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_rules(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
FROM alert_rules ORDER BY created_at DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name": r.get::<String, _>("name"),
"rule_type": r.get::<String, _>("rule_type"),
"condition": r.get::<String, _>("condition"),
"severity": r.get::<String, _>("severity"),
"enabled": r.get::<i32, _>("enabled"),
"notify_email": r.get::<Option<String>, _>("notify_email"),
"notify_webhook": r.get::<Option<String>, _>("notify_webhook"),
"created_at": r.get::<String, _>("created_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"rules": items,
})))
}
Err(e) => Json(ApiResponse::internal_error("query alert rules", e)),
}
}
pub async fn list_records(
State(state): State<AppState>,
Query(params): Query<AlertRecordListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None (Axum deserializes `key=` as Some(""))
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let alert_type = params.alert_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
let severity = params.severity.as_deref().filter(|s| !s.is_empty()).map(String::from);
let handled = params.handled;
let rows = sqlx::query(
"SELECT id, rule_id, device_uid, alert_type, severity, detail, handled, handled_by, handled_at, triggered_at
FROM alert_records WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR alert_type = ?)
AND (? IS NULL OR severity = ?)
AND (? IS NULL OR handled = ?)
ORDER BY triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&alert_type).bind(&alert_type)
.bind(&severity).bind(&severity)
.bind(&handled).bind(&handled)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_id": r.get::<Option<i64>, _>("rule_id"),
"device_uid": r.get::<Option<String>, _>("device_uid"),
"alert_type": r.get::<String, _>("alert_type"),
"severity": r.get::<String, _>("severity"),
"detail": r.get::<String, _>("detail"),
"handled": r.get::<i32, _>("handled"),
"handled_by": r.get::<Option<String>, _>("handled_by"),
"handled_at": r.get::<Option<String>, _>("handled_at"),
"triggered_at": r.get::<String, _>("triggered_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"records": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query alert records", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub name: String,
pub rule_type: String,
pub condition: String,
pub severity: Option<String>,
pub notify_email: Option<String>,
pub notify_webhook: Option<String>,
}
pub async fn create_rule(
State(state): State<AppState>,
Json(body): Json<CreateRuleRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
let result = sqlx::query(
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&body.name)
.bind(&body.rule_type)
.bind(&body.condition)
.bind(&severity)
.bind(&body.notify_email)
.bind(&body.notify_webhook)
.execute(&state.db)
.await;
match result {
Ok(r) => (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
"id": r.last_insert_rowid(),
})))),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create alert rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest {
pub name: Option<String>,
pub rule_type: Option<String>,
pub condition: Option<String>,
pub severity: Option<String>,
pub enabled: Option<i32>,
pub notify_email: Option<String>,
pub notify_webhook: Option<String>,
}
pub async fn update_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateRuleRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let existing = sqlx::query("SELECT * FROM alert_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Rule not found")),
Err(e) => return Json(ApiResponse::internal_error("query alert rule", e)),
};
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
let condition = body.condition.unwrap_or_else(|| existing.get::<String, _>("condition"));
let severity = body.severity.unwrap_or_else(|| existing.get::<String, _>("severity"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
let result = sqlx::query(
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&name)
.bind(&rule_type)
.bind(&condition)
.bind(&severity)
.bind(enabled)
.bind(&notify_email)
.bind(&notify_webhook)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(_) => Json(ApiResponse::ok(serde_json::json!({"updated": true}))),
Err(e) => Json(ApiResponse::internal_error("update alert rule", e)),
}
}
pub async fn delete_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let result = sqlx::query("DELETE FROM alert_rules WHERE id = ?")
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
} else {
Json(ApiResponse::error("Rule not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("delete alert rule", e)),
}
}
pub async fn handle_record(
State(state): State<AppState>,
Path(id): Path<i64>,
claims: axum::Extension<Claims>,
) -> Json<ApiResponse<serde_json::Value>> {
let handled_by = &claims.username;
let result = sqlx::query(
"UPDATE alert_records SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
)
.bind(handled_by)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
Json(ApiResponse::ok(serde_json::json!({"handled": true})))
} else {
Json(ApiResponse::error("Alert record not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("handle alert record", e)),
}
}

View File

@@ -0,0 +1,143 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
#[derive(Debug, Deserialize)]
pub struct AssetListParams {
pub device_uid: Option<String>,
pub search: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_hardware(
State(state): State<AppState>,
Query(params): Query<AssetListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb,
gpu_model, motherboard, serial_number, reported_at
FROM hardware_assets WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR cpu_model LIKE '%' || ? || '%' OR gpu_model LIKE '%' || ? || '%')
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"cpu_model": r.get::<String, _>("cpu_model"),
"cpu_cores": r.get::<i32, _>("cpu_cores"),
"memory_total_mb": r.get::<i64, _>("memory_total_mb"),
"disk_model": r.get::<String, _>("disk_model"),
"disk_total_mb": r.get::<i64, _>("disk_total_mb"),
"gpu_model": r.get::<Option<String>, _>("gpu_model"),
"motherboard": r.get::<Option<String>, _>("motherboard"),
"serial_number": r.get::<Option<String>, _>("serial_number"),
"reported_at": r.get::<String, _>("reported_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"hardware": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query hardware assets", e)),
}
}
pub async fn list_software(
State(state): State<AppState>,
Query(params): Query<AssetListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, name, version, publisher, install_date, install_path
FROM software_assets WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR name LIKE '%' || ? || '%' OR publisher LIKE '%' || ? || '%')
ORDER BY name ASC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"name": r.get::<String, _>("name"),
"version": r.get::<Option<String>, _>("version"),
"publisher": r.get::<Option<String>, _>("publisher"),
"install_date": r.get::<Option<String>, _>("install_date"),
"install_path": r.get::<Option<String>, _>("install_path"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"software": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query software assets", e)),
}
}
pub async fn list_changes(
State(state): State<AppState>,
Query(page): Query<Pagination>,
) -> Json<ApiResponse<serde_json::Value>> {
let offset = page.offset();
let limit = page.limit();
let rows = sqlx::query(
"SELECT id, device_uid, change_type, change_detail, detected_at
FROM asset_changes ORDER BY detected_at DESC LIMIT ? OFFSET ?"
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"change_type": r.get::<String, _>("change_type"),
"change_detail": r.get::<String, _>("change_detail"),
"detected_at": r.get::<String, _>("detected_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"changes": items,
"page": page.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query asset changes", e)),
}
}

View File

@@ -0,0 +1,295 @@
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::Mutex;
use crate::AppState;
use super::ApiResponse;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: i64,
pub username: String,
pub role: String,
pub exp: u64,
pub iat: u64,
pub token_type: String,
/// Random family ID for refresh token rotation detection
pub family: String,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: String,
pub user: UserInfo,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct UserInfo {
pub id: i64,
pub username: String,
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
/// In-memory rate limiter for login attempts
#[derive(Clone, Default)]
pub struct LoginRateLimiter {
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
}
impl LoginRateLimiter {
pub fn new() -> Self {
Self::default()
}
/// Returns true if the request should be rate-limited
pub async fn is_limited(&self, key: &str) -> bool {
let mut attempts = self.attempts.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(300); // 5-minute window
let max_attempts = 10u32;
if let Some((first_attempt, count)) = attempts.get_mut(key) {
if now.duration_since(*first_attempt) > window {
// Window expired, reset
*first_attempt = now;
*count = 1;
false
} else if *count >= max_attempts {
true // Rate limited
} else {
*count += 1;
false
}
} else {
attempts.insert(key.to_string(), (now, 1));
// Cleanup old entries periodically
if attempts.len() > 1000 {
let cutoff = now - window;
attempts.retain(|_, (t, _)| *t > cutoff);
}
false
}
}
}
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
// Rate limit check
if state.login_limiter.is_limited(&req.username).await {
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
}
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
"SELECT id, username, role FROM users WHERE username = ?"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let user = match user {
Some(u) => u,
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
};
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
.bind(user.id)
.fetch_one(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
}
let now = chrono::Utc::now().timestamp() as u64;
let family = uuid::Uuid::new_v4().to_string();
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
)
.bind(user.id)
.bind(format!("User {} logged in", user.username))
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
}
pub async fn refresh(
State(state): State<AppState>,
Json(req): Json<RefreshRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
let claims = decode::<Claims>(
&req.refresh_token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "refresh" {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
}
// Check if this refresh token family has been revoked (reuse detection)
let revoked: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
)
.bind(&claims.claims.family)
.fetch_one(&state.db)
.await
.unwrap_or(0) > 0;
if revoked {
// Token reuse detected — revoke entire family and force re-login
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.execute(&state.db)
.await;
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
}
let user = UserInfo {
id: claims.claims.sub,
username: claims.claims.username,
role: claims.claims.role,
};
// Rotate: new family for each refresh
let new_family = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().timestamp() as u64;
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
// Revoke old family
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.claims.family)
.bind(claims.claims.sub)
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
}
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
let claims = Claims {
sub: user.id,
username: user.username.clone(),
role: user.role.clone(),
exp: now + ttl,
iat: now,
token_type: token_type.to_string(),
family: family.to_string(),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// Axum middleware: require valid JWT access token
pub async fn require_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let token = match auth_header {
Some(t) => t,
None => return Err(StatusCode::UNAUTHORIZED),
};
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "access" {
return Err(StatusCode::UNAUTHORIZED);
}
// Inject claims into request extensions for handlers to use
request.extensions_mut().insert(claims.claims);
Ok(next.run(request).await)
}
/// Axum middleware: require admin role for write operations + audit log
pub async fn require_admin(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let claims = request.extensions()
.get::<Claims>()
.ok_or(StatusCode::UNAUTHORIZED)?;
if claims.role != "admin" {
return Err(StatusCode::FORBIDDEN);
}
// Capture audit info before running handler
let method = request.method().clone();
let path = request.uri().path().to_string();
let user_id = claims.sub;
let username = claims.username.clone();
let response = next.run(request).await;
// Record admin action to audit log (fire and forget — don't block response)
let status = response.status();
if status.is_success() {
let action = format!("{} {}", method, path);
let detail = format!("by {}", username);
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)"
)
.bind(user_id)
.bind(&action)
.bind(&detail)
.execute(&state.db)
.await;
}
Ok(response)
}

View File

@@ -0,0 +1,263 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
#[derive(Debug, Deserialize)]
pub struct DeviceListParams {
pub status: Option<String>,
pub group: Option<String>,
pub search: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct DeviceRow {
pub id: i64,
pub device_uid: String,
pub hostname: String,
pub ip_address: String,
pub mac_address: Option<String>,
pub os_version: Option<String>,
pub client_version: Option<String>,
pub status: String,
pub last_heartbeat: Option<String>,
pub registered_at: Option<String>,
pub group_name: Option<String>,
}
pub async fn list(
State(state): State<AppState>,
Query(params): Query<DeviceListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None (Axum deserializes `status=` as Some(""))
let status = params.status.as_deref().filter(|s| !s.is_empty()).map(String::from);
let group = params.group.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let devices = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
.bind(&search).bind(&search).bind(&search)
.fetch_one(&state.db)
.await
.unwrap_or(0);
match devices {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"devices": rows,
"total": total,
"page": params.page.unwrap_or(1),
"page_size": limit,
}))),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn get_detail(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let device = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE device_uid = ?"
)
.bind(&uid)
.fetch_optional(&state.db)
.await;
match device {
Ok(Some(d)) => Json(ApiResponse::ok(serde_json::to_value(d).unwrap_or_default())),
Ok(None) => Json(ApiResponse::error("Device not found")),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
#[derive(Debug, Serialize, sqlx::FromRow)]
struct StatusRow {
pub cpu_usage: f64,
pub memory_usage: f64,
pub memory_total_mb: i64,
pub disk_usage: f64,
pub disk_total_mb: i64,
pub network_rx_rate: i64,
pub network_tx_rate: i64,
pub running_procs: i32,
pub top_processes: Option<String>,
pub reported_at: String,
}
pub async fn get_status(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let status = sqlx::query_as::<_, StatusRow>(
"SELECT cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb,
network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at
FROM device_status WHERE device_uid = ?"
)
.bind(&uid)
.fetch_optional(&state.db)
.await;
match status {
Ok(Some(s)) => {
let mut val = serde_json::to_value(&s).unwrap_or_default();
// Parse top_processes JSON string back to array
if let Some(tp_str) = &s.top_processes {
if let Ok(tp) = serde_json::from_str::<serde_json::Value>(tp_str) {
val["top_processes"] = tp;
}
}
Json(ApiResponse::ok(val))
}
Ok(None) => Json(ApiResponse::error("No status data found")),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn get_history(
State(state): State<AppState>,
Path(uid): Path<String>,
Query(page): Query<Pagination>,
) -> Json<ApiResponse<serde_json::Value>> {
let offset = page.offset();
let limit = page.limit();
let rows = sqlx::query(
"SELECT cpu_usage, memory_usage, disk_usage, running_procs, reported_at
FROM device_status_history WHERE device_uid = ?
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
)
.bind(&uid)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| {
serde_json::json!({
"cpu_usage": r.get::<f64, _>("cpu_usage"),
"memory_usage": r.get::<f64, _>("memory_usage"),
"disk_usage": r.get::<f64, _>("disk_usage"),
"running_procs": r.get::<i32, _>("running_procs"),
"reported_at": r.get::<String, _>("reported_at"),
})
}).collect();
Json(ApiResponse::ok(serde_json::json!({
"history": items,
"page": page.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn remove(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<()>> {
// If client is connected, send self-destruct command
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
// Delete device and all associated data in a transaction
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(e) => return Json(ApiResponse::internal_error("begin transaction", e)),
};
// Delete status history
if let Err(e) = sqlx::query("DELETE FROM device_status_history WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await
{
return Json(ApiResponse::internal_error("remove device history", e));
}
// Delete current status
if let Err(e) = sqlx::query("DELETE FROM device_status WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await
{
return Json(ApiResponse::internal_error("remove device status", e));
}
// Delete plugin-related data
let cleanup_tables = [
"hardware_assets",
"usb_events",
"usb_file_operations",
"usage_daily",
"app_usage_daily",
"software_violations",
"web_access_log",
"popup_block_stats",
];
for table in &cleanup_tables {
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
.bind(&uid)
.execute(&mut *tx)
.await
{
tracing::warn!("Failed to clean {} for device {}: {}", table, uid, e);
}
}
// Finally delete the device itself
let delete_result = sqlx::query("DELETE FROM devices WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await;
match delete_result {
Ok(r) if r.rows_affected() > 0 => {
if let Err(e) = tx.commit().await {
return Json(ApiResponse::internal_error("commit device deletion", e));
}
state.clients.unregister(&uid).await;
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Device not found")),
Err(e) => Json(ApiResponse::internal_error("remove device", e)),
}
}

View File

@@ -0,0 +1,120 @@
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
use serde::{Deserialize, Serialize};
use crate::AppState;
pub mod auth;
pub mod devices;
pub mod assets;
pub mod usb;
pub mod alerts;
pub mod plugins;
pub fn routes(state: AppState) -> Router<AppState> {
let public = Router::new()
.route("/api/auth/login", post(auth::login))
.route("/api/auth/refresh", post(auth::refresh))
.route("/health", get(health_check))
.with_state(state.clone());
// Read-only routes (any authenticated user)
let read_routes = Router::new()
// Devices
.route("/api/devices", get(devices::list))
.route("/api/devices/:uid", get(devices::get_detail))
.route("/api/devices/:uid/status", get(devices::get_status))
.route("/api/devices/:uid/history", get(devices::get_history))
// Assets
.route("/api/assets/hardware", get(assets::list_hardware))
.route("/api/assets/software", get(assets::list_software))
.route("/api/assets/changes", get(assets::list_changes))
// USB (read)
.route("/api/usb/events", get(usb::list_events))
.route("/api/usb/policies", get(usb::list_policies))
// Alerts (read)
.route("/api/alerts/rules", get(alerts::list_rules))
.route("/api/alerts/records", get(alerts::list_records))
// Plugin read routes
.merge(plugins::read_routes())
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// Write routes (admin only)
let write_routes = Router::new()
// Devices
.route("/api/devices/:uid", delete(devices::remove))
// USB (write)
.route("/api/usb/policies", post(usb::create_policy))
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
// Alerts (write)
.route("/api/alerts/rules", post(alerts::create_rule))
.route("/api/alerts/rules/:id", put(alerts::update_rule).delete(alerts::delete_rule))
.route("/api/alerts/records/:id/handle", put(alerts::handle_record))
// Plugin write routes (already has require_admin layer internally)
.merge(plugins::write_routes())
// Layer order: outer (require_admin) runs AFTER inner (require_auth)
// so require_auth sets Claims extension first, then require_admin checks it
.layer(middleware::from_fn_with_state(state.clone(), auth::require_admin))
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// WebSocket has its own JWT auth via query parameter
let ws_router = Router::new()
.route("/ws", get(crate::ws::ws_handler))
.with_state(state.clone());
Router::new()
.merge(public)
.merge(read_routes)
.merge(write_routes)
.merge(ws_router)
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
}
async fn health_check() -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok",
})
}
/// Standard API response envelope
#[derive(Serialize)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub data: Option<T>,
pub error: Option<String>,
}
impl<T: Serialize> ApiResponse<T> {
pub fn ok(data: T) -> Self {
Self { success: true, data: Some(data), error: None }
}
pub fn error(msg: impl Into<String>) -> Self {
Self { success: false, data: None, error: Some(msg.into()) }
}
/// Log internal error and return sanitized message to client
pub fn internal_error(context: &str, e: impl std::fmt::Display) -> Self {
tracing::error!("{}: {}", context, e);
Self { success: false, data: None, error: Some("Internal server error".to_string()) }
}
}
/// Pagination parameters
#[derive(Debug, Deserialize)]
pub struct Pagination {
pub page: Option<u32>,
pub page_size: Option<u32>,
}
impl Pagination {
pub fn offset(&self) -> u32 {
self.page.unwrap_or(1).saturating_sub(1) * self.limit()
}
pub fn limit(&self) -> u32 {
self.page_size.unwrap_or(20).min(100)
}
}

View File

@@ -0,0 +1,49 @@
pub mod web_filter;
pub mod usage_timer;
pub mod software_blocker;
pub mod popup_blocker;
pub mod usb_file_audit;
pub mod watermark;
use axum::{Router, routing::{get, post, put}};
use crate::AppState;
/// Read-only plugin routes (accessible by admin + viewer)
pub fn read_routes() -> Router<AppState> {
Router::new()
// Web Filter
.route("/api/plugins/web-filter/rules", get(web_filter::list_rules))
.route("/api/plugins/web-filter/log", get(web_filter::list_access_log))
// Usage Timer
.route("/api/plugins/usage-timer/daily", get(usage_timer::list_daily))
.route("/api/plugins/usage-timer/app-usage", get(usage_timer::list_app_usage))
.route("/api/plugins/usage-timer/leaderboard", get(usage_timer::leaderboard))
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
// USB File Audit
.route("/api/plugins/usb-file-audit/log", get(usb_file_audit::list_operations))
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
// Watermark
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
}
/// Write plugin routes (admin only — require_admin middleware applied by caller)
pub fn write_routes() -> Router<AppState> {
Router::new()
// Web Filter
.route("/api/plugins/web-filter/rules", post(web_filter::create_rule))
.route("/api/plugins/web-filter/rules/:id", put(web_filter::update_rule).delete(web_filter::delete_rule))
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
// Watermark
.route("/api/plugins/watermark/config", post(watermark::create_config))
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
}

View File

@@ -0,0 +1,155 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String, // "block" | "allow"
pub window_title: Option<String>,
pub window_class: Option<String>,
pub process_name: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"window_title": r.get::<Option<String>,_>("window_title"),
"window_class": r.get::<Option<String>,_>("window_class"),
"process_name": r.get::<Option<String>,_>("process_name"),
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query popup filter rules", e)),
}
}
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if !matches!(req.rule_type.as_str(), "block" | "allow") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
let has_filter = req.window_title.as_ref().map_or(false, |s| !s.is_empty())
|| req.window_class.as_ref().map_or(false, |s| !s.is_empty())
|| req.process_name.as_ref().map_or(false, |s| !s.is_empty());
if !has_filter {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
}
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_popup_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create popup filter rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest { pub window_title: Option<String>, pub window_class: Option<String>, pub process_name: Option<String>, pub enabled: Option<bool> }
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM popup_filter_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query popup filter rule", e)),
};
let window_title = body.window_title.or_else(|| existing.get::<Option<String>, _>("window_title"));
let window_class = body.window_class.or_else(|| existing.get::<Option<String>, _>("window_class"));
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
.bind(&window_title)
.bind(&window_class)
.bind(&process_name)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_popup_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update popup filter rule", e)),
}
}
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM popup_filter_rules WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM popup_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_popup_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
pub async fn list_stats(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT device_uid, blocked_count, date FROM popup_block_stats ORDER BY date DESC LIMIT 30")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"stats": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "blocked_count": r.get::<i32,_>("blocked_count"),
"date": r.get::<String,_>("date")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query popup block stats", e)),
}
}
async fn fetch_popup_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"window_title": r.get::<Option<String>,_>("window_title"),
"window_class": r.get::<Option<String>,_>("window_class"),
"process_name": r.get::<Option<String>,_>("process_name"),
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,155 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateBlacklistRequest {
pub name_pattern: String,
pub category: Option<String>,
pub action: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
"category": r.get::<Option<String>,_>("category"), "action": r.get::<String,_>("action"),
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software blacklist", e)),
}
}
pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<CreateBlacklistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let action = req.action.unwrap_or_else(|| "block".to_string());
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
}
if !matches!(action.as_str(), "block" | "alert") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("action must be 'block' or 'alert'")));
}
if let Some(ref cat) = req.category {
if !matches!(cat.as_str(), "game" | "social" | "vpn" | "mining" | "custom") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid category")));
}
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO software_blacklist (name_pattern, category, action, target_type, target_id) VALUES (?,?,?,?,?)")
.bind(&req.name_pattern).bind(&req.category).bind(&action).bind(&target_type).bind(&req.target_id)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateBlacklistRequest { pub name_pattern: Option<String>, pub action: Option<String>, pub enabled: Option<bool> }
pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateBlacklistRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM software_blacklist WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query software blacklist", e)),
};
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern)
.bind(&action)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update software blacklist", e)),
}
}
pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM software_blacklist WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
#[derive(Debug, Deserialize)]
pub struct ViolationFilters { pub device_uid: Option<String> }
pub async fn list_violations(State(state): State<AppState>, Query(f): Query<ViolationFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, software_name, action_taken, timestamp FROM software_violations WHERE (? IS NULL OR device_uid=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"violations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"software_name": r.get::<String,_>("software_name"), "action_taken": r.get::<String,_>("action_taken"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software violations", e)),
}
}
async fn fetch_blacklist_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"), "action": r.get::<String,_>("action")
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,60 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct DailyFilters { pub device_uid: Option<String>, pub start_date: Option<String>, pub end_date: Option<String> }
pub async fn list_daily(State(state): State<AppState>, Query(f): Query<DailyFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at
FROM usage_daily WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date>=?) AND (? IS NULL OR date<=?)
ORDER BY date DESC LIMIT 90")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.start_date).bind(&f.start_date)
.bind(&f.end_date).bind(&f.end_date)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"daily": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"date": r.get::<String,_>("date"), "total_active_minutes": r.get::<i32,_>("total_active_minutes"),
"total_idle_minutes": r.get::<i32,_>("total_idle_minutes"),
"first_active_at": r.get::<Option<String>,_>("first_active_at"),
"last_active_at": r.get::<Option<String>,_>("last_active_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query daily usage", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct AppUsageFilters { pub device_uid: Option<String>, pub date: Option<String> }
pub async fn list_app_usage(State(state): State<AppState>, Query(f): Query<AppUsageFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, date, app_name, usage_minutes FROM app_usage_daily
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date=?)
ORDER BY usage_minutes DESC LIMIT 100")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.date).bind(&f.date)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"app_usage": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"date": r.get::<String,_>("date"), "app_name": r.get::<String,_>("app_name"),
"usage_minutes": r.get::<i32,_>("usage_minutes")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query app usage", e)),
}
}
pub async fn leaderboard(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT device_uid, SUM(total_active_minutes) as total_minutes FROM usage_daily
WHERE date >= date('now', '-7 days') GROUP BY device_uid ORDER BY total_minutes DESC LIMIT 20")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"leaderboard": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "total_minutes": r.get::<i64,_>("total_minutes")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query usage leaderboard", e)),
}
}

View File

@@ -0,0 +1,47 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct LogFilters {
pub device_uid: Option<String>,
pub operation: Option<String>,
pub usb_serial: Option<String>,
}
pub async fn list_operations(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp
FROM usb_file_operations WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR operation=?) AND (? IS NULL OR usb_serial=?)
ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.operation).bind(&f.operation)
.bind(&f.usb_serial).bind(&f.usb_serial)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"operations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"usb_serial": r.get::<Option<String>,_>("usb_serial"), "drive_letter": r.get::<Option<String>,_>("drive_letter"),
"operation": r.get::<String,_>("operation"), "file_path": r.get::<String,_>("file_path"),
"file_size": r.get::<Option<i64>,_>("file_size"), "timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query USB file operations", e)),
}
}
pub async fn summary(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT device_uid, COUNT(*) as op_count, COUNT(DISTINCT usb_serial) as usb_count,
MIN(timestamp) as first_op, MAX(timestamp) as last_op
FROM usb_file_operations WHERE timestamp >= datetime('now', '-7 days')
GROUP BY device_uid ORDER BY op_count DESC LIMIT 50")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"summary": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "op_count": r.get::<i64,_>("op_count"),
"usb_count": r.get::<i64,_>("usb_count"), "first_op": r.get::<String,_>("first_op"),
"last_op": r.get::<String,_>("last_op")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query USB file audit summary", e)),
}
}

View File

@@ -0,0 +1,186 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::{MessageType, WatermarkConfigPayload};
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateConfigRequest {
pub target_type: Option<String>,
pub target_id: Option<String>,
pub content: Option<String>,
pub font_size: Option<u32>,
pub opacity: Option<f64>,
pub color: Option<String>,
pub angle: Option<i32>,
pub enabled: Option<bool>,
}
pub async fn get_config_list(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, target_type, target_id, content, font_size, opacity, color, angle, enabled, updated_at FROM watermark_config ORDER BY id")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"configs": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "target_type": r.get::<String,_>("target_type"),
"target_id": r.get::<Option<String>,_>("target_id"), "content": r.get::<String,_>("content"),
"font_size": r.get::<i32,_>("font_size"), "opacity": r.get::<f64,_>("opacity"),
"color": r.get::<String,_>("color"), "angle": r.get::<i32,_>("angle"),
"enabled": r.get::<bool,_>("enabled"), "updated_at": r.get::<String,_>("updated_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query watermark configs", e)),
}
}
pub async fn create_config(
State(state): State<AppState>,
Json(req): Json<CreateConfigRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
let content = req.content.unwrap_or_else(|| "{company} | {username} | {date}".to_string());
let font_size = req.font_size.unwrap_or(14).clamp(8, 72) as i32;
let opacity = req.opacity.unwrap_or(0.15).clamp(0.01, 1.0);
let color = req.color.unwrap_or_else(|| "#808080".to_string());
let angle = req.angle.unwrap_or(-30);
let enabled = req.enabled.unwrap_or(true);
// Validate inputs
if content.len() > 200 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("content too long (max 200 chars)")));
}
if !is_valid_hex_color(&color) {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid color format (expected #RRGGBB)")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO watermark_config (target_type, target_id, content, font_size, opacity, color, angle, enabled) VALUES (?,?,?,?,?,?,?,?)")
.bind(&target_type).bind(&req.target_id).bind(&content).bind(font_size).bind(opacity).bind(&color).bind(angle).bind(enabled)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push to online clients
let config = WatermarkConfigPayload {
content: content.clone(),
font_size: font_size as u32,
opacity,
color: color.clone(),
angle,
enabled,
};
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create watermark config", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateConfigRequest {
pub content: Option<String>, pub font_size: Option<u32>, pub opacity: Option<f64>,
pub color: Option<String>, pub angle: Option<i32>, pub enabled: Option<bool>,
}
pub async fn update_config(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateConfigRequest>,
) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM watermark_config WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query watermark config", e)),
};
let content = body.content.unwrap_or_else(|| existing.get::<String, _>("content"));
let font_size = body.font_size.map(|v| v.clamp(8, 72) as i32).unwrap_or_else(|| existing.get::<i32, _>("font_size"));
let opacity = body.opacity.map(|v| v.clamp(0.01, 1.0)).unwrap_or_else(|| existing.get::<f64, _>("opacity"));
let color = body.color.unwrap_or_else(|| existing.get::<String, _>("color"));
let angle = body.angle.unwrap_or_else(|| existing.get::<i32, _>("angle"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate inputs
if content.len() > 200 {
return Json(ApiResponse::error("content too long (max 200 chars)"));
}
if !is_valid_hex_color(&color) {
return Json(ApiResponse::error("invalid color format (expected #RRGGBB)"));
}
let result = sqlx::query(
"UPDATE watermark_config SET content = ?, font_size = ?, opacity = ?, color = ?, angle = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&content)
.bind(font_size)
.bind(opacity)
.bind(&color)
.bind(angle)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
// Push updated config to online clients
let config = WatermarkConfigPayload {
content: content.clone(),
font_size: font_size as u32,
opacity,
color: color.clone(),
angle,
enabled,
};
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update watermark config", e)),
}
}
pub async fn delete_config(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
// Fetch existing config to get target info for push
let existing = sqlx::query("SELECT target_type, target_id FROM watermark_config WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM watermark_config WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
// Push disabled watermark to clients
let disabled = WatermarkConfigPayload {
content: String::new(),
font_size: 0,
opacity: 0.0,
color: String::new(),
angle: 0,
enabled: false,
};
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &disabled, &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
/// Validate a hex color string (#RRGGBB format)
fn is_valid_hex_color(color: &str) -> bool {
if color.len() != 7 || !color.starts_with('#') {
return false;
}
color[1..].chars().all(|c| c.is_ascii_hexdigit())
}

View File

@@ -0,0 +1,156 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String,
pub pattern: String,
pub target_type: Option<String>,
pub target_id: Option<String>,
pub enabled: Option<bool>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"pattern": r.get::<String,_>("pattern"), "target_type": r.get::<String,_>("target_type"),
"target_id": r.get::<Option<String>,_>("target_id"), "enabled": r.get::<bool,_>("enabled"),
"created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>() }))),
Err(e) => Json(ApiResponse::internal_error("query web filter rules", e)),
}
}
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = req.enabled.unwrap_or(true);
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if !matches!(req.rule_type.as_str(), "blacklist" | "whitelist" | "category") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid rule_type (expected blacklist|whitelist|category)")));
}
if req.pattern.trim().is_empty() || req.pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("pattern must be 1-255 chars")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO web_filter_rules (rule_type, pattern, target_type, target_id, enabled) VALUES (?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.pattern).bind(&target_type).bind(&req.target_id).bind(enabled)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create web filter rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest { pub rule_type: Option<String>, pub pattern: Option<String>, pub enabled: Option<bool> }
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM web_filter_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query web filter rule", e)),
};
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
.bind(&rule_type)
.bind(&pattern)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update web filter rule", e)),
}
}
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM web_filter_rules WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM web_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
#[derive(Debug, Deserialize)]
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>() }))),
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
}
}
/// Fetch enabled web filter rules applicable to a given target scope.
/// For "device" targets, includes both global rules and device-specific rules
/// (matching the logic used during registration push in tcp.rs).
async fn fetch_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"), "pattern": r.get::<String,_>("pattern")
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,246 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
use crate::tcp::push_to_targets;
use csm_protocol::{MessageType, UsbPolicyPayload, UsbDeviceRule};
#[derive(Debug, Deserialize)]
pub struct UsbEventListParams {
pub device_uid: Option<String>,
pub event_type: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_events(
State(state): State<AppState>,
Query(params): Query<UsbEventListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let event_type = params.event_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, vendor_id, product_id, serial_number, device_name, event_type, event_time
FROM usb_events WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR event_type = ?)
ORDER BY event_time DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&event_type).bind(&event_type)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"vendor_id": r.get::<Option<String>, _>("vendor_id"),
"product_id": r.get::<Option<String>, _>("product_id"),
"serial_number": r.get::<Option<String>, _>("serial_number"),
"device_name": r.get::<Option<String>, _>("device_name"),
"event_type": r.get::<String, _>("event_type"),
"event_time": r.get::<String, _>("event_time"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"events": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query usb events", e)),
}
}
pub async fn list_policies(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
FROM usb_policies ORDER BY created_at DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name": r.get::<String, _>("name"),
"policy_type": r.get::<String, _>("policy_type"),
"target_group": r.get::<Option<String>, _>("target_group"),
"rules": r.get::<String, _>("rules"),
"enabled": r.get::<i32, _>("enabled"),
"created_at": r.get::<String, _>("created_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"policies": items,
})))
}
Err(e) => Json(ApiResponse::internal_error("query usb policies", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreatePolicyRequest {
pub name: String,
pub policy_type: String,
pub target_group: Option<String>,
pub rules: String,
pub enabled: Option<i32>,
}
pub async fn create_policy(
State(state): State<AppState>,
Json(body): Json<CreatePolicyRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = body.enabled.unwrap_or(1);
let result = sqlx::query(
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
)
.bind(&body.name)
.bind(&body.policy_type)
.bind(&body.target_group)
.bind(&body.rules)
.bind(enabled)
.execute(&state.db)
.await;
match result {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push USB policy to matching online clients
if enabled == 1 {
let payload = build_usb_policy_payload(&body.policy_type, true, &body.rules);
let target_group = body.target_group.as_deref();
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
}
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
"id": new_id,
}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create usb policy", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdatePolicyRequest {
pub name: Option<String>,
pub policy_type: Option<String>,
pub target_group: Option<String>,
pub rules: Option<String>,
pub enabled: Option<i32>,
}
pub async fn update_policy(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdatePolicyRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
// Fetch existing policy
let existing = sqlx::query("SELECT * FROM usb_policies WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Policy not found")),
Err(e) => return Json(ApiResponse::internal_error("query usb policy", e)),
};
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
let policy_type = body.policy_type.unwrap_or_else(|| existing.get::<String, _>("policy_type"));
let target_group = body.target_group.or_else(|| existing.get::<Option<String>, _>("target_group"));
let rules = body.rules.unwrap_or_else(|| existing.get::<String, _>("rules"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
let result = sqlx::query(
"UPDATE usb_policies SET name = ?, policy_type = ?, target_group = ?, rules = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&name)
.bind(&policy_type)
.bind(&target_group)
.bind(&rules)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(_) => {
// Push updated USB policy to matching online clients
let payload = build_usb_policy_payload(&policy_type, enabled == 1, &rules);
let target_group = target_group.as_deref();
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
Json(ApiResponse::ok(serde_json::json!({"updated": true})))
}
Err(e) => Json(ApiResponse::internal_error("update usb policy", e)),
}
}
pub async fn delete_policy(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
// Fetch existing policy to get target info for push
let existing = sqlx::query("SELECT target_group FROM usb_policies WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let target_group = match existing {
Ok(Some(row)) => row.get::<Option<String>, _>("target_group"),
_ => return Json(ApiResponse::error("Policy not found")),
};
let result = sqlx::query("DELETE FROM usb_policies WHERE id = ?")
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
// Push disabled policy to clients
let disabled = UsbPolicyPayload {
policy_type: String::new(),
enabled: false,
rules: vec![],
};
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &disabled, "group", target_group.as_deref()).await;
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
} else {
Json(ApiResponse::error("Policy not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("delete usb policy", e)),
}
}
/// Build a UsbPolicyPayload from raw policy fields
fn build_usb_policy_payload(policy_type: &str, enabled: bool, rules_json: &str) -> UsbPolicyPayload {
let raw_rules: Vec<serde_json::Value> = serde_json::from_str(rules_json).unwrap_or_default();
let rules: Vec<UsbDeviceRule> = raw_rules.iter().map(|r| UsbDeviceRule {
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
}).collect();
UsbPolicyPayload {
policy_type: policy_type.to_string(),
enabled,
rules,
}
}

View File

@@ -0,0 +1,28 @@
use sqlx::SqlitePool;
use tracing::debug;
/// Record an admin audit log entry.
pub async fn audit_log(
db: &SqlitePool,
user_id: i64,
action: &str,
target_type: Option<&str>,
target_id: Option<&str>,
detail: Option<&str>,
) {
let result = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, target_type, target_id, detail) VALUES (?, ?, ?, ?, ?)"
)
.bind(user_id)
.bind(action)
.bind(target_type)
.bind(target_id)
.bind(detail)
.execute(db)
.await;
match result {
Ok(_) => debug!("Audit: user={} action={} target={}/{}", user_id, action, target_type.unwrap_or("-"), target_id.unwrap_or("-")),
Err(e) => tracing::warn!("Failed to write audit log: {}", e),
}
}

134
crates/server/src/config.rs Normal file
View File

@@ -0,0 +1,134 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub retention: RetentionConfig,
#[serde(default)]
pub notify: NotifyConfig,
/// Token required for device registration. Empty = any token accepted.
#[serde(default)]
pub registration_token: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ServerConfig {
pub http_addr: String,
pub tcp_addr: String,
/// Allowed CORS origins. Empty = same-origin only (no CORS headers).
#[serde(default)]
pub cors_origins: Vec<String>,
/// Optional TLS configuration for the TCP listener.
#[serde(default)]
pub tls: Option<TlsConfig>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TlsConfig {
/// Path to the server certificate (PEM format)
pub cert_path: String,
/// Path to the server private key (PEM format)
pub key_path: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DatabaseConfig {
pub path: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AuthConfig {
pub jwt_secret: String,
#[serde(default = "default_access_ttl")]
pub access_token_ttl_secs: u64,
#[serde(default = "default_refresh_ttl")]
pub refresh_token_ttl_secs: u64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RetentionConfig {
#[serde(default = "default_status_history_days")]
pub status_history_days: u32,
#[serde(default = "default_usb_events_days")]
pub usb_events_days: u32,
#[serde(default = "default_asset_changes_days")]
pub asset_changes_days: u32,
#[serde(default = "default_alert_records_days")]
pub alert_records_days: u32,
#[serde(default = "default_audit_log_days")]
pub audit_log_days: u32,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct NotifyConfig {
#[serde(default)]
pub smtp: Option<SmtpConfig>,
#[serde(default)]
pub webhook_urls: Vec<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SmtpConfig {
pub host: String,
pub port: u16,
pub username: String,
pub password: String,
pub from: String,
}
impl AppConfig {
pub async fn load(path: &str) -> Result<Self> {
if Path::new(path).exists() {
let content = tokio::fs::read_to_string(path).await?;
let config: AppConfig = toml::from_str(&content)?;
Ok(config)
} else {
let config = default_config();
// Write default config for reference
let toml_str = toml::to_string_pretty(&config)?;
tokio::fs::write(path, &toml_str).await?;
tracing::warn!("Created default config at {}", path);
Ok(config)
}
}
}
fn default_access_ttl() -> u64 { 1800 } // 30 minutes
fn default_refresh_ttl() -> u64 { 604800 } // 7 days
fn default_status_history_days() -> u32 { 7 }
fn default_usb_events_days() -> u32 { 90 }
fn default_asset_changes_days() -> u32 { 365 }
fn default_alert_records_days() -> u32 { 90 }
fn default_audit_log_days() -> u32 { 365 }
pub fn default_config() -> AppConfig {
AppConfig {
server: ServerConfig {
http_addr: "0.0.0.0:8080".into(),
tcp_addr: "0.0.0.0:9999".into(),
cors_origins: vec![],
tls: None,
},
database: DatabaseConfig {
path: "./csm.db".into(),
},
auth: AuthConfig {
jwt_secret: uuid::Uuid::new_v4().to_string(),
access_token_ttl_secs: default_access_ttl(),
refresh_token_ttl_secs: default_refresh_ttl(),
},
retention: RetentionConfig {
status_history_days: default_status_history_days(),
usb_events_days: default_usb_events_days(),
asset_changes_days: default_asset_changes_days(),
alert_records_days: default_alert_records_days(),
audit_log_days: default_audit_log_days(),
},
notify: NotifyConfig::default(),
registration_token: uuid::Uuid::new_v4().to_string(),
}
}

118
crates/server/src/db.rs Normal file
View File

@@ -0,0 +1,118 @@
use sqlx::SqlitePool;
use anyhow::Result;
/// Database repository for device operations
pub struct DeviceRepo;
impl DeviceRepo {
pub async fn upsert_status(pool: &SqlitePool, device_uid: &str, status: &csm_protocol::DeviceStatus) -> Result<()> {
let top_procs_json = serde_json::to_string(&status.top_processes)?;
// Update latest snapshot
sqlx::query(
"INSERT INTO device_status (device_uid, cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb, network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid) DO UPDATE SET
cpu_usage = excluded.cpu_usage,
memory_usage = excluded.memory_usage,
memory_total_mb = excluded.memory_total_mb,
disk_usage = excluded.disk_usage,
disk_total_mb = excluded.disk_total_mb,
network_rx_rate = excluded.network_rx_rate,
network_tx_rate = excluded.network_tx_rate,
running_procs = excluded.running_procs,
top_processes = excluded.top_processes,
reported_at = datetime('now'),
updated_at = datetime('now')"
)
.bind(device_uid)
.bind(status.cpu_usage)
.bind(status.memory_usage)
.bind(status.memory_total_mb as i64)
.bind(status.disk_usage)
.bind(status.disk_total_mb as i64)
.bind(status.network_rx_rate as i64)
.bind(status.network_tx_rate as i64)
.bind(status.running_procs as i32)
.bind(&top_procs_json)
.execute(pool)
.await?;
// Insert into history
sqlx::query(
"INSERT INTO device_status_history (device_uid, cpu_usage, memory_usage, disk_usage, network_rx_rate, network_tx_rate, running_procs, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(device_uid)
.bind(status.cpu_usage)
.bind(status.memory_usage)
.bind(status.disk_usage)
.bind(status.network_rx_rate as i64)
.bind(status.network_tx_rate as i64)
.bind(status.running_procs as i32)
.execute(pool)
.await?;
// Update device heartbeat
sqlx::query(
"UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?"
)
.bind(device_uid)
.execute(pool)
.await?;
Ok(())
}
pub async fn insert_usb_event(pool: &SqlitePool, event: &csm_protocol::UsbEvent) -> Result<i64> {
let result = sqlx::query(
"INSERT INTO usb_events (device_uid, vendor_id, product_id, serial_number, device_name, event_type)
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&event.device_uid)
.bind(&event.vendor_id)
.bind(&event.product_id)
.bind(&event.serial)
.bind(&event.device_name)
.bind(match event.event_type {
csm_protocol::UsbEventType::Inserted => "inserted",
csm_protocol::UsbEventType::Removed => "removed",
csm_protocol::UsbEventType::Blocked => "blocked",
})
.execute(pool)
.await?;
Ok(result.last_insert_rowid())
}
pub async fn upsert_hardware(pool: &SqlitePool, asset: &csm_protocol::HardwareAsset) -> Result<()> {
sqlx::query(
"INSERT INTO hardware_assets (device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb, gpu_model, motherboard, serial_number, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid) DO UPDATE SET
cpu_model = excluded.cpu_model,
cpu_cores = excluded.cpu_cores,
memory_total_mb = excluded.memory_total_mb,
disk_model = excluded.disk_model,
disk_total_mb = excluded.disk_total_mb,
gpu_model = excluded.gpu_model,
motherboard = excluded.motherboard,
serial_number = excluded.serial_number,
reported_at = datetime('now'),
updated_at = datetime('now')"
)
.bind(&asset.device_uid)
.bind(&asset.cpu_model)
.bind(asset.cpu_cores as i32)
.bind(asset.memory_total_mb as i64)
.bind(&asset.disk_model)
.bind(asset.disk_total_mb as i64)
.bind(&asset.gpu_model)
.bind(&asset.motherboard)
.bind(&asset.serial_number)
.execute(pool)
.await?;
Ok(())
}
}

264
crates/server/src/main.rs Normal file
View File

@@ -0,0 +1,264 @@
use anyhow::Result;
use axum::Router;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{CorsLayer, Any};
use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tracing::{info, warn, error};
mod api;
mod audit;
mod config;
mod db;
mod tcp;
mod ws;
mod alert;
use config::AppConfig;
/// Application shared state
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::SqlitePool,
pub config: Arc<AppConfig>,
pub clients: Arc<tcp::ClientRegistry>,
pub ws_hub: Arc<ws::WsHub>,
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "csm_server=info,tower_http=info".into()),
)
.json()
.init();
info!("CSM Server starting...");
// Load configuration
let config = AppConfig::load("config.toml").await?;
let config = Arc::new(config);
// Initialize database
let db = init_database(&config.database.path).await?;
run_migrations(&db).await?;
info!("Database initialized at {}", config.database.path);
// Ensure default admin exists
ensure_default_admin(&db).await?;
// Initialize shared state
let clients = Arc::new(tcp::ClientRegistry::new());
let ws_hub = Arc::new(ws::WsHub::new());
let state = AppState {
db: db.clone(),
config: config.clone(),
clients: clients.clone(),
ws_hub: ws_hub.clone(),
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
};
// Start background tasks
let cleanup_state = state.clone();
tokio::spawn(async move {
alert::cleanup_task(cleanup_state).await;
});
// Start TCP listener for client connections
let tcp_state = state.clone();
let tcp_addr = config.server.tcp_addr.clone();
tokio::spawn(async move {
if let Err(e) = tcp::start_tcp_server(tcp_addr, tcp_state).await {
error!("TCP server error: {}", e);
}
});
// Build HTTP router
let app = Router::new()
.merge(api::routes(state.clone()))
.layer(
build_cors_layer(&config.server.cors_origins),
)
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
// Security headers
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::X_CONTENT_TYPE_OPTIONS,
axum::http::HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::X_FRAME_OPTIONS,
axum::http::HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("x-xss-protection"),
axum::http::HeaderValue::from_static("1; mode=block"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("referrer-policy"),
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("content-security-policy"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
))
.with_state(state);
// Start HTTP server
let http_addr = &config.server.http_addr;
info!("HTTP server listening on {}", http_addr);
let listener = TcpListener::bind(http_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn init_database(db_path: &str) -> Result<sqlx::SqlitePool> {
// Ensure parent directory exists for file-based databases
// Strip sqlite: prefix if present for directory creation
let file_path = db_path.strip_prefix("sqlite:").unwrap_or(db_path);
// Strip query parameters
let file_path = file_path.split('?').next().unwrap_or(file_path);
if let Some(parent) = Path::new(file_path).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
let options = SqliteConnectOptions::from_str(db_path)?
.journal_mode(SqliteJournalMode::Wal)
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
.busy_timeout(std::time::Duration::from_secs(5))
.foreign_keys(true);
let pool = SqlitePoolOptions::new()
.max_connections(8)
.connect_with(options)
.await?;
// Set pragmas on each connection
sqlx::query("PRAGMA cache_size = -64000")
.execute(&pool)
.await?;
sqlx::query("PRAGMA wal_autocheckpoint = 1000")
.execute(&pool)
.await?;
Ok(pool)
}
async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
// Embedded migrations - run in order
let migrations = [
include_str!("../../../migrations/001_init.sql"),
include_str!("../../../migrations/002_assets.sql"),
include_str!("../../../migrations/003_usb.sql"),
include_str!("../../../migrations/004_alerts.sql"),
include_str!("../../../migrations/005_plugins_web_filter.sql"),
include_str!("../../../migrations/006_plugins_usage_timer.sql"),
include_str!("../../../migrations/007_plugins_software_blocker.sql"),
include_str!("../../../migrations/008_plugins_popup_blocker.sql"),
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
include_str!("../../../migrations/010_plugins_watermark.sql"),
include_str!("../../../migrations/011_token_security.sql"),
];
// Create migrations tracking table
sqlx::query(
"CREATE TABLE IF NOT EXISTS _migrations (id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, applied_at TEXT NOT NULL DEFAULT (datetime('now')))"
)
.execute(pool)
.await?;
for (i, migration_sql) in migrations.iter().enumerate() {
let name = format!("{:03}", i + 1);
let exists: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM _migrations WHERE name = ?"
)
.bind(&name)
.fetch_one(pool)
.await? > 0;
if !exists {
info!("Running migration: {}", name);
sqlx::query(migration_sql)
.execute(pool)
.await?;
sqlx::query("INSERT INTO _migrations (name) VALUES (?)")
.bind(&name)
.execute(pool)
.await?;
}
}
Ok(())
}
async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users")
.fetch_one(pool)
.await?;
if count == 0 {
// Generate a random 16-character alphanumeric password
let random_password: String = {
use std::fmt::Write;
let bytes = uuid::Uuid::new_v4();
let mut s = String::with_capacity(16);
for byte in bytes.as_bytes().iter().take(16) {
write!(s, "{:02x}", byte).unwrap();
}
s
};
let password_hash = bcrypt::hash(&random_password, 12)?;
sqlx::query(
"INSERT INTO users (username, password, role) VALUES (?, ?, 'admin')"
)
.bind("admin")
.bind(&password_hash)
.execute(pool)
.await?;
warn!("Created default admin user (username: admin)");
// Print password directly to stderr — bypasses tracing JSON formatter
eprintln!("============================================================");
eprintln!(" Generated admin password: {}", random_password);
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
eprintln!("============================================================");
}
Ok(())
}
/// Build CORS layer from configured origins.
/// If cors_origins is empty, no CORS headers are sent (same-origin only).
/// If origins are specified, only those are allowed.
fn build_cors_layer(origins: &[String]) -> CorsLayer {
use axum::http::HeaderValue;
let allowed_origins: Vec<HeaderValue> = origins.iter()
.filter_map(|o| o.parse::<HeaderValue>().ok())
.collect();
if allowed_origins.is_empty() {
// No CORS — production safe by default
CorsLayer::new()
} else {
CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
.allow_methods(Any)
.allow_headers(Any)
.max_age(std::time::Duration::from_secs(3600))
}
}

844
crates/server/src/tcp.rs Normal file
View File

@@ -0,0 +1,844 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tokio::net::{TcpListener, TcpStream};
use tracing::{info, warn, debug};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use csm_protocol::{Frame, MessageType, PROTOCOL_VERSION};
use crate::AppState;
/// Maximum frames per second per connection before rate-limiting kicks in
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
const RATE_LIMIT_MAX_FRAMES: usize = 100;
/// Per-connection rate limiter using a sliding window of frame timestamps
struct RateLimiter {
timestamps: Vec<Instant>,
}
impl RateLimiter {
fn new() -> Self {
Self { timestamps: Vec::with_capacity(RATE_LIMIT_MAX_FRAMES) }
}
/// Returns false if the connection is rate-limited
fn check(&mut self) -> bool {
let now = Instant::now();
let cutoff = now - std::time::Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
// Evict timestamps outside the window
self.timestamps.retain(|t| *t > cutoff);
if self.timestamps.len() >= RATE_LIMIT_MAX_FRAMES {
return false;
}
self.timestamps.push(now);
true
}
}
/// Push a plugin config frame to all online clients matching the target scope.
/// target_type: "global" | "device" | "group"
/// target_id: device_uid or group_name (None for global)
pub async fn push_to_targets(
db: &sqlx::SqlitePool,
clients: &crate::tcp::ClientRegistry,
msg_type: MessageType,
payload: &impl serde::Serialize,
target_type: &str,
target_id: Option<&str>,
) {
let frame = match Frame::new_json(msg_type, payload) {
Ok(f) => f.encode(),
Err(e) => {
warn!("Failed to encode plugin push frame: {}", e);
return;
}
};
let online = clients.list_online().await;
let mut pushed_count = 0usize;
// For group targeting, resolve group members from DB once
let group_members: Option<Vec<String>> = if target_type == "group" {
if let Some(group_name) = target_id {
sqlx::query_scalar::<_, String>(
"SELECT device_uid FROM devices WHERE group_name = ?"
)
.bind(group_name)
.fetch_all(db)
.await
.ok()
.into()
} else {
None
}
} else {
None
};
for uid in &online {
let should_push = match target_type {
"global" => true,
"device" => target_id.map_or(false, |id| id == uid),
"group" => {
if let Some(members) = &group_members {
members.contains(uid)
} else {
false
}
}
other => {
warn!("Unknown target_type '{}', skipping push", other);
false
}
};
if should_push {
if clients.send_to(uid, frame.clone()).await {
pushed_count += 1;
}
}
}
debug!("Pushed {:?} to {}/{} online clients (target={})", msg_type, pushed_count, online.len(), target_type);
}
/// Push all active plugin configs to a newly registered client.
pub async fn push_all_plugin_configs(
db: &sqlx::SqlitePool,
clients: &crate::tcp::ClientRegistry,
device_uid: &str,
) {
use sqlx::Row;
// Watermark configs — only push the highest-priority enabled config (device > group > global)
if let Ok(rows) = sqlx::query(
"SELECT content, font_size, opacity, color, angle, enabled FROM watermark_config WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?))) ORDER BY CASE WHEN target_type = 'device' THEN 0 WHEN target_type = 'group' THEN 1 ELSE 2 END LIMIT 1"
)
.bind(device_uid)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
if let Some(row) = rows.first() {
let config = csm_protocol::WatermarkConfigPayload {
content: row.get("content"),
font_size: row.get::<i32, _>("font_size") as u32,
opacity: row.get("opacity"),
color: row.get("color"),
angle: row.get::<i32, _>("angle"),
enabled: row.get("enabled"),
};
if let Ok(frame) = Frame::new_json(MessageType::WatermarkConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Web filter rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"pattern": r.get::<String, _>("pattern"),
})).collect();
if !rules.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Software blacklist
if let Ok(rows) = sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let entries: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name_pattern": r.get::<String, _>("name_pattern"),
"action": r.get::<String, _>("action"),
})).collect();
if !entries.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Popup blocker rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"window_title": r.get::<Option<String>, _>("window_title"),
"window_class": r.get::<Option<String>, _>("window_class"),
"process_name": r.get::<Option<String>, _>("process_name"),
})).collect();
if !rules.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::PopupRules, &serde_json::json!({"rules": rules})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// USB policies — push highest-priority enabled policy for the device's group
if let Ok(rows) = sqlx::query(
"SELECT policy_type, rules, enabled FROM usb_policies WHERE enabled = 1 AND target_group = (SELECT group_name FROM devices WHERE device_uid = ?) ORDER BY CASE WHEN policy_type = 'all_block' THEN 0 WHEN policy_type = 'blacklist' THEN 1 ELSE 2 END LIMIT 1"
)
.bind(device_uid)
.fetch_all(db).await
{
if let Some(row) = rows.first() {
let policy_type: String = row.get("policy_type");
let rules_json: String = row.get("rules");
let rules: Vec<serde_json::Value> = serde_json::from_str(&rules_json).unwrap_or_default();
let payload = csm_protocol::UsbPolicyPayload {
policy_type,
enabled: true,
rules: rules.iter().map(|r| csm_protocol::UsbDeviceRule {
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
}).collect(),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbPolicyUpdate, &payload) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
info!("Pushed all plugin configs to newly registered device {}", device_uid);
}
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Registry of all connected client sessions
#[derive(Clone, Default)]
pub struct ClientRegistry {
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
}
impl ClientRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, tx);
}
pub async fn unregister(&self, device_uid: &str) {
self.sessions.write().await.remove(device_uid);
}
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
if let Some(tx) = self.sessions.read().await.get(device_uid) {
tx.send(data).await.is_ok()
} else {
false
}
}
pub async fn count(&self) -> usize {
self.sessions.read().await.len()
}
pub async fn list_online(&self) -> Vec<String> {
self.sessions.read().await.keys().cloned().collect()
}
}
/// Start the TCP server for client connections (optionally with TLS)
pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<()> {
let listener = TcpListener::bind(&addr).await?;
// Build TLS acceptor if configured
let tls_acceptor = build_tls_acceptor(&state.config.server.tls)?;
if tls_acceptor.is_some() {
info!("TCP server listening on {} (TLS enabled)", addr);
} else {
info!("TCP server listening on {} (plaintext)", addr);
}
loop {
let (stream, peer_addr) = listener.accept().await?;
let state = state.clone();
let acceptor = tls_acceptor.clone();
tokio::spawn(async move {
debug!("New TCP connection from {}", peer_addr);
match acceptor {
Some(acceptor) => {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_client_tls(tls_stream, state).await {
warn!("Client {} TLS error: {}", peer_addr, e);
}
}
Err(e) => warn!("TLS handshake failed for {}: {}", peer_addr, e),
}
}
None => {
if let Err(e) = handle_client(stream, state).await {
warn!("Client {} error: {}", peer_addr, e);
}
}
}
});
}
}
fn build_tls_acceptor(
tls_config: &Option<crate::config::TlsConfig>,
) -> anyhow::Result<Option<tokio_rustls::TlsAcceptor>> {
let config = match tls_config {
Some(c) => c,
None => return Ok(None),
};
let cert_pem = std::fs::read(&config.cert_path)
.map_err(|e| anyhow::anyhow!("Failed to read TLS cert {}: {}", config.cert_path, e))?;
let key_pem = std::fs::read(&config.key_path)
.map_err(|e| anyhow::anyhow!("Failed to read TLS key {}: {}", config.key_path, e))?;
let certs: Vec<rustls_pki_types::CertificateDer> = rustls_pemfile::certs(&mut &cert_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse TLS cert: {:?}", e))?
.into_iter()
.map(|c| c.into())
.collect();
let key = rustls_pemfile::private_key(&mut &key_pem[..])
.map_err(|e| anyhow::anyhow!("Failed to parse TLS key: {:?}", e))?
.ok_or_else(|| anyhow::anyhow!("No private key found in {}", config.key_path))?;
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| anyhow::anyhow!("Failed to build TLS config: {}", e))?;
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(server_config))))
}
/// Cleanup on client disconnect: unregister from client map, mark offline, notify WS.
async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
if let Some(uid) = device_uid {
state.clients.unregister(uid).await;
sqlx::query("UPDATE devices SET status = 'offline' WHERE device_uid = ?")
.bind(uid)
.execute(&state.db)
.await
.ok();
state.ws_hub.broadcast(serde_json::json!({
"type": "device_state",
"device_uid": uid,
"status": "offline"
}).to_string()).await;
info!("Device disconnected: {}", uid);
}
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
/// Verify that a frame sender is a registered device and that the claimed device_uid
/// matches the one registered on this connection. Returns true if valid.
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
match device_uid {
Some(uid) if *uid == claimed_uid => true,
Some(uid) => {
warn!("{} device_uid mismatch: expected {:?}, got {}", msg_type, uid, claimed_uid);
false
}
None => {
warn!("{} from unregistered connection", msg_type);
false
}
}
}
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
async fn process_frame(
frame: Frame,
state: &AppState,
device_uid: &mut Option<String>,
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
) -> anyhow::Result<()> {
match frame.msg_type {
MessageType::Register => {
let req: csm_protocol::RegisterRequest = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration payload: {}", e))?;
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
// Validate registration token against configured token
let expected_token = &state.config.registration_token;
if !expected_token.is_empty() {
if req.registration_token.is_empty() || req.registration_token != *expected_token {
warn!("Registration rejected for {}: invalid token", req.device_uid);
let err_frame = Frame::new_json(MessageType::RegisterResponse,
&serde_json::json!({"error": "invalid_registration_token"}))?;
tx.send(err_frame.encode()).await.ok();
return Ok(());
}
}
// Check if device already exists with a secret (reconnection scenario)
let existing_secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&req.device_uid)
.fetch_optional(&state.db)
.await
.ok()
.flatten();
let device_secret = match existing_secret {
// Existing device — keep the same secret, don't rotate
Some(secret) if !secret.is_empty() => secret,
// New device — generate a fresh secret
_ => uuid::Uuid::new_v4().to_string(),
};
sqlx::query(
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
mac_address=excluded.mac_address, status='online'"
)
.bind(&req.device_uid)
.bind(&req.hostname)
.bind(&req.mac_address)
.bind(&req.os_version)
.bind(&device_secret)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error during registration: {}", e))?;
*device_uid = Some(req.device_uid.clone());
// If this device was already connected on a different session, evict the old one
// The new register() call will replace it in the hashmap
state.clients.register(req.device_uid.clone(), tx.clone()).await;
// Send registration response
let config = csm_protocol::ClientConfig::default();
let response = csm_protocol::RegisterResponse {
device_secret,
config,
};
let resp_frame = Frame::new_json(MessageType::RegisterResponse, &response)?;
tx.send(resp_frame.encode()).await?;
info!("Device registered successfully: {} ({})", req.hostname, req.device_uid);
// Push all active plugin configs to newly registered client
push_all_plugin_configs(&state.db, &state.clients, &req.device_uid).await;
}
MessageType::Heartbeat => {
let heartbeat: csm_protocol::HeartbeatPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid heartbeat: {}", e))?;
if !verify_device_uid(device_uid, "Heartbeat", &heartbeat.device_uid) {
return Ok(());
}
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
let secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
if let Some(ref secret) = secret {
if !secret.is_empty() {
if heartbeat.hmac.is_empty() {
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
return Ok(());
}
// Constant-time HMAC verification using hmac::Mac::verify_slice
let message = format!("{}\n{}", heartbeat.device_uid, heartbeat.timestamp);
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
.map_err(|_| anyhow::anyhow!("HMAC key error"))?;
mac.update(message.as_bytes());
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
if mac.verify_slice(&provided_bytes).is_err() {
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
return Ok(());
}
}
}
debug!("Heartbeat from {} (hmac verified)", heartbeat.device_uid);
// Update device status in DB
sqlx::query("UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?")
.bind(&heartbeat.device_uid)
.execute(&state.db)
.await
.ok();
// Push to WebSocket subscribers
state.ws_hub.broadcast(serde_json::json!({
"type": "device_state",
"device_uid": heartbeat.device_uid,
"status": "online"
}).to_string()).await;
}
MessageType::StatusReport => {
let status: csm_protocol::DeviceStatus = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid status report: {}", e))?;
if !verify_device_uid(device_uid, "StatusReport", &status.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::upsert_status(&state.db, &status.device_uid, &status).await?;
// Push to WebSocket subscribers
state.ws_hub.broadcast(serde_json::json!({
"type": "device_status",
"device_uid": status.device_uid,
"cpu": status.cpu_usage,
"memory": status.memory_usage,
"disk": status.disk_usage
}).to_string()).await;
}
MessageType::UsbEvent => {
let event: csm_protocol::UsbEvent = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB event: {}", e))?;
if !verify_device_uid(device_uid, "UsbEvent", &event.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::insert_usb_event(&state.db, &event).await?;
state.ws_hub.broadcast(serde_json::json!({
"type": "usb_event",
"device_uid": event.device_uid,
"event": event.event_type,
"usb_name": event.device_name
}).to_string()).await;
}
MessageType::AssetReport => {
let asset: csm_protocol::HardwareAsset = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid asset report: {}", e))?;
if !verify_device_uid(device_uid, "AssetReport", &asset.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::upsert_hardware(&state.db, &asset).await?;
}
MessageType::UsageReport => {
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
if !verify_device_uid(device_uid, "UsageReport", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO usage_daily (device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at) \
VALUES (?, ?, ?, ?, ?, ?) \
ON CONFLICT(device_uid, date) DO UPDATE SET \
total_active_minutes = excluded.total_active_minutes, \
total_idle_minutes = excluded.total_idle_minutes, \
first_active_at = excluded.first_active_at, \
last_active_at = excluded.last_active_at"
)
.bind(&report.device_uid)
.bind(&report.date)
.bind(report.total_active_minutes as i32)
.bind(report.total_idle_minutes as i32)
.bind(&report.first_active_at)
.bind(&report.last_active_at)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting usage report: {}", e))?;
debug!("Usage report saved for device {}", report.device_uid);
}
MessageType::AppUsageReport => {
let report: csm_protocol::AppUsageEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid app usage report: {}", e))?;
if !verify_device_uid(device_uid, "AppUsageReport", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
VALUES (?, ?, ?, ?) \
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
)
.bind(&report.device_uid)
.bind(&report.date)
.bind(&report.app_name)
.bind(report.usage_minutes as i32)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting app usage: {}", e))?;
debug!("App usage saved: {} -> {} ({} min)", report.device_uid, report.app_name, report.usage_minutes);
}
MessageType::SoftwareViolation => {
let report: csm_protocol::SoftwareViolationReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software violation: {}", e))?;
if !verify_device_uid(device_uid, "SoftwareViolation", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO software_violations (device_uid, software_name, action_taken, timestamp) VALUES (?, ?, ?, ?)"
)
.bind(&report.device_uid)
.bind(&report.software_name)
.bind(&report.action_taken)
.bind(&report.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting software violation: {}", e))?;
info!("Software violation: {} tried to run {} -> {}", report.device_uid, report.software_name, report.action_taken);
state.ws_hub.broadcast(serde_json::json!({
"type": "software_violation",
"device_uid": report.device_uid,
"software_name": report.software_name,
"action_taken": report.action_taken
}).to_string()).await;
}
MessageType::UsbFileOp => {
let entry: csm_protocol::UsbFileOpEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB file op: {}", e))?;
if !verify_device_uid(device_uid, "UsbFileOp", &entry.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO usb_file_operations (device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)"
)
.bind(&entry.device_uid)
.bind(&entry.usb_serial)
.bind(&entry.drive_letter)
.bind(&entry.operation)
.bind(&entry.file_path)
.bind(entry.file_size.map(|s| s as i64))
.bind(&entry.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting USB file op: {}", e))?;
debug!("USB file op: {} {} on {}", entry.operation, entry.file_path, entry.device_uid);
}
MessageType::WebAccessLog => {
let entry: csm_protocol::WebAccessLogEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web access log: {}", e))?;
if !verify_device_uid(device_uid, "WebAccessLog", &entry.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO web_access_log (device_uid, url, action, timestamp) VALUES (?, ?, ?, ?)"
)
.bind(&entry.device_uid)
.bind(&entry.url)
.bind(&entry.action)
.bind(&entry.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting web access log: {}", e))?;
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
/// Handle a single client TCP connection
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Set read timeout to detect stale connections
let _ = stream.set_nodelay(true);
let (mut reader, mut writer) = stream.into_split();
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
// Writer task: forwards messages from channel to TCP stream
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
}
});
// Reader loop
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break; // Connection closed
}
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("Connection exceeded max buffer size, dropping");
break;
}
// Process complete frames
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
// Remove consumed bytes without reallocating
read_buf.drain(..frame_size);
// Rate limit check
if !rate_limiter.check() {
warn!("Rate limit exceeded for device {:?}, dropping connection", device_uid);
break 'reader;
}
// Verify protocol version
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
warn!("Frame processing error: {}", e);
}
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
}
/// Handle a TLS-wrapped client connection
async fn handle_client_tls(
stream: tokio_rustls::server::TlsStream<TcpStream>,
state: AppState,
) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
}
});
// Reader loop — same logic as plaintext handler
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break;
}
read_buf.extend_from_slice(&buffer[..n]);
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("TLS connection exceeded max buffer size, dropping");
break;
}
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
read_buf.drain(..frame_size);
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if !rate_limiter.check() {
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
break 'reader;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
warn!("Frame processing error: {}", e);
}
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
}

125
crates/server/src/ws.rs Normal file
View File

@@ -0,0 +1,125 @@
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
use axum::response::IntoResponse;
use axum::extract::Query;
use jsonwebtoken::{decode, Validation, DecodingKey};
use serde::Deserialize;
use tokio::sync::broadcast;
use std::sync::Arc;
use tracing::{debug, warn};
use crate::api::auth::Claims;
use crate::AppState;
/// WebSocket hub for broadcasting real-time events to admin browsers
#[derive(Clone)]
pub struct WsHub {
tx: broadcast::Sender<String>,
}
impl WsHub {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(1024);
Self { tx }
}
pub async fn broadcast(&self, message: String) {
if self.tx.send(message).is_err() {
debug!("No WebSocket subscribers to receive broadcast");
}
}
pub fn subscribe(&self) -> broadcast::Receiver<String> {
self.tx.subscribe()
}
}
#[derive(Debug, Deserialize)]
pub struct WsAuthParams {
pub token: Option<String>,
}
/// HTTP upgrade handler for WebSocket connections
/// Validates JWT token from query parameter before upgrading
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(params): Query<WsAuthParams>,
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl IntoResponse {
let token = match params.token {
Some(t) => t,
None => {
warn!("WebSocket connection rejected: no token provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
}
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(e) => {
warn!("WebSocket connection rejected: invalid token - {}", e);
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
}
};
if claims.token_type != "access" {
warn!("WebSocket connection rejected: not an access token");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
}
let hub = state.ws_hub.clone();
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
}
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claims.username);
let welcome = serde_json::json!({
"type": "connected",
"message": "CSM real-time feed active",
"user": claims.username
});
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
return;
}
// Subscribe to broadcast hub for real-time events
let mut rx = hub.subscribe();
loop {
tokio::select! {
// Forward broadcast messages to WebSocket client
msg = rx.recv() => {
match msg {
Ok(text) => {
if socket.send(Message::Text(text)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
debug!("WebSocket client lagged {} messages, continuing", n);
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
// Handle incoming WebSocket messages (ping/close)
msg = socket.recv() => {
match msg {
Some(Ok(Message::Ping(data))) => {
if socket.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) => break,
Some(Err(_)) => break,
None => break,
_ => {}
}
}
}
}
debug!("WebSocket client disconnected: user={}", claims.username);
}