feat: 初始化项目基础架构和核心功能
- 添加项目基础结构:Cargo.toml、.gitignore、设备UID和密钥文件 - 实现前端Vue3项目结构:路由、登录页面、设备管理页面 - 添加核心协议定义(crates/protocol):设备状态、资产、USB事件等 - 实现客户端监控模块:系统状态收集、资产收集 - 实现服务端基础API和插件系统 - 添加数据库迁移脚本:设备管理、资产跟踪、告警系统等 - 实现前端设备状态展示和基本交互 - 添加使用时长统计和水印功能插件
This commit is contained in:
48
crates/client/Cargo.toml
Normal file
48
crates/client/Cargo.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[package]
|
||||
name = "csm-client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
csm-protocol = { path = "../protocol" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
sysinfo = "0.30"
|
||||
tokio-rustls = "0.26"
|
||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||
rustls-pki-types = "1"
|
||||
webpki-roots = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
windows = { version = "0.54", features = [
|
||||
"Win32_Foundation",
|
||||
"Win32_System_SystemInformation",
|
||||
"Win32_System_Registry",
|
||||
"Win32_System_IO",
|
||||
"Win32_Security",
|
||||
"Win32_NetworkManagement_IpHelper",
|
||||
"Win32_Storage_FileSystem",
|
||||
"Win32_UI_WindowsAndMessaging",
|
||||
"Win32_UI_Input_KeyboardAndMouse",
|
||||
"Win32_System_Threading",
|
||||
"Win32_System_Diagnostics_ToolHelp",
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_Performance",
|
||||
"Win32_Graphics_Gdi",
|
||||
] }
|
||||
windows-service = "0.7"
|
||||
hostname = "0.4"
|
||||
|
||||
[target.'cfg(not(target_os = "windows"))'.dependencies]
|
||||
hostname = "0.4"
|
||||
54
crates/client/src/asset/mod.rs
Normal file
54
crates/client/src/asset/mod.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use csm_protocol::{Frame, MessageType, HardwareAsset};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{info, error};
|
||||
use sysinfo::System;
|
||||
|
||||
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
let interval = Duration::from_secs(86400); // Once per day
|
||||
|
||||
// Initial collection on startup
|
||||
if let Err(e) = collect_and_send(&tx, &device_uid).await {
|
||||
error!("Initial asset collection failed: {}", e);
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
if let Err(e) = collect_and_send(&tx, &device_uid).await {
|
||||
error!("Asset collection failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_and_send(tx: &Sender<Frame>, device_uid: &str) -> anyhow::Result<()> {
|
||||
let hardware = collect_hardware(device_uid)?;
|
||||
let frame = Frame::new_json(MessageType::AssetReport, &hardware)?;
|
||||
tx.send(frame).await.map_err(|e| anyhow::anyhow!("Channel send failed: {}", e))?;
|
||||
info!("Asset report sent for {}", device_uid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
|
||||
let mut sys = System::new_all();
|
||||
sys.refresh_all();
|
||||
|
||||
let cpu_model = sys.cpus().first()
|
||||
.map(|c| c.brand().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
|
||||
let cpu_cores = sys.cpus().len() as u32;
|
||||
let memory_total_mb = sys.total_memory() / 1024 / 1024; // bytes to MB (sysinfo 0.30)
|
||||
|
||||
Ok(HardwareAsset {
|
||||
device_uid: device_uid.to_string(),
|
||||
cpu_model,
|
||||
cpu_cores,
|
||||
memory_total_mb: memory_total_mb as u64,
|
||||
disk_model: "Unknown".to_string(),
|
||||
disk_total_mb: 0,
|
||||
gpu_model: None,
|
||||
motherboard: None,
|
||||
serial_number: None,
|
||||
})
|
||||
}
|
||||
206
crates/client/src/main.rs
Normal file
206
crates/client/src/main.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use anyhow::Result;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, error, warn};
|
||||
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
|
||||
|
||||
mod monitor;
|
||||
mod asset;
|
||||
mod usb;
|
||||
mod network;
|
||||
mod watermark;
|
||||
mod usage_timer;
|
||||
mod usb_audit;
|
||||
mod popup_blocker;
|
||||
mod software_blocker;
|
||||
mod web_filter;
|
||||
|
||||
/// Shared shutdown flag
|
||||
static SHUTDOWN: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
/// Client configuration
|
||||
struct ClientState {
|
||||
device_uid: String,
|
||||
server_addr: String,
|
||||
config: ClientConfig,
|
||||
device_secret: Option<String>,
|
||||
registration_token: String,
|
||||
/// Whether to use TLS when connecting to the server
|
||||
use_tls: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter("csm_client=info")
|
||||
.init();
|
||||
|
||||
info!("CSM Client starting...");
|
||||
|
||||
// Load or generate device identity
|
||||
let device_uid = load_or_create_device_uid()?;
|
||||
info!("Device UID: {}", device_uid);
|
||||
|
||||
// Load server address
|
||||
let server_addr = std::env::var("CSM_SERVER")
|
||||
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
|
||||
|
||||
let state = ClientState {
|
||||
device_uid,
|
||||
server_addr,
|
||||
config: ClientConfig::default(),
|
||||
device_secret: load_device_secret(),
|
||||
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
|
||||
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
|
||||
};
|
||||
|
||||
// TODO: Register as Windows Service on Windows
|
||||
// For development, run directly
|
||||
|
||||
// Main event loop
|
||||
run(state).await
|
||||
}
|
||||
|
||||
async fn run(state: ClientState) -> Result<()> {
|
||||
let (data_tx, mut data_rx) = tokio::sync::mpsc::channel::<Frame>(1024);
|
||||
|
||||
// Spawn Ctrl+C handler
|
||||
tokio::spawn(async {
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
info!("Received Ctrl+C, initiating graceful shutdown...");
|
||||
SHUTDOWN.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
// Create plugin config channels
|
||||
let (watermark_tx, watermark_rx) = tokio::sync::watch::channel(None);
|
||||
let (web_filter_tx, web_filter_rx) = tokio::sync::watch::channel(web_filter::WebFilterConfig::default());
|
||||
let (software_blocker_tx, software_blocker_rx) = tokio::sync::watch::channel(software_blocker::SoftwareBlockerConfig::default());
|
||||
let (popup_blocker_tx, popup_blocker_rx) = tokio::sync::watch::channel(popup_blocker::PopupBlockerConfig::default());
|
||||
let (usb_audit_tx, usb_audit_rx) = tokio::sync::watch::channel(usb_audit::UsbAuditConfig::default());
|
||||
let (usage_timer_tx, usage_timer_rx) = tokio::sync::watch::channel(usage_timer::UsageConfig::default());
|
||||
let (usb_policy_tx, usb_policy_rx) = tokio::sync::watch::channel(None::<UsbPolicyPayload>);
|
||||
|
||||
let plugins = network::PluginChannels {
|
||||
watermark_tx,
|
||||
web_filter_tx,
|
||||
software_blocker_tx,
|
||||
popup_blocker_tx,
|
||||
usb_audit_tx,
|
||||
usage_timer_tx,
|
||||
usb_policy_tx,
|
||||
};
|
||||
|
||||
// Spawn core monitoring tasks
|
||||
let monitor_tx = data_tx.clone();
|
||||
let uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
monitor::start_collecting(monitor_tx, uid).await;
|
||||
});
|
||||
|
||||
let asset_tx = data_tx.clone();
|
||||
let uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
asset::start_collecting(asset_tx, uid).await;
|
||||
});
|
||||
|
||||
let usb_tx = data_tx.clone();
|
||||
let uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
usb::start_monitoring(usb_tx, uid, usb_policy_rx).await;
|
||||
});
|
||||
|
||||
// Spawn plugin tasks
|
||||
tokio::spawn(async move {
|
||||
watermark::start(watermark_rx).await;
|
||||
});
|
||||
|
||||
let usage_data_tx = data_tx.clone();
|
||||
let usage_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
usage_timer::start(usage_timer_rx, usage_data_tx, usage_uid).await;
|
||||
});
|
||||
|
||||
let audit_data_tx = data_tx.clone();
|
||||
let audit_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
usb_audit::start(usb_audit_rx, audit_data_tx, audit_uid).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
popup_blocker::start(popup_blocker_rx).await;
|
||||
});
|
||||
|
||||
let sw_data_tx = data_tx.clone();
|
||||
let sw_uid = state.device_uid.clone();
|
||||
tokio::spawn(async move {
|
||||
software_blocker::start(software_blocker_rx, sw_data_tx, sw_uid).await;
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
web_filter::start(web_filter_rx).await;
|
||||
});
|
||||
|
||||
// Connect to server with reconnect
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let max_backoff = Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
if SHUTDOWN.load(Ordering::SeqCst) {
|
||||
info!("Shutting down gracefully...");
|
||||
break Ok(());
|
||||
}
|
||||
|
||||
match network::connect_and_run(&state, &mut data_rx, &plugins).await {
|
||||
Ok(()) => {
|
||||
warn!("Disconnected from server, reconnecting...");
|
||||
// Use a short fixed delay for clean disconnects (server-initiated),
|
||||
// but don't reset to zero to prevent connection storms
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Connection error: {}, reconnecting...", e);
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(max_backoff);
|
||||
}
|
||||
}
|
||||
|
||||
// Drain stale frames that accumulated during disconnection
|
||||
let drained = data_rx.try_recv().ok().map(|_| 1).unwrap_or(0);
|
||||
if drained > 0 {
|
||||
let mut count = drained;
|
||||
while data_rx.try_recv().is_ok() {
|
||||
count += 1;
|
||||
}
|
||||
warn!("Drained {} stale frames from channel", count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_or_create_device_uid() -> Result<String> {
|
||||
// In production, store in Windows Credential Store or local config
|
||||
// For now, use a simple file
|
||||
let uid_file = "device_uid.txt";
|
||||
if std::path::Path::new(uid_file).exists() {
|
||||
let uid = std::fs::read_to_string(uid_file)?;
|
||||
Ok(uid.trim().to_string())
|
||||
} else {
|
||||
let uid = uuid::Uuid::new_v4().to_string();
|
||||
std::fs::write(uid_file, &uid)?;
|
||||
Ok(uid)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load persisted device_secret from disk (if available)
|
||||
pub fn load_device_secret() -> Option<String> {
|
||||
let secret_file = "device_secret.txt";
|
||||
let secret = std::fs::read_to_string(secret_file).ok()?;
|
||||
let trimmed = secret.trim().to_string();
|
||||
if trimmed.is_empty() { None } else { Some(trimmed) }
|
||||
}
|
||||
|
||||
/// Persist device_secret to disk
|
||||
pub fn save_device_secret(secret: &str) {
|
||||
if let Err(e) = std::fs::write("device_secret.txt", secret) {
|
||||
warn!("Failed to persist device_secret: {}", e);
|
||||
}
|
||||
}
|
||||
86
crates/client/src/monitor/mod.rs
Normal file
86
crates/client/src/monitor/mod.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use anyhow::Result;
|
||||
use csm_protocol::{Frame, MessageType, DeviceStatus, ProcessInfo};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{info, error, debug};
|
||||
use sysinfo::System;
|
||||
|
||||
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
|
||||
let interval = Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
// Run blocking sysinfo collection on a dedicated thread
|
||||
let uid_clone = device_uid.clone();
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
collect_system_status(&uid_clone)
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(status)) => {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::StatusReport, &status) {
|
||||
debug!("Sending status report: cpu={:.1}%, mem={:.1}%", status.cpu_usage, status.memory_usage);
|
||||
if tx.send(frame).await.is_err() {
|
||||
info!("Monitor channel closed, exiting");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("Failed to collect system status: {}", e);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Monitor task join error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
|
||||
let mut sys = System::new_all();
|
||||
sys.refresh_all();
|
||||
|
||||
// Brief wait for CPU usage to stabilize
|
||||
std::thread::sleep(Duration::from_millis(200));
|
||||
sys.refresh_all();
|
||||
|
||||
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
|
||||
|
||||
let total_memory = sys.total_memory() / 1024 / 1024; // Convert bytes to MB (sysinfo 0.30 returns bytes)
|
||||
let used_memory = sys.used_memory() / 1024 / 1024;
|
||||
let memory_usage = if total_memory > 0 {
|
||||
(used_memory as f64 / total_memory as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// Top processes by CPU
|
||||
let mut processes: Vec<ProcessInfo> = sys.processes()
|
||||
.iter()
|
||||
.map(|(_, p)| {
|
||||
ProcessInfo {
|
||||
name: p.name().to_string(),
|
||||
pid: p.pid().as_u32(),
|
||||
cpu_usage: p.cpu_usage() as f64,
|
||||
memory_mb: p.memory() / 1024 / 1024, // bytes to MB (sysinfo 0.30)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
processes.sort_by(|a, b| b.cpu_usage.partial_cmp(&a.cpu_usage).unwrap_or(std::cmp::Ordering::Equal));
|
||||
processes.truncate(10);
|
||||
|
||||
Ok(DeviceStatus {
|
||||
device_uid: device_uid.to_string(),
|
||||
cpu_usage,
|
||||
memory_usage,
|
||||
memory_total_mb: total_memory as u64,
|
||||
disk_usage: 0.0, // TODO: implement disk usage via Windows API
|
||||
disk_total_mb: 0,
|
||||
network_rx_rate: 0,
|
||||
network_tx_rate: 0,
|
||||
running_procs: sys.processes().len() as u32,
|
||||
top_processes: processes,
|
||||
})
|
||||
}
|
||||
380
crates/client/src/network/mod.rs
Normal file
380
crates/client/src/network/mod.rs
Normal file
@@ -0,0 +1,380 @@
|
||||
use anyhow::Result;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::ClientState;
|
||||
|
||||
/// Holds senders for all plugin config channels
|
||||
pub struct PluginChannels {
|
||||
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
|
||||
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
|
||||
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
|
||||
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
|
||||
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
|
||||
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
|
||||
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
pub async fn connect_and_run(
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
|
||||
info!("TCP connected to {}", state.server_addr);
|
||||
|
||||
if state.use_tls {
|
||||
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
|
||||
run_comm_loop(tls_stream, state, data_rx, plugins).await
|
||||
} else {
|
||||
run_comm_loop(tcp_stream, state, data_rx, plugins).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a TCP stream with TLS.
|
||||
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
|
||||
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
|
||||
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load custom CA certificate if specified
|
||||
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
|
||||
let ca_pem = std::fs::read(&ca_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
|
||||
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
|
||||
for cert in certs {
|
||||
root_store.add(cert)?;
|
||||
}
|
||||
info!("Loaded custom CA certificates from {}", ca_path);
|
||||
}
|
||||
|
||||
// Always include system roots as fallback
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
|
||||
warn!("TLS certificate verification DISABLED — do not use in production!");
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
||||
|
||||
// Extract hostname from server_addr (strip port)
|
||||
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
|
||||
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
|
||||
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
|
||||
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
info!("TLS handshake completed with {}", domain);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer],
|
||||
_server_name: &rustls_pki_types::ServerName,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Main communication loop over any read+write stream
|
||||
async fn run_comm_loop<S>(
|
||||
mut stream: S,
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||
{
|
||||
// Send registration
|
||||
let register = RegisterRequest {
|
||||
device_uid: state.device_uid.clone(),
|
||||
hostname: hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string()),
|
||||
registration_token: state.registration_token.clone(),
|
||||
os_version: get_os_info(),
|
||||
mac_address: None,
|
||||
};
|
||||
|
||||
let frame = Frame::new_json(MessageType::Register, ®ister)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
info!("Registration request sent");
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
// Clamp heartbeat interval to sane range [5, 3600] to prevent CPU spin or effective disable
|
||||
let heartbeat_secs = state.config.heartbeat_interval_secs.clamp(5, 3600);
|
||||
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
|
||||
heartbeat_interval.tick().await; // Skip first tick
|
||||
|
||||
// HMAC key — set after receiving RegisterResponse
|
||||
let mut device_secret: Option<String> = state.device_secret.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Read from server
|
||||
result = stream.read(&mut buffer) => {
|
||||
let n = result?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("Server closed connection"));
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Guard against unbounded buffer growth from a malicious server
|
||||
if read_buf.len() > 1_048_576 {
|
||||
return Err(anyhow::anyhow!("Read buffer exceeded 1MB, server may be malicious"));
|
||||
}
|
||||
|
||||
// Process complete frames
|
||||
loop {
|
||||
match Frame::decode(&read_buf)? {
|
||||
Some(frame) => {
|
||||
let consumed = frame.encoded_size();
|
||||
read_buf.drain(..consumed);
|
||||
// Capture device_secret from registration response
|
||||
if frame.msg_type == MessageType::RegisterResponse {
|
||||
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
|
||||
device_secret = Some(resp.device_secret.clone());
|
||||
crate::save_device_secret(&resp.device_secret);
|
||||
info!("Device secret received and persisted, HMAC enabled for heartbeats");
|
||||
}
|
||||
}
|
||||
handle_server_message(frame, plugins)?;
|
||||
}
|
||||
None => break, // Incomplete frame, wait for more data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queued data
|
||||
frame = data_rx.recv() => {
|
||||
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
}
|
||||
|
||||
// Heartbeat
|
||||
_ = heartbeat_interval.tick() => {
|
||||
let timestamp = chrono::Utc::now().to_rfc3339();
|
||||
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, ×tamp);
|
||||
let heartbeat = HeartbeatPayload {
|
||||
device_uid: state.device_uid.clone(),
|
||||
timestamp,
|
||||
hmac: hmac_value,
|
||||
};
|
||||
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::RegisterResponse => {
|
||||
let resp: RegisterResponse = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
|
||||
info!("Registration accepted by server (server version: {})", resp.config.server_version);
|
||||
}
|
||||
MessageType::PolicyUpdate => {
|
||||
let policy: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
|
||||
info!("Received policy update: {}", policy);
|
||||
}
|
||||
MessageType::ConfigUpdate => {
|
||||
info!("Received config update");
|
||||
}
|
||||
MessageType::TaskExecute => {
|
||||
warn!("Task execution requested (not yet implemented)");
|
||||
}
|
||||
MessageType::WatermarkConfig => {
|
||||
let config: WatermarkConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
|
||||
info!("Received watermark config: enabled={}", config.enabled);
|
||||
plugins.watermark_tx.send(Some(config))?;
|
||||
}
|
||||
MessageType::UsbPolicyUpdate => {
|
||||
let policy: UsbPolicyPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid USB policy: {}", e))?;
|
||||
info!("Received USB policy: type={}, enabled={}", policy.policy_type, policy.enabled);
|
||||
plugins.usb_policy_tx.send(Some(policy))?;
|
||||
}
|
||||
MessageType::WebFilterRuleUpdate => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
|
||||
info!("Received web filter rules update");
|
||||
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
|
||||
plugins.web_filter_tx.send(config)?;
|
||||
}
|
||||
MessageType::SoftwareBlacklist => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
|
||||
info!("Received software blacklist update");
|
||||
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
|
||||
plugins.software_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PopupRules => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
|
||||
info!("Received popup blocker rules update");
|
||||
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
|
||||
plugins.popup_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PluginEnable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
|
||||
info!("Plugin enabled: {}", payload.plugin_name);
|
||||
// Route to appropriate plugin channel based on plugin_name
|
||||
handle_plugin_control(&payload, plugins, true)?;
|
||||
}
|
||||
MessageType::PluginDisable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
|
||||
info!("Plugin disabled: {}", payload.plugin_name);
|
||||
handle_plugin_control(&payload, plugins, false)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_plugin_control(
|
||||
payload: &csm_protocol::PluginControlPayload,
|
||||
plugins: &PluginChannels,
|
||||
enabled: bool,
|
||||
) -> Result<()> {
|
||||
match payload.plugin_name.as_str() {
|
||||
"watermark" => {
|
||||
if !enabled {
|
||||
// Send disabled config to remove overlay
|
||||
plugins.watermark_tx.send(None)?;
|
||||
}
|
||||
// When enabling, server will push the actual config next
|
||||
}
|
||||
"web_filter" => {
|
||||
if !enabled {
|
||||
// Clear hosts rules on disable
|
||||
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
// When enabling, server will push rules
|
||||
}
|
||||
"software_blocker" => {
|
||||
if !enabled {
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
|
||||
}
|
||||
}
|
||||
"popup_blocker" => {
|
||||
if !enabled {
|
||||
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usb_audit" => {
|
||||
if !enabled {
|
||||
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usage_timer" => {
|
||||
if !enabled {
|
||||
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
|
||||
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
|
||||
let secret = match secret {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return String::new(),
|
||||
};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
fn get_os_info() -> String {
|
||||
use sysinfo::System;
|
||||
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
|
||||
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
|
||||
format!("{} {}", name, version)
|
||||
}
|
||||
370
crates/client/src/network/mod.rs.tmp.575580.1775308681874
Normal file
370
crates/client/src/network/mod.rs.tmp.575580.1775308681874
Normal file
@@ -0,0 +1,370 @@
|
||||
use anyhow::Result;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::ClientState;
|
||||
|
||||
/// Maximum accumulated read buffer size per connection (8 MB)
|
||||
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Holds senders for all plugin config channels
|
||||
pub struct PluginChannels {
|
||||
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
|
||||
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
|
||||
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
|
||||
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
|
||||
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
|
||||
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
|
||||
}
|
||||
|
||||
/// Connect to server and run the main communication loop
|
||||
pub async fn connect_and_run(
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
|
||||
info!("TCP connected to {}", state.server_addr);
|
||||
|
||||
if state.use_tls {
|
||||
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
|
||||
run_comm_loop(tls_stream, state, data_rx, plugins).await
|
||||
} else {
|
||||
run_comm_loop(tcp_stream, state, data_rx, plugins).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a TCP stream with TLS.
|
||||
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
|
||||
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
|
||||
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load custom CA certificate if specified
|
||||
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
|
||||
let ca_pem = std::fs::read(&ca_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
|
||||
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
|
||||
for cert in certs {
|
||||
root_store.add(cert)?;
|
||||
}
|
||||
info!("Loaded custom CA certificates from {}", ca_path);
|
||||
}
|
||||
|
||||
// Always include system roots as fallback
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
|
||||
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
|
||||
warn!("TLS certificate verification DISABLED — do not use in production!");
|
||||
rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
|
||||
|
||||
// Extract hostname from server_addr (strip port)
|
||||
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
|
||||
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
|
||||
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
|
||||
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
info!("TLS handshake completed with {}", domain);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer],
|
||||
_server_name: &rustls_pki_types::ServerName,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls_pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Main communication loop over any read+write stream
|
||||
async fn run_comm_loop<S>(
|
||||
mut stream: S,
|
||||
state: &ClientState,
|
||||
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
|
||||
plugins: &PluginChannels,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||
{
|
||||
// Send registration
|
||||
let register = RegisterRequest {
|
||||
device_uid: state.device_uid.clone(),
|
||||
hostname: hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string()),
|
||||
registration_token: state.registration_token.clone(),
|
||||
os_version: get_os_info(),
|
||||
mac_address: None,
|
||||
};
|
||||
|
||||
let frame = Frame::new_json(MessageType::Register, ®ister)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
info!("Registration request sent");
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let heartbeat_secs = state.config.heartbeat_interval_secs;
|
||||
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
|
||||
heartbeat_interval.tick().await; // Skip first tick
|
||||
|
||||
// HMAC key — set after receiving RegisterResponse
|
||||
let mut device_secret: Option<String> = state.device_secret.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Read from server
|
||||
result = stream.read(&mut buffer) => {
|
||||
let n = result?;
|
||||
if n == 0 {
|
||||
return Err(anyhow::anyhow!("Server closed connection"));
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Process complete frames
|
||||
loop {
|
||||
match Frame::decode(&read_buf)? {
|
||||
Some(frame) => {
|
||||
let consumed = frame.encoded_size();
|
||||
read_buf.drain(..consumed);
|
||||
// Capture device_secret from registration response
|
||||
if frame.msg_type == MessageType::RegisterResponse {
|
||||
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
|
||||
device_secret = Some(resp.device_secret.clone());
|
||||
crate::save_device_secret(&resp.device_secret);
|
||||
info!("Device secret received and persisted, HMAC enabled for heartbeats");
|
||||
}
|
||||
}
|
||||
handle_server_message(frame, plugins)?;
|
||||
}
|
||||
None => break, // Incomplete frame, wait for more data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send queued data
|
||||
frame = data_rx.recv() => {
|
||||
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
}
|
||||
|
||||
// Heartbeat
|
||||
_ = heartbeat_interval.tick() => {
|
||||
let timestamp = chrono::Utc::now().to_rfc3339();
|
||||
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, ×tamp);
|
||||
let heartbeat = HeartbeatPayload {
|
||||
device_uid: state.device_uid.clone(),
|
||||
timestamp,
|
||||
hmac: hmac_value,
|
||||
};
|
||||
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
|
||||
stream.write_all(&frame.encode()).await?;
|
||||
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::RegisterResponse => {
|
||||
let resp: RegisterResponse = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
|
||||
info!("Registration accepted by server (server version: {})", resp.config.server_version);
|
||||
}
|
||||
MessageType::PolicyUpdate => {
|
||||
let policy: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
|
||||
info!("Received policy update: {}", policy);
|
||||
}
|
||||
MessageType::ConfigUpdate => {
|
||||
info!("Received config update");
|
||||
}
|
||||
MessageType::TaskExecute => {
|
||||
warn!("Task execution requested (not yet implemented)");
|
||||
}
|
||||
MessageType::WatermarkConfig => {
|
||||
let config: WatermarkConfigPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
|
||||
info!("Received watermark config: enabled={}", config.enabled);
|
||||
plugins.watermark_tx.send(Some(config))?;
|
||||
}
|
||||
MessageType::WebFilterRuleUpdate => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
|
||||
info!("Received web filter rules update");
|
||||
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
|
||||
plugins.web_filter_tx.send(config)?;
|
||||
}
|
||||
MessageType::SoftwareBlacklist => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
|
||||
info!("Received software blacklist update");
|
||||
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
|
||||
plugins.software_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PopupRules => {
|
||||
let payload: serde_json::Value = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
|
||||
info!("Received popup blocker rules update");
|
||||
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
|
||||
.and_then(|r| serde_json::from_value(r.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
|
||||
plugins.popup_blocker_tx.send(config)?;
|
||||
}
|
||||
MessageType::PluginEnable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
|
||||
info!("Plugin enabled: {}", payload.plugin_name);
|
||||
// Route to appropriate plugin channel based on plugin_name
|
||||
handle_plugin_control(&payload, plugins, true)?;
|
||||
}
|
||||
MessageType::PluginDisable => {
|
||||
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
|
||||
info!("Plugin disabled: {}", payload.plugin_name);
|
||||
handle_plugin_control(&payload, plugins, false)?;
|
||||
}
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_plugin_control(
|
||||
payload: &csm_protocol::PluginControlPayload,
|
||||
plugins: &PluginChannels,
|
||||
enabled: bool,
|
||||
) -> Result<()> {
|
||||
match payload.plugin_name.as_str() {
|
||||
"watermark" => {
|
||||
if !enabled {
|
||||
// Send disabled config to remove overlay
|
||||
plugins.watermark_tx.send(None)?;
|
||||
}
|
||||
// When enabling, server will push the actual config next
|
||||
}
|
||||
"web_filter" => {
|
||||
if !enabled {
|
||||
// Clear hosts rules on disable
|
||||
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
// When enabling, server will push rules
|
||||
}
|
||||
"software_blocker" => {
|
||||
if !enabled {
|
||||
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
|
||||
}
|
||||
}
|
||||
"popup_blocker" => {
|
||||
if !enabled {
|
||||
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usb_audit" => {
|
||||
if !enabled {
|
||||
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
|
||||
}
|
||||
}
|
||||
"usage_timer" => {
|
||||
if !enabled {
|
||||
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown plugin: {}", payload.plugin_name);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
|
||||
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
|
||||
let secret = match secret {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return String::new(),
|
||||
};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
fn get_os_info() -> String {
|
||||
use sysinfo::System;
|
||||
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
|
||||
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
|
||||
format!("{} {}", name, version)
|
||||
}
|
||||
254
crates/client/src/popup_blocker/mod.rs
Normal file
254
crates/client/src/popup_blocker/mod.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Popup blocker rule from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct PopupRule {
|
||||
pub id: i64,
|
||||
pub rule_type: String,
|
||||
pub window_title: Option<String>,
|
||||
pub window_class: Option<String>,
|
||||
pub process_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Popup blocker configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct PopupBlockerConfig {
|
||||
pub enabled: bool,
|
||||
pub rules: Vec<PopupRule>,
|
||||
}
|
||||
|
||||
/// Context passed to EnumWindows callback via LPARAM
|
||||
struct ScanContext {
|
||||
rules: Vec<PopupRule>,
|
||||
blocked_count: u32,
|
||||
}
|
||||
|
||||
/// Start popup blocker plugin.
|
||||
/// Periodically enumerates windows and closes those matching rules.
|
||||
pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
|
||||
info!("Popup blocker plugin started");
|
||||
let mut config = PopupBlockerConfig::default();
|
||||
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(2));
|
||||
scan_interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Popup blocker config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
|
||||
config = new_config;
|
||||
}
|
||||
_ = scan_interval.tick() => {
|
||||
if !config.enabled || config.rules.is_empty() {
|
||||
continue;
|
||||
}
|
||||
scan_and_block(&config.rules);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_and_block(rules: &[PopupRule]) {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::WindowsAndMessaging::EnumWindows;
|
||||
use windows::Win32::Foundation::LPARAM;
|
||||
|
||||
let mut ctx = ScanContext {
|
||||
rules: rules.to_vec(),
|
||||
blocked_count: 0,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let _ = EnumWindows(
|
||||
Some(enum_windows_callback),
|
||||
LPARAM(&mut ctx as *mut ScanContext as isize),
|
||||
);
|
||||
}
|
||||
if ctx.blocked_count > 0 {
|
||||
debug!("Popup scan blocked {} windows", ctx.blocked_count);
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = rules;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
unsafe extern "system" fn enum_windows_callback(
|
||||
hwnd: windows::Win32::Foundation::HWND,
|
||||
lparam: windows::Win32::Foundation::LPARAM,
|
||||
) -> windows::Win32::Foundation::BOOL {
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
|
||||
// Only consider visible, top-level windows without an owner
|
||||
if !IsWindowVisible(hwnd).as_bool() {
|
||||
return BOOL(1);
|
||||
}
|
||||
|
||||
// Skip windows that have an owner (they're child dialogs, not popups)
|
||||
if GetWindow(hwnd, GW_OWNER).0 != 0 {
|
||||
return BOOL(1);
|
||||
}
|
||||
|
||||
// Get window title
|
||||
let mut title_buf = [0u16; 512];
|
||||
let title_len = GetWindowTextW(hwnd, &mut title_buf);
|
||||
let title_len = title_len.max(0) as usize;
|
||||
let title = String::from_utf16_lossy(&title_buf[..title_len]);
|
||||
|
||||
// Skip windows with empty titles
|
||||
if title.is_empty() {
|
||||
return BOOL(1);
|
||||
}
|
||||
|
||||
// Get class name
|
||||
let mut class_buf = [0u16; 256];
|
||||
let class_len = GetClassNameW(hwnd, &mut class_buf);
|
||||
let class_len = class_len.max(0) as usize;
|
||||
let class_name = String::from_utf16_lossy(&class_buf[..class_len]);
|
||||
|
||||
// Get process name from PID
|
||||
let mut pid: u32 = 0;
|
||||
GetWindowThreadProcessId(hwnd, Some(&mut pid));
|
||||
let process_name = get_process_name(pid);
|
||||
|
||||
// Recover the ScanContext from LPARAM
|
||||
let ctx = &mut *(lparam.0 as *mut ScanContext);
|
||||
|
||||
// Check against each rule
|
||||
for rule in &ctx.rules {
|
||||
if rule.rule_type != "block" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let matches = rule_matches(rule, &title, &class_name, &process_name);
|
||||
if matches {
|
||||
let _ = PostMessageW(hwnd, WM_CLOSE, WPARAM(0), LPARAM(0));
|
||||
ctx.blocked_count += 1;
|
||||
info!(
|
||||
"Blocked popup: title='{}' class='{}' process='{}' (rule_id={})",
|
||||
title, class_name, process_name, rule.id
|
||||
);
|
||||
break; // One match is enough per window
|
||||
}
|
||||
}
|
||||
|
||||
BOOL(1) // Continue enumeration
|
||||
}
|
||||
|
||||
fn rule_matches(rule: &PopupRule, title: &str, class_name: &str, process_name: &str) -> bool {
|
||||
let title_match = match &rule.window_title {
|
||||
Some(pattern) => pattern_match(pattern, title),
|
||||
None => true, // No title filter = match all
|
||||
};
|
||||
|
||||
let class_match = match &rule.window_class {
|
||||
Some(pattern) => pattern_match(pattern, class_name),
|
||||
None => true,
|
||||
};
|
||||
|
||||
let process_match = match &rule.process_name {
|
||||
Some(pattern) => pattern_match(pattern, process_name),
|
||||
None => true,
|
||||
};
|
||||
|
||||
title_match && class_match && process_match
|
||||
}
|
||||
|
||||
/// Simple case-insensitive wildcard pattern matching.
|
||||
/// Supports `*` as wildcard (matches any characters).
|
||||
fn pattern_match(pattern: &str, text: &str) -> bool {
|
||||
let p = pattern.to_lowercase();
|
||||
let t = text.to_lowercase();
|
||||
|
||||
if !p.contains('*') {
|
||||
return t.contains(&p);
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = p.split('*').collect();
|
||||
if parts.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut pos = 0usize;
|
||||
let mut matched_any = false;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
// Leading empty = pattern starts with * → no start anchor
|
||||
// Trailing empty = pattern ends with * → no end anchor
|
||||
continue;
|
||||
}
|
||||
|
||||
matched_any = true;
|
||||
|
||||
if i == 0 && !parts[0].is_empty() {
|
||||
// Pattern starts with literal → must match at start
|
||||
if !t.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else {
|
||||
// Find this segment anywhere after current position
|
||||
match t[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If pattern ends with literal (no trailing *), must match at end
|
||||
if matched_any && !parts.last().map_or(true, |p| p.is_empty()) {
|
||||
return t.ends_with(parts.last().unwrap());
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn get_process_name(pid: u32) -> String {
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
|
||||
unsafe {
|
||||
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return format!("pid:{}", pid),
|
||||
};
|
||||
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
if entry.th32ProcessID == pid {
|
||||
let name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
let _ = CloseHandle(snapshot);
|
||||
return name;
|
||||
}
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if Process32NextW(snapshot, &mut entry).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
format!("pid:{}", pid)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn get_process_name(pid: u32) -> String {
|
||||
format!("pid:{}", pid)
|
||||
}
|
||||
254
crates/client/src/software_blocker/mod.rs
Normal file
254
crates/client/src/software_blocker/mod.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// System-critical processes that must never be killed regardless of server rules.
|
||||
/// Killing any of these would cause system instability or a BSOD.
|
||||
const PROTECTED_PROCESSES: &[&str] = &[
|
||||
"system",
|
||||
"system idle process",
|
||||
"svchost.exe",
|
||||
"lsass.exe",
|
||||
"csrss.exe",
|
||||
"wininit.exe",
|
||||
"winlogon.exe",
|
||||
"services.exe",
|
||||
"dwm.exe",
|
||||
"explorer.exe",
|
||||
"taskhostw.exe",
|
||||
"registry",
|
||||
"smss.exe",
|
||||
"conhost.exe",
|
||||
];
|
||||
|
||||
/// Software blacklist entry from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BlacklistEntry {
|
||||
pub id: i64,
|
||||
pub name_pattern: String,
|
||||
pub action: String,
|
||||
}
|
||||
|
||||
/// Software blocker configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SoftwareBlockerConfig {
|
||||
pub enabled: bool,
|
||||
pub blacklist: Vec<BlacklistEntry>,
|
||||
}
|
||||
|
||||
/// Start software blocker plugin.
|
||||
/// Periodically scans running processes against the blacklist.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<SoftwareBlockerConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Software blocker plugin started");
|
||||
let mut config = SoftwareBlockerConfig::default();
|
||||
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
|
||||
scan_interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
|
||||
config = new_config;
|
||||
}
|
||||
_ = scan_interval.tick() => {
|
||||
if !config.enabled || config.blacklist.is_empty() {
|
||||
continue;
|
||||
}
|
||||
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn scan_processes(
|
||||
blacklist: &[BlacklistEntry],
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
) {
|
||||
let running = get_running_processes_with_pids();
|
||||
|
||||
for entry in blacklist {
|
||||
for (process_name, pid) in &running {
|
||||
if pattern_matches(&entry.name_pattern, process_name) {
|
||||
// Never kill system-critical processes
|
||||
if is_protected_process(process_name) {
|
||||
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
|
||||
continue;
|
||||
}
|
||||
|
||||
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
|
||||
|
||||
// Report violation to server
|
||||
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
|
||||
let action_taken = match entry.action.as_str() {
|
||||
"block" => "blocked_install",
|
||||
"alert" => "alerted",
|
||||
other => other,
|
||||
};
|
||||
let violation = SoftwareViolationReport {
|
||||
device_uid: device_uid.to_string(),
|
||||
software_name: process_name.clone(),
|
||||
action_taken: action_taken.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
|
||||
// Kill the process directly by captured PID (avoids TOCTOU race)
|
||||
if entry.action == "block" {
|
||||
kill_process_by_pid(*pid, process_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pattern_matches(pattern: &str, name: &str) -> bool {
|
||||
let pattern_lower = pattern.to_lowercase();
|
||||
let name_lower = name.to_lowercase();
|
||||
// Support wildcard patterns
|
||||
if pattern_lower.contains('*') {
|
||||
let parts: Vec<&str> = pattern_lower.split('*').collect();
|
||||
let mut pos = 0;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if i == 0 && !parts[0].is_empty() {
|
||||
// Pattern starts with literal → must match at start
|
||||
if !name_lower.starts_with(part) {
|
||||
return false;
|
||||
}
|
||||
pos = part.len();
|
||||
} else {
|
||||
match name_lower[pos..].find(part) {
|
||||
Some(idx) => pos += idx + part.len(),
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
// If pattern ends with literal (no trailing *), must match at end
|
||||
if !parts.last().map_or(true, |p| p.is_empty()) {
|
||||
return name_lower.ends_with(parts.last().unwrap());
|
||||
}
|
||||
true
|
||||
} else {
|
||||
name_lower.contains(&pattern_lower)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
|
||||
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
|
||||
let mut procs = Vec::new();
|
||||
unsafe {
|
||||
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return procs,
|
||||
};
|
||||
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
let name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
procs.push((name, entry.th32ProcessID));
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if !Process32NextW(snapshot, &mut entry).is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
}
|
||||
procs
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn kill_process_by_pid(pid: u32, expected_name: &str) {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::System::Threading::{OpenProcess, TerminateProcess, PROCESS_TERMINATE};
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
|
||||
unsafe {
|
||||
// Verify the PID still belongs to the expected process (prevent PID reuse kills)
|
||||
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
let mut name_matches = false;
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
if entry.th32ProcessID == pid {
|
||||
let current_name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
name_matches = current_name.to_lowercase() == expected_name.to_lowercase();
|
||||
break;
|
||||
}
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if !Process32NextW(snapshot, &mut entry).is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
|
||||
if !name_matches {
|
||||
warn!("Skipping kill: PID {} no longer matches expected process '{}'", pid, expected_name);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(handle) = OpenProcess(PROCESS_TERMINATE, false, pid) {
|
||||
let terminated = TerminateProcess(handle, 1).is_ok();
|
||||
let _ = CloseHandle(handle);
|
||||
if terminated {
|
||||
warn!("Killed process: {} (pid={})", expected_name, pid);
|
||||
} else {
|
||||
warn!("TerminateProcess failed for: {} (pid={})", expected_name, pid);
|
||||
}
|
||||
} else {
|
||||
warn!("OpenProcess failed for: {} (pid={}) — insufficient privileges?", expected_name, pid);
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = (pid, expected_name);
|
||||
}
|
||||
}
|
||||
|
||||
fn is_protected_process(name: &str) -> bool {
|
||||
let lower = name.to_lowercase();
|
||||
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
|
||||
}
|
||||
197
crates/client/src/usage_timer/mod.rs
Normal file
197
crates/client/src/usage_timer/mod.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug};
|
||||
use csm_protocol::{Frame, MessageType, UsageDailyReport, AppUsageEntry};
|
||||
|
||||
/// Usage tracking configuration pushed from server
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct UsageConfig {
|
||||
pub enabled: bool,
|
||||
pub idle_threshold_secs: u64,
|
||||
pub report_interval_secs: u64,
|
||||
}
|
||||
|
||||
/// Start the usage timer plugin.
|
||||
/// Tracks active/idle time and foreground application usage.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<UsageConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("Usage timer plugin started");
|
||||
|
||||
let mut config = UsageConfig {
|
||||
enabled: false,
|
||||
idle_threshold_secs: 300, // 5 minutes default
|
||||
report_interval_secs: 300,
|
||||
};
|
||||
|
||||
let mut report_interval = tokio::time::interval(Duration::from_secs(60));
|
||||
report_interval.tick().await;
|
||||
|
||||
let mut active_secs: u64 = 0;
|
||||
let mut idle_secs: u64 = 0;
|
||||
let mut app_usage: std::collections::HashMap<String, u64> = std::collections::HashMap::new();
|
||||
let mut last_tick = Instant::now();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
if new_config.enabled != config.enabled {
|
||||
info!("Usage timer enabled: {}", new_config.enabled);
|
||||
}
|
||||
config = new_config;
|
||||
if config.enabled {
|
||||
report_interval = tokio::time::interval(Duration::from_secs(config.report_interval_secs));
|
||||
report_interval.tick().await;
|
||||
last_tick = Instant::now();
|
||||
}
|
||||
}
|
||||
_ = report_interval.tick() => {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Measure actual elapsed time since last tick
|
||||
let now = Instant::now();
|
||||
let elapsed_secs = now.duration_since(last_tick).as_secs();
|
||||
last_tick = now;
|
||||
|
||||
let idle_ms = get_idle_millis();
|
||||
let is_idle = idle_ms as u64 > config.idle_threshold_secs * 1000;
|
||||
|
||||
if is_idle {
|
||||
idle_secs += elapsed_secs;
|
||||
} else {
|
||||
active_secs += elapsed_secs;
|
||||
}
|
||||
|
||||
// Track foreground app
|
||||
if let Some(app) = get_foreground_app_name() {
|
||||
*app_usage.entry(app).or_insert(0) += elapsed_secs;
|
||||
}
|
||||
|
||||
// Report usage to server
|
||||
if active_secs > 0 || idle_secs > 0 {
|
||||
let report = UsageDailyReport {
|
||||
device_uid: device_uid.clone(),
|
||||
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
|
||||
total_active_minutes: (active_secs / 60) as u32,
|
||||
total_idle_minutes: (idle_secs / 60) as u32,
|
||||
first_active_at: None,
|
||||
last_active_at: Some(chrono::Utc::now().to_rfc3339()),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsageReport, &report) {
|
||||
if data_tx.send(frame).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Report per-app usage
|
||||
for (app, secs) in &app_usage {
|
||||
let entry = AppUsageEntry {
|
||||
device_uid: device_uid.clone(),
|
||||
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
|
||||
app_name: app.clone(),
|
||||
usage_minutes: (secs / 60) as u32,
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::AppUsageReport, &entry) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Reset counters after reporting
|
||||
active_secs = 0;
|
||||
idle_secs = 0;
|
||||
app_usage.clear();
|
||||
}
|
||||
|
||||
debug!("Usage report sent (idle_ms={}, elapsed={}s)", idle_ms, elapsed_secs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get system idle time in milliseconds using GetLastInputInfo + GetTickCount64.
|
||||
/// Both use the same time base. LASTINPUTINFO.dwTime is u32 (GetTickCount legacy),
|
||||
/// so we take the low 32 bits of GetTickCount64 for correct wrapping comparison.
|
||||
fn get_idle_millis() -> u32 {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::Input::KeyboardAndMouse::{GetLastInputInfo, LASTINPUTINFO};
|
||||
|
||||
unsafe {
|
||||
let mut lii = LASTINPUTINFO {
|
||||
cbSize: std::mem::size_of::<LASTINPUTINFO>() as u32,
|
||||
dwTime: 0,
|
||||
};
|
||||
if GetLastInputInfo(&mut lii).as_bool() {
|
||||
// GetTickCount64 returns u64, but dwTime is u32.
|
||||
// Take low 32 bits so wrapping_sub produces correct idle delta.
|
||||
let tick_low32 = (windows::Win32::System::SystemInformation::GetTickCount64() & 0xFFFFFFFF) as u32;
|
||||
return tick_low32.wrapping_sub(lii.dwTime);
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the foreground window's process name using CreateToolhelp32Snapshot
|
||||
fn get_foreground_app_name() -> Option<String> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::UI::WindowsAndMessaging::{GetForegroundWindow, GetWindowThreadProcessId};
|
||||
use windows::Win32::System::Diagnostics::ToolHelp::*;
|
||||
use windows::Win32::Foundation::CloseHandle;
|
||||
|
||||
unsafe {
|
||||
let hwnd = GetForegroundWindow();
|
||||
if hwnd.0 == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pid: u32 = 0;
|
||||
GetWindowThreadProcessId(hwnd, Some(&mut pid));
|
||||
if pid == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get process name via CreateToolhelp32Snapshot
|
||||
let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()?;
|
||||
let mut entry = PROCESSENTRY32W {
|
||||
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if Process32FirstW(snapshot, &mut entry).is_ok() {
|
||||
loop {
|
||||
if entry.th32ProcessID == pid {
|
||||
let name = String::from_utf16_lossy(
|
||||
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
|
||||
);
|
||||
let _ = CloseHandle(snapshot);
|
||||
return Some(name);
|
||||
}
|
||||
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
|
||||
if !Process32NextW(snapshot, &mut entry).is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = CloseHandle(snapshot);
|
||||
None
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
249
crates/client/src/usb/mod.rs
Normal file
249
crates/client/src/usb/mod.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use csm_protocol::{Frame, MessageType, UsbEvent, UsbEventType, UsbPolicyPayload, UsbDeviceRule};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc::Sender, watch};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Start USB monitoring with policy enforcement.
|
||||
/// Monitors for removable drive insertions/removals and enforces USB policies.
|
||||
pub async fn start_monitoring(
|
||||
tx: Sender<Frame>,
|
||||
device_uid: String,
|
||||
mut policy_rx: watch::Receiver<Option<UsbPolicyPayload>>,
|
||||
) {
|
||||
info!("USB monitoring started for {}", device_uid);
|
||||
|
||||
let mut current_policy: Option<UsbPolicyPayload> = None;
|
||||
let mut known_drives: Vec<String> = Vec::new();
|
||||
let interval = Duration::from_secs(10);
|
||||
|
||||
loop {
|
||||
// Check for policy updates (non-blocking)
|
||||
if policy_rx.has_changed().unwrap_or(false) {
|
||||
let new_policy = policy_rx.borrow_and_update().clone();
|
||||
let policy_desc = new_policy.as_ref()
|
||||
.map(|p| format!("type={}, enabled={}", p.policy_type, p.enabled))
|
||||
.unwrap_or_else(|| "None".to_string());
|
||||
info!("USB policy updated: {}", policy_desc);
|
||||
current_policy = new_policy;
|
||||
|
||||
// Re-check all currently mounted drives against new policy
|
||||
let drives = detect_removable_drives();
|
||||
for drive in &drives {
|
||||
if should_block_drive(drive, ¤t_policy) {
|
||||
warn!("USB device {} blocked by policy, ejecting...", drive);
|
||||
if let Err(e) = eject_drive(drive) {
|
||||
warn!("Failed to eject {}: {}", drive, e);
|
||||
} else {
|
||||
info!("Successfully ejected {}", drive);
|
||||
}
|
||||
// Report blocked event
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let current_drives = detect_removable_drives();
|
||||
|
||||
// Detect new drives (inserted)
|
||||
for drive in ¤t_drives {
|
||||
if !known_drives.iter().any(|d: &String| d == drive) {
|
||||
info!("USB device inserted: {}", drive);
|
||||
|
||||
// Check against policy before reporting
|
||||
if should_block_drive(drive, ¤t_policy) {
|
||||
warn!("USB device {} blocked by policy, ejecting...", drive);
|
||||
if let Err(e) = eject_drive(drive) {
|
||||
warn!("Failed to eject {}: {}", drive, e);
|
||||
} else {
|
||||
info!("Successfully ejected {}", drive);
|
||||
}
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Inserted, drive).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Detect removed drives
|
||||
for drive in &known_drives {
|
||||
if !current_drives.iter().any(|d| d == drive) {
|
||||
info!("USB device removed: {}", drive);
|
||||
send_usb_event(&tx, &device_uid, UsbEventType::Removed, drive).await;
|
||||
}
|
||||
}
|
||||
|
||||
known_drives = current_drives;
|
||||
tokio::time::sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_usb_event(
|
||||
tx: &Sender<Frame>,
|
||||
device_uid: &str,
|
||||
event_type: UsbEventType,
|
||||
drive: &str,
|
||||
) {
|
||||
let event = UsbEvent {
|
||||
device_uid: device_uid.to_string(),
|
||||
event_type,
|
||||
vendor_id: None,
|
||||
product_id: None,
|
||||
serial: None,
|
||||
device_name: Some(drive.to_string()),
|
||||
};
|
||||
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsbEvent, &event) {
|
||||
tx.send(frame).await.ok();
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a drive should be blocked based on the current policy.
|
||||
fn should_block_drive(drive: &str, policy: &Option<UsbPolicyPayload>) -> bool {
|
||||
let policy = match policy {
|
||||
Some(p) if p.enabled => p,
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
match policy.policy_type.as_str() {
|
||||
"all_block" => true,
|
||||
"blacklist" => {
|
||||
// Block if any rule matches
|
||||
policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
|
||||
}
|
||||
"whitelist" => {
|
||||
// Block if NOT in whitelist (empty whitelist = block all)
|
||||
if policy.rules.is_empty() {
|
||||
return true;
|
||||
}
|
||||
!policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a device matches a rule pattern.
|
||||
/// Currently matches by device_name (drive letter path) since that's what we detect.
|
||||
fn device_matches_rule(drive: &str, rule: &UsbDeviceRule) -> bool {
|
||||
// Match by device name (drive root path like "E:\")
|
||||
if let Some(ref name) = rule.device_name {
|
||||
if drive.eq_ignore_ascii_case(name) || drive.eq_ignore_ascii_case(&name.replace('\\', "")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// vendor_id, product_id, serial matching would require WMI or SetupDi queries
|
||||
// For now, these are placeholder checks
|
||||
let _ = (drive, rule);
|
||||
false
|
||||
}
|
||||
|
||||
/// DRIVE_REMOVABLE = 2 in Windows API
|
||||
const DRIVE_REMOVABLE: u32 = 2;
|
||||
|
||||
fn detect_removable_drives() -> Vec<String> {
|
||||
let mut drives = Vec::new();
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
for letter in b'A'..=b'Z' {
|
||||
let root = format!("{}:\\", letter as char);
|
||||
let root_wide: Vec<u16> = root.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
unsafe {
|
||||
let drive_type = windows::Win32::Storage::FileSystem::GetDriveTypeW(
|
||||
windows::core::PCWSTR(root_wide.as_ptr()),
|
||||
);
|
||||
if drive_type == DRIVE_REMOVABLE {
|
||||
drives.push(root);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = &drives; // Suppress unused warning on non-Windows
|
||||
}
|
||||
|
||||
drives
|
||||
}
|
||||
|
||||
/// Eject a removable drive using Windows API.
|
||||
/// Opens the volume handle, dismounts the filesystem, and ejects the media.
|
||||
#[cfg(target_os = "windows")]
|
||||
fn eject_drive(drive: &str) -> Result<(), anyhow::Error> {
|
||||
use windows::Win32::Storage::FileSystem::*;
|
||||
use windows::Win32::System::IO::DeviceIoControl;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::core::PCWSTR;
|
||||
|
||||
// IOCTL control codes (not exposed directly in windows crate)
|
||||
const FSCTL_DISMOUNT_VOLUME: u32 = 0x00090020;
|
||||
const IOCTL_STORAGE_EJECT_MEDIA: u32 = 0x002D4808;
|
||||
|
||||
// Extract drive letter from path like "E:\"
|
||||
let letter = drive.chars().next().unwrap_or('A');
|
||||
let path = format!("\\\\.\\{}:\0", letter);
|
||||
let path_wide: Vec<u16> = path.encode_utf16().collect();
|
||||
|
||||
unsafe {
|
||||
let handle = CreateFileW(
|
||||
PCWSTR(path_wide.as_ptr()),
|
||||
GENERIC_READ.0,
|
||||
FILE_SHARE_READ | FILE_SHARE_WRITE,
|
||||
None,
|
||||
OPEN_EXISTING,
|
||||
FILE_ATTRIBUTE_NORMAL,
|
||||
None,
|
||||
)?;
|
||||
|
||||
if handle.is_invalid() {
|
||||
return Err(anyhow::anyhow!("Failed to open volume handle for {}", drive));
|
||||
}
|
||||
|
||||
// Dismount the filesystem
|
||||
let mut bytes_returned = 0u32;
|
||||
let dismount_ok = DeviceIoControl(
|
||||
handle,
|
||||
FSCTL_DISMOUNT_VOLUME,
|
||||
None,
|
||||
0,
|
||||
None,
|
||||
0,
|
||||
Some(&mut bytes_returned),
|
||||
None,
|
||||
).is_ok();
|
||||
|
||||
if !dismount_ok {
|
||||
warn!("FSCTL_DISMOUNT_VOLUME failed for {}", drive);
|
||||
}
|
||||
|
||||
// Eject the media
|
||||
let eject_ok = DeviceIoControl(
|
||||
handle,
|
||||
IOCTL_STORAGE_EJECT_MEDIA,
|
||||
None,
|
||||
0,
|
||||
None,
|
||||
0,
|
||||
Some(&mut bytes_returned),
|
||||
None,
|
||||
).is_ok();
|
||||
|
||||
if !eject_ok {
|
||||
warn!("IOCTL_STORAGE_EJECT_MEDIA failed for {}", drive);
|
||||
}
|
||||
|
||||
let _ = CloseHandle(handle);
|
||||
|
||||
if !dismount_ok && !eject_ok {
|
||||
Err(anyhow::anyhow!("Failed to eject {}", drive))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn eject_drive(_drive: &str) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
190
crates/client/src/usb_audit/mod.rs
Normal file
190
crates/client/src/usb_audit/mod.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, debug, warn};
|
||||
use csm_protocol::{Frame, MessageType, UsbFileOpEntry};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// USB file audit configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct UsbAuditConfig {
|
||||
pub enabled: bool,
|
||||
pub monitored_extensions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Start USB file audit plugin.
|
||||
/// Detects removable drives and monitors file changes via periodic scanning.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<UsbAuditConfig>,
|
||||
data_tx: tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: String,
|
||||
) {
|
||||
info!("USB file audit plugin started");
|
||||
let mut config = UsbAuditConfig::default();
|
||||
let mut active_drives: HashSet<String> = HashSet::new();
|
||||
// Track file listings per drive to detect changes
|
||||
let mut drive_snapshots: HashMap<String, HashSet<String>> = HashMap::new();
|
||||
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
|
||||
interval.tick().await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("USB audit config updated: enabled={}", new_config.enabled);
|
||||
config = new_config;
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
let drives = detect_removable_drives();
|
||||
|
||||
// Detect new drives
|
||||
for drive in &drives {
|
||||
if !active_drives.contains(drive) {
|
||||
info!("New removable drive detected: {}", drive);
|
||||
// Take initial snapshot in blocking thread to avoid freezing
|
||||
let drive_clone = drive.clone();
|
||||
let files = tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await;
|
||||
match files {
|
||||
Ok(files) => { drive_snapshots.insert(drive.clone(), files); }
|
||||
Err(e) => { warn!("Failed to scan drive {}: {}", drive, e); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scan existing drives for changes (each in a blocking thread)
|
||||
for drive in &drives {
|
||||
let drive_clone = drive.clone();
|
||||
let current_files = match tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await {
|
||||
Ok(files) => files,
|
||||
Err(e) => { warn!("Drive scan failed for {}: {}", drive, e); continue; }
|
||||
};
|
||||
if let Some(prev_files) = drive_snapshots.get_mut(drive) {
|
||||
// Find new files (created)
|
||||
for file in ¤t_files {
|
||||
if !prev_files.contains(file) {
|
||||
report_file_op(&data_tx, &device_uid, drive, file, "create", &config.monitored_extensions).await;
|
||||
}
|
||||
}
|
||||
// Find deleted files
|
||||
for file in prev_files.iter() {
|
||||
if !current_files.contains(file) {
|
||||
report_file_op(&data_tx, &device_uid, drive, file, "delete", &config.monitored_extensions).await;
|
||||
}
|
||||
}
|
||||
*prev_files = current_files;
|
||||
} else {
|
||||
drive_snapshots.insert(drive.clone(), current_files);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up removed drives
|
||||
active_drives.retain(|d| drives.contains(d));
|
||||
drive_snapshots.retain(|d, _| drives.contains(d));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn report_file_op(
|
||||
data_tx: &tokio::sync::mpsc::Sender<Frame>,
|
||||
device_uid: &str,
|
||||
drive: &str,
|
||||
file_path: &str,
|
||||
operation: &str,
|
||||
ext_filter: &[String],
|
||||
) {
|
||||
// Check extension filter
|
||||
let should_report = if ext_filter.is_empty() {
|
||||
true
|
||||
} else {
|
||||
let file_ext = std::path::Path::new(file_path)
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e.to_lowercase());
|
||||
file_ext.map_or(true, |ext| {
|
||||
ext_filter.iter().any(|f| f.to_lowercase() == ext)
|
||||
})
|
||||
};
|
||||
|
||||
if should_report {
|
||||
let entry = UsbFileOpEntry {
|
||||
device_uid: device_uid.to_string(),
|
||||
usb_serial: None,
|
||||
drive_letter: Some(drive.to_string()),
|
||||
operation: operation.to_string(),
|
||||
file_path: file_path.to_string(),
|
||||
file_size: None,
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsbFileOp, &entry) {
|
||||
let _ = data_tx.send(frame).await;
|
||||
}
|
||||
debug!("USB file op: {} {}", operation, file_path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively scan a drive and return all file paths
|
||||
fn scan_drive_files(drive: &str) -> HashSet<String> {
|
||||
let mut files = HashSet::new();
|
||||
let max_depth = 3; // Limit recursion depth for performance
|
||||
scan_dir_recursive(drive, &mut files, 0, max_depth);
|
||||
files
|
||||
}
|
||||
|
||||
fn scan_dir_recursive(dir: &str, files: &mut HashSet<String>, depth: usize, max_depth: usize) {
|
||||
if depth >= max_depth {
|
||||
return;
|
||||
}
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
if path.is_dir() {
|
||||
scan_dir_recursive(&path_str, files, depth + 1, max_depth);
|
||||
} else {
|
||||
files.insert(path_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_removable_drives() -> Vec<String> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use windows::Win32::Storage::FileSystem::GetDriveTypeW;
|
||||
use windows::core::PCWSTR;
|
||||
|
||||
let mut drives = Vec::new();
|
||||
let mask = unsafe { windows::Win32::Storage::FileSystem::GetLogicalDrives() };
|
||||
let mut mask = mask as u32;
|
||||
let mut letter = b'A';
|
||||
|
||||
while mask != 0 {
|
||||
if mask & 1 != 0 {
|
||||
let drive_letter = format!("{}:\\", letter as char);
|
||||
let drive_wide: Vec<u16> = drive_letter.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let drive_type = unsafe {
|
||||
GetDriveTypeW(PCWSTR(drive_wide.as_ptr()))
|
||||
};
|
||||
// DRIVE_REMOVABLE = 2
|
||||
if drive_type == 2 {
|
||||
drives.push(drive_letter);
|
||||
}
|
||||
}
|
||||
mask >>= 1;
|
||||
letter += 1;
|
||||
}
|
||||
drives
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
313
crates/client/src/watermark/mod.rs
Normal file
313
crates/client/src/watermark/mod.rs
Normal file
@@ -0,0 +1,313 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn, error};
|
||||
use csm_protocol::WatermarkConfigPayload;
|
||||
|
||||
/// Watermark overlay state
|
||||
struct WatermarkState {
|
||||
enabled: bool,
|
||||
content: String,
|
||||
font_size: u32,
|
||||
opacity: f64,
|
||||
color: String,
|
||||
angle: i32,
|
||||
hwnd: Option<isize>,
|
||||
}
|
||||
|
||||
impl Default for WatermarkState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
content: String::new(),
|
||||
font_size: 14,
|
||||
opacity: 0.15,
|
||||
color: "#808080".to_string(),
|
||||
angle: -30,
|
||||
hwnd: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static WATERMARK_STATE: std::sync::OnceLock<Arc<Mutex<WatermarkState>>> = std::sync::OnceLock::new();
|
||||
|
||||
const WM_USER_UPDATE: u32 = 0x0401;
|
||||
const OVERLAY_CLASS_NAME: &str = "CSM_WatermarkOverlay";
|
||||
|
||||
/// Start the watermark plugin, listening for config updates.
|
||||
pub async fn start(mut rx: watch::Receiver<Option<WatermarkConfigPayload>>) {
|
||||
info!("Watermark plugin started");
|
||||
|
||||
let state = WATERMARK_STATE.get_or_init(|| Arc::new(Mutex::new(WatermarkState::default())));
|
||||
let state_clone = state.clone();
|
||||
|
||||
// Spawn a dedicated thread for the Win32 message loop
|
||||
std::thread::spawn(move || {
|
||||
#[cfg(target_os = "windows")]
|
||||
run_overlay_message_loop(state_clone);
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
let _ = state_clone;
|
||||
info!("Watermark overlay not supported on this platform");
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.changed() => {
|
||||
if result.is_err() {
|
||||
warn!("Watermark config channel closed");
|
||||
break;
|
||||
}
|
||||
let config = rx.borrow_and_update().clone();
|
||||
if let Some(cfg) = config {
|
||||
info!("Watermark config updated: enabled={}, content='{}'", cfg.enabled, cfg.content);
|
||||
|
||||
{
|
||||
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
s.enabled = cfg.enabled;
|
||||
s.content = cfg.content.clone();
|
||||
s.font_size = cfg.font_size;
|
||||
s.opacity = cfg.opacity;
|
||||
s.color = cfg.color.clone();
|
||||
s.angle = cfg.angle;
|
||||
}
|
||||
|
||||
// Post message to overlay window thread
|
||||
#[cfg(target_os = "windows")]
|
||||
post_overlay_update();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn post_overlay_update() {
|
||||
use windows::Win32::UI::WindowsAndMessaging::{FindWindowA, PostMessageW};
|
||||
use windows::Win32::Foundation::{WPARAM, LPARAM};
|
||||
use windows::core::PCSTR;
|
||||
|
||||
unsafe {
|
||||
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
|
||||
let hwnd = FindWindowA(PCSTR(class_name.as_ptr()), PCSTR::null());
|
||||
if hwnd.0 != 0 {
|
||||
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn run_overlay_message_loop(state: Arc<Mutex<WatermarkState>>) {
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::System::LibraryLoader::GetModuleHandleA;
|
||||
use windows::core::PCSTR;
|
||||
|
||||
// Register window class
|
||||
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
|
||||
let instance = unsafe { GetModuleHandleA(PCSTR::null()) }.unwrap_or_default();
|
||||
let instance: HINSTANCE = instance.into();
|
||||
|
||||
let wc = WNDCLASSA {
|
||||
style: CS_HREDRAW | CS_VREDRAW,
|
||||
lpfnWndProc: Some(watermark_wnd_proc),
|
||||
hInstance: instance,
|
||||
lpszClassName: PCSTR(class_name.as_ptr()),
|
||||
hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
unsafe { RegisterClassA(&wc); }
|
||||
|
||||
let screen_w = unsafe { GetSystemMetrics(SM_CXSCREEN) };
|
||||
let screen_h = unsafe { GetSystemMetrics(SM_CYSCREEN) };
|
||||
|
||||
// WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW
|
||||
let ex_style = WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW;
|
||||
let style = WS_POPUP;
|
||||
|
||||
let hwnd = unsafe {
|
||||
CreateWindowExA(
|
||||
ex_style,
|
||||
PCSTR(class_name.as_ptr()),
|
||||
PCSTR("CSM Watermark\0".as_ptr()),
|
||||
style,
|
||||
0, 0, screen_w, screen_h,
|
||||
HWND::default(),
|
||||
HMENU::default(),
|
||||
instance,
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
if hwnd.0 == 0 {
|
||||
error!("Failed to create watermark overlay window");
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
s.hwnd = Some(hwnd.0);
|
||||
}
|
||||
|
||||
info!("Watermark overlay window created (hwnd={})", hwnd.0);
|
||||
|
||||
// Post initial update to self — ensures overlay shows even if config
|
||||
// arrived before window creation completed (race between threads)
|
||||
unsafe {
|
||||
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
|
||||
}
|
||||
|
||||
// Message loop
|
||||
let mut msg = MSG::default();
|
||||
unsafe {
|
||||
loop {
|
||||
match GetMessageA(&mut msg, HWND::default(), 0, 0) {
|
||||
BOOL(0) | BOOL(-1) => break,
|
||||
_ => {
|
||||
TranslateMessage(&msg);
|
||||
DispatchMessageA(&msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
unsafe extern "system" fn watermark_wnd_proc(
|
||||
hwnd: windows::Win32::Foundation::HWND,
|
||||
msg: u32,
|
||||
wparam: windows::Win32::Foundation::WPARAM,
|
||||
lparam: windows::Win32::Foundation::LPARAM,
|
||||
) -> windows::Win32::Foundation::LRESULT {
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
|
||||
match msg {
|
||||
WM_PAINT => {
|
||||
if let Some(state) = WATERMARK_STATE.get() {
|
||||
let s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if s.enabled && !s.content.is_empty() {
|
||||
paint_watermark(hwnd, &s);
|
||||
}
|
||||
}
|
||||
LRESULT(0)
|
||||
}
|
||||
WM_ERASEBKGND => {
|
||||
// Fill with black (will be color-keyed to transparent)
|
||||
let hdc = HDC(wparam.0 as isize);
|
||||
unsafe {
|
||||
let rect = {
|
||||
let mut r = RECT::default();
|
||||
let _ = GetClientRect(hwnd, &mut r);
|
||||
r
|
||||
};
|
||||
let brush = CreateSolidBrush(COLORREF(0x00000000));
|
||||
let _ = FillRect(hdc, &rect, brush);
|
||||
let _ = DeleteObject(brush);
|
||||
}
|
||||
LRESULT(1)
|
||||
}
|
||||
m if m == WM_USER_UPDATE => {
|
||||
if let Some(state) = WATERMARK_STATE.get() {
|
||||
let s = state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if s.enabled {
|
||||
let _ = SetWindowPos(
|
||||
hwnd, HWND_TOPMOST,
|
||||
0, 0,
|
||||
GetSystemMetrics(SM_CXSCREEN),
|
||||
GetSystemMetrics(SM_CYSCREEN),
|
||||
SWP_SHOWWINDOW,
|
||||
);
|
||||
let alpha = (s.opacity * 255.0).clamp(0.0, 255.0) as u8;
|
||||
// Color key black background to transparent, apply alpha to text
|
||||
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY | LWA_ALPHA);
|
||||
let _ = InvalidateRect(hwnd, None, true);
|
||||
} else {
|
||||
let _ = ShowWindow(hwnd, SW_HIDE);
|
||||
}
|
||||
}
|
||||
LRESULT(0)
|
||||
}
|
||||
WM_DESTROY => {
|
||||
PostQuitMessage(0);
|
||||
LRESULT(0)
|
||||
}
|
||||
_ => DefWindowProcA(hwnd, msg, wparam, lparam),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkState) {
|
||||
use windows::Win32::Graphics::Gdi::*;
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::core::PCSTR;
|
||||
|
||||
unsafe {
|
||||
let mut ps = PAINTSTRUCT::default();
|
||||
let hdc = BeginPaint(hwnd, &mut ps);
|
||||
|
||||
let color = parse_color(&state.color);
|
||||
let font_size = state.font_size.max(1);
|
||||
|
||||
// Create font with rotation
|
||||
let font = CreateFontA(
|
||||
(font_size as i32) * 2,
|
||||
0,
|
||||
(state.angle as i32) * 10,
|
||||
0,
|
||||
FW_NORMAL.0 as i32,
|
||||
0, 0, 0,
|
||||
DEFAULT_CHARSET.0 as u32,
|
||||
OUT_DEFAULT_PRECIS.0 as u32,
|
||||
CLIP_DEFAULT_PRECIS.0 as u32,
|
||||
DEFAULT_QUALITY.0 as u32,
|
||||
DEFAULT_PITCH.0 as u32 | FF_DONTCARE.0 as u32,
|
||||
PCSTR("Arial\0".as_ptr()),
|
||||
);
|
||||
|
||||
let old_font = SelectObject(hdc, font);
|
||||
|
||||
let _ = SetBkMode(hdc, TRANSPARENT);
|
||||
let _ = SetTextColor(hdc, color);
|
||||
|
||||
// Draw tiled watermark text
|
||||
let screen_w = GetSystemMetrics(SM_CXSCREEN);
|
||||
let screen_h = GetSystemMetrics(SM_CYSCREEN);
|
||||
|
||||
let content_bytes: Vec<u8> = state.content.bytes().chain(std::iter::once(0)).collect();
|
||||
let text_slice = &content_bytes[..content_bytes.len().saturating_sub(1)];
|
||||
|
||||
let spacing_x = 400i32;
|
||||
let spacing_y = 200i32;
|
||||
|
||||
let mut y = -100i32;
|
||||
while y < screen_h + 100 {
|
||||
let mut x = -200i32;
|
||||
while x < screen_w + 200 {
|
||||
let _ = TextOutA(hdc, x, y, text_slice);
|
||||
x += spacing_x;
|
||||
}
|
||||
y += spacing_y;
|
||||
}
|
||||
|
||||
SelectObject(hdc, old_font);
|
||||
let _ = DeleteObject(font);
|
||||
EndPaint(hwnd, &ps);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_color(hex: &str) -> windows::Win32::Foundation::COLORREF {
|
||||
use windows::Win32::Foundation::COLORREF;
|
||||
let hex = hex.trim_start_matches('#');
|
||||
if hex.len() == 6 {
|
||||
let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(128);
|
||||
let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(128);
|
||||
let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(128);
|
||||
COLORREF((b as u32) << 16 | (g as u32) << 8 | (r as u32))
|
||||
} else {
|
||||
COLORREF(0x00808080)
|
||||
}
|
||||
}
|
||||
169
crates/client/src/web_filter/mod.rs
Normal file
169
crates/client/src/web_filter/mod.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn, debug};
|
||||
use serde::Deserialize;
|
||||
use std::io;
|
||||
|
||||
/// Web filter rule from server
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct WebFilterRule {
|
||||
pub id: i64,
|
||||
pub rule_type: String,
|
||||
pub pattern: String,
|
||||
}
|
||||
|
||||
/// Web filter configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct WebFilterConfig {
|
||||
pub enabled: bool,
|
||||
pub rules: Vec<WebFilterRule>,
|
||||
}
|
||||
|
||||
/// Start web filter plugin.
|
||||
/// Manages the hosts file to block/allow URLs based on server rules.
|
||||
pub async fn start(
|
||||
mut config_rx: watch::Receiver<WebFilterConfig>,
|
||||
) {
|
||||
info!("Web filter plugin started");
|
||||
let mut _config = WebFilterConfig::default();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = config_rx.changed() => {
|
||||
if result.is_err() {
|
||||
break;
|
||||
}
|
||||
let new_config = config_rx.borrow_and_update().clone();
|
||||
info!("Web filter config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
|
||||
|
||||
if new_config.enabled {
|
||||
match apply_hosts_rules(&new_config.rules) {
|
||||
Ok(()) => info!("Web filter hosts rules applied ({} rules)", new_config.rules.len()),
|
||||
Err(e) => warn!("Failed to apply hosts rules: {}", e),
|
||||
}
|
||||
} else {
|
||||
match clear_hosts_rules() {
|
||||
Ok(()) => info!("Web filter hosts rules cleared"),
|
||||
Err(e) => warn!("Failed to clear hosts rules: {}", e),
|
||||
}
|
||||
}
|
||||
_config = new_config;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const HOSTS_MARKER_START: &str = "# CSM_WEB_FILTER_START";
|
||||
const HOSTS_MARKER_END: &str = "# CSM_WEB_FILTER_END";
|
||||
|
||||
fn apply_hosts_rules(rules: &[WebFilterRule]) -> io::Result<()> {
|
||||
let hosts_path = get_hosts_path();
|
||||
let original = std::fs::read_to_string(&hosts_path)?;
|
||||
|
||||
// Remove existing CSM block
|
||||
let clean = remove_csm_block(&original);
|
||||
|
||||
// Build new block with block rules
|
||||
let block_rules: Vec<&WebFilterRule> = rules.iter()
|
||||
.filter(|r| r.rule_type == "block")
|
||||
.filter(|r| is_valid_hosts_entry(&r.pattern))
|
||||
.collect();
|
||||
|
||||
if block_rules.is_empty() {
|
||||
atomic_write(&hosts_path, &clean)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut new_block = format!("{}\n", HOSTS_MARKER_START);
|
||||
for rule in &block_rules {
|
||||
// Redirect blocked domains to 127.0.0.1
|
||||
new_block.push_str(&format!("127.0.0.1 {}\n", rule.pattern));
|
||||
}
|
||||
new_block.push_str(HOSTS_MARKER_END);
|
||||
new_block.push('\n');
|
||||
|
||||
let new_content = format!("{}{}", clean, new_block);
|
||||
atomic_write(&hosts_path, &new_content)?;
|
||||
|
||||
debug!("Applied {} web filter rules to hosts file", block_rules.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_hosts_rules() -> io::Result<()> {
|
||||
let hosts_path = get_hosts_path();
|
||||
let original = std::fs::read_to_string(&hosts_path)?;
|
||||
let clean = remove_csm_block(&original);
|
||||
atomic_write(&hosts_path, &clean)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write content to a file atomically.
|
||||
/// On Windows, hosts file may be locked by the DNS cache service,
|
||||
/// so we use direct overwrite with truncation instead of rename.
|
||||
fn atomic_write(path: &str, content: &str) -> io::Result<()> {
|
||||
use std::io::Write;
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.create(true)
|
||||
.open(path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
file.sync_data()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_csm_block(content: &str) -> String {
|
||||
// Handle paired markers (normal case)
|
||||
let start_idx = content.find(HOSTS_MARKER_START);
|
||||
let end_idx = content.find(HOSTS_MARKER_END);
|
||||
|
||||
match (start_idx, end_idx) {
|
||||
(Some(s), Some(e)) if e > s => {
|
||||
let mut result = String::new();
|
||||
result.push_str(&content[..s]);
|
||||
result.push_str(&content[e + HOSTS_MARKER_END.len()..]);
|
||||
return result.trim_end().to_string() + "\n";
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Handle orphan markers: remove any lone START or END marker line
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let cleaned: Vec<&str> = lines.iter()
|
||||
.filter(|line| {
|
||||
let trimmed = line.trim();
|
||||
trimmed != HOSTS_MARKER_START && trimmed != HOSTS_MARKER_END
|
||||
})
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let mut result = cleaned.join("\n");
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn get_hosts_path() -> String {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
r"C:\Windows\System32\drivers\etc\hosts".to_string()
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
{
|
||||
"/etc/hosts".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that a pattern is safe to write to the hosts file.
|
||||
/// Rejects patterns with whitespace, control chars, or comment markers.
|
||||
fn is_valid_hosts_entry(pattern: &str) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return false;
|
||||
}
|
||||
// Reject if contains whitespace, control chars, or comment markers
|
||||
if pattern.chars().any(|c| c.is_whitespace() || c.is_control() || c == '#') {
|
||||
return false;
|
||||
}
|
||||
// Must look like a hostname (alphanumeric, dots, hyphens, underscores, asterisks)
|
||||
pattern.chars().all(|c| c.is_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '*')
|
||||
}
|
||||
12
crates/protocol/Cargo.toml
Normal file
12
crates/protocol/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "csm-protocol"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
108
crates/protocol/src/device.rs
Normal file
108
crates/protocol/src/device.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Real-time device status report
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct DeviceStatus {
|
||||
pub device_uid: String,
|
||||
pub cpu_usage: f64,
|
||||
pub memory_usage: f64,
|
||||
pub memory_total_mb: u64,
|
||||
pub disk_usage: f64,
|
||||
pub disk_total_mb: u64,
|
||||
pub network_rx_rate: u64,
|
||||
pub network_tx_rate: u64,
|
||||
pub running_procs: u32,
|
||||
pub top_processes: Vec<ProcessInfo>,
|
||||
}
|
||||
|
||||
/// Top process information
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ProcessInfo {
|
||||
pub name: String,
|
||||
pub pid: u32,
|
||||
pub cpu_usage: f64,
|
||||
pub memory_mb: u64,
|
||||
}
|
||||
|
||||
/// Hardware asset information
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct HardwareAsset {
|
||||
pub device_uid: String,
|
||||
pub cpu_model: String,
|
||||
pub cpu_cores: u32,
|
||||
pub memory_total_mb: u64,
|
||||
pub disk_model: String,
|
||||
pub disk_total_mb: u64,
|
||||
pub gpu_model: Option<String>,
|
||||
pub motherboard: Option<String>,
|
||||
pub serial_number: Option<String>,
|
||||
}
|
||||
|
||||
/// Software asset information
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct SoftwareAsset {
|
||||
pub device_uid: String,
|
||||
pub name: String,
|
||||
pub version: Option<String>,
|
||||
pub publisher: Option<String>,
|
||||
pub install_date: Option<String>,
|
||||
pub install_path: Option<String>,
|
||||
}
|
||||
|
||||
/// USB device event
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct UsbEvent {
|
||||
pub device_uid: String,
|
||||
pub event_type: UsbEventType,
|
||||
pub vendor_id: Option<String>,
|
||||
pub product_id: Option<String>,
|
||||
pub serial: Option<String>,
|
||||
pub device_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum UsbEventType {
|
||||
Inserted,
|
||||
Removed,
|
||||
Blocked,
|
||||
}
|
||||
|
||||
/// Asset change event
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct AssetChange {
|
||||
pub device_uid: String,
|
||||
pub change_type: AssetChangeType,
|
||||
pub change_detail: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AssetChangeType {
|
||||
Hardware,
|
||||
SoftwareAdded,
|
||||
SoftwareRemoved,
|
||||
}
|
||||
|
||||
/// USB policy (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct UsbPolicy {
|
||||
pub policy_id: i64,
|
||||
pub policy_type: UsbPolicyType,
|
||||
pub allowed_devices: Vec<UsbDevicePattern>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UsbPolicyType {
|
||||
AllBlock,
|
||||
Whitelist,
|
||||
Blacklist,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct UsbDevicePattern {
|
||||
pub vendor_id: Option<String>,
|
||||
pub product_id: Option<String>,
|
||||
pub serial: Option<String>,
|
||||
}
|
||||
27
crates/protocol/src/lib.rs
Normal file
27
crates/protocol/src/lib.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub mod message;
|
||||
pub mod device;
|
||||
|
||||
// Re-export constants from message module
|
||||
pub use message::{MAGIC, PROTOCOL_VERSION, FRAME_HEADER_SIZE, MAX_PAYLOAD_SIZE};
|
||||
|
||||
// Core frame & message types
|
||||
pub use message::{
|
||||
Frame, FrameError, MessageType,
|
||||
RegisterRequest, RegisterResponse, ClientConfig,
|
||||
HeartbeatPayload, TaskExecutePayload, ConfigUpdateType,
|
||||
};
|
||||
|
||||
// Device status & asset types
|
||||
pub use device::{
|
||||
DeviceStatus, ProcessInfo, HardwareAsset, SoftwareAsset,
|
||||
UsbEvent, UsbEventType, AssetChange, AssetChangeType,
|
||||
UsbPolicy, UsbPolicyType, UsbDevicePattern,
|
||||
};
|
||||
|
||||
// Plugin message payloads
|
||||
pub use message::{
|
||||
WebAccessLogEntry, UsageDailyReport, AppUsageEntry,
|
||||
SoftwareViolationReport, UsbFileOpEntry,
|
||||
WatermarkConfigPayload, PluginControlPayload,
|
||||
UsbPolicyPayload, UsbDeviceRule,
|
||||
};
|
||||
417
crates/protocol/src/message.rs
Normal file
417
crates/protocol/src/message.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Protocol magic bytes: "CSM\0"
|
||||
pub const MAGIC: [u8; 4] = [0x43, 0x53, 0x4D, 0x00];
|
||||
|
||||
/// Current protocol version
|
||||
pub const PROTOCOL_VERSION: u8 = 0x01;
|
||||
|
||||
/// Frame header size: magic(4) + version(1) + type(1) + length(4)
|
||||
pub const FRAME_HEADER_SIZE: usize = 10;
|
||||
|
||||
/// Maximum payload size: 4 MB — prevents memory exhaustion from malicious frames
|
||||
pub const MAX_PAYLOAD_SIZE: usize = 4 * 1024 * 1024;
|
||||
|
||||
/// Binary message types for client-server communication
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MessageType {
|
||||
// Client → Server (Core)
|
||||
Heartbeat = 0x01,
|
||||
Register = 0x02,
|
||||
StatusReport = 0x03,
|
||||
AssetReport = 0x04,
|
||||
AssetChange = 0x05,
|
||||
UsbEvent = 0x06,
|
||||
AlertAck = 0x07,
|
||||
|
||||
// Server → Client (Core)
|
||||
RegisterResponse = 0x08,
|
||||
PolicyUpdate = 0x10,
|
||||
ConfigUpdate = 0x11,
|
||||
TaskExecute = 0x12,
|
||||
|
||||
// Plugin: Web Filter (上网拦截)
|
||||
WebFilterRuleUpdate = 0x20,
|
||||
WebAccessLog = 0x21,
|
||||
|
||||
// Plugin: Usage Timer (时长记录)
|
||||
UsageReport = 0x30,
|
||||
AppUsageReport = 0x31,
|
||||
|
||||
// Plugin: Software Blocker (软件禁止安装)
|
||||
SoftwareBlacklist = 0x40,
|
||||
SoftwareViolation = 0x41,
|
||||
|
||||
// Plugin: Popup Blocker (弹窗拦截)
|
||||
PopupRules = 0x50,
|
||||
|
||||
// Plugin: USB File Audit (U盘文件操作记录)
|
||||
UsbFileOp = 0x60,
|
||||
|
||||
// Plugin: Screen Watermark (水印管理)
|
||||
WatermarkConfig = 0x70,
|
||||
|
||||
// Plugin: USB Policy (U盘管控策略)
|
||||
UsbPolicyUpdate = 0x71,
|
||||
|
||||
// Plugin control
|
||||
PluginEnable = 0x80,
|
||||
PluginDisable = 0x81,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for MessageType {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x01 => Ok(Self::Heartbeat),
|
||||
0x02 => Ok(Self::Register),
|
||||
0x03 => Ok(Self::StatusReport),
|
||||
0x04 => Ok(Self::AssetReport),
|
||||
0x05 => Ok(Self::AssetChange),
|
||||
0x06 => Ok(Self::UsbEvent),
|
||||
0x07 => Ok(Self::AlertAck),
|
||||
0x08 => Ok(Self::RegisterResponse),
|
||||
0x10 => Ok(Self::PolicyUpdate),
|
||||
0x11 => Ok(Self::ConfigUpdate),
|
||||
0x12 => Ok(Self::TaskExecute),
|
||||
0x20 => Ok(Self::WebFilterRuleUpdate),
|
||||
0x21 => Ok(Self::WebAccessLog),
|
||||
0x30 => Ok(Self::UsageReport),
|
||||
0x31 => Ok(Self::AppUsageReport),
|
||||
0x40 => Ok(Self::SoftwareBlacklist),
|
||||
0x41 => Ok(Self::SoftwareViolation),
|
||||
0x50 => Ok(Self::PopupRules),
|
||||
0x60 => Ok(Self::UsbFileOp),
|
||||
0x70 => Ok(Self::WatermarkConfig),
|
||||
0x71 => Ok(Self::UsbPolicyUpdate),
|
||||
0x80 => Ok(Self::PluginEnable),
|
||||
0x81 => Ok(Self::PluginDisable),
|
||||
_ => Err(format!("Unknown message type: 0x{:02X}", value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A wire-format frame for transmission over TCP
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Frame {
|
||||
pub version: u8,
|
||||
pub msg_type: MessageType,
|
||||
pub payload: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Frame {
|
||||
/// Create a new frame with the current protocol version
|
||||
pub fn new(msg_type: MessageType, payload: Vec<u8>) -> Self {
|
||||
Self {
|
||||
version: PROTOCOL_VERSION,
|
||||
msg_type,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new frame with JSON-serialized payload
|
||||
pub fn new_json<T: Serialize>(msg_type: MessageType, data: &T) -> anyhow::Result<Self> {
|
||||
let payload = serde_json::to_vec(data)?;
|
||||
Ok(Self::new(msg_type, payload))
|
||||
}
|
||||
|
||||
/// Encode frame to bytes for transmission
|
||||
/// Format: MAGIC(4) + VERSION(1) + TYPE(1) + LENGTH(4) + PAYLOAD(var)
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + self.payload.len());
|
||||
buf.extend_from_slice(&MAGIC);
|
||||
buf.push(self.version);
|
||||
buf.push(self.msg_type as u8);
|
||||
buf.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
|
||||
buf.extend_from_slice(&self.payload);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode frame from bytes. Returns Ok(Some(frame)) when a complete frame is available.
|
||||
pub fn decode(data: &[u8]) -> Result<Option<Frame>, FrameError> {
|
||||
if data.len() < FRAME_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if data[0..4] != MAGIC {
|
||||
return Err(FrameError::InvalidMagic);
|
||||
}
|
||||
|
||||
let version = data[4];
|
||||
let msg_type_byte = data[5];
|
||||
let payload_len = u32::from_be_bytes([data[6], data[7], data[8], data[9]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(FrameError::PayloadTooLarge(payload_len));
|
||||
}
|
||||
|
||||
if data.len() < FRAME_HEADER_SIZE + payload_len {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let msg_type = MessageType::try_from(msg_type_byte)
|
||||
.map_err(|e| FrameError::UnknownMessageType(msg_type_byte, e))?;
|
||||
|
||||
let payload = data[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len].to_vec();
|
||||
|
||||
Ok(Some(Frame {
|
||||
version,
|
||||
msg_type,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Deserialize the payload as JSON
|
||||
pub fn decode_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
|
||||
serde_json::from_slice(&self.payload)
|
||||
}
|
||||
|
||||
/// Total encoded size of this frame
|
||||
pub fn encoded_size(&self) -> usize {
|
||||
FRAME_HEADER_SIZE + self.payload.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FrameError {
|
||||
#[error("Invalid magic bytes in frame header")]
|
||||
InvalidMagic,
|
||||
#[error("Unknown message type: 0x{0:02X} - {1}")]
|
||||
UnknownMessageType(u8, String),
|
||||
#[error("Payload too large: {0} bytes (max {})", MAX_PAYLOAD_SIZE)]
|
||||
PayloadTooLarge(usize),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
// ==================== Core Message Payloads ====================
|
||||
|
||||
/// Registration request payload
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub device_uid: String,
|
||||
pub hostname: String,
|
||||
pub registration_token: String,
|
||||
pub os_version: String,
|
||||
pub mac_address: Option<String>,
|
||||
}
|
||||
|
||||
/// Registration response payload
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegisterResponse {
|
||||
pub device_secret: String,
|
||||
pub config: ClientConfig,
|
||||
}
|
||||
|
||||
/// Server-pushed client configuration
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ClientConfig {
|
||||
pub heartbeat_interval_secs: u64,
|
||||
pub status_report_interval_secs: u64,
|
||||
pub asset_report_interval_secs: u64,
|
||||
pub server_version: String,
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
heartbeat_interval_secs: 30,
|
||||
status_report_interval_secs: 60,
|
||||
asset_report_interval_secs: 86400,
|
||||
server_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Heartbeat payload (minimal)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct HeartbeatPayload {
|
||||
pub device_uid: String,
|
||||
pub timestamp: String,
|
||||
pub hmac: String,
|
||||
}
|
||||
|
||||
/// Task execution request (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TaskExecutePayload {
|
||||
pub task_type: String,
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Config update types (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum ConfigUpdateType {
|
||||
UpdateIntervals { heartbeat: u64, status: u64, asset: u64 },
|
||||
TlsCertRotate,
|
||||
SelfDestruct,
|
||||
}
|
||||
|
||||
// ==================== Plugin Message Payloads ====================
|
||||
|
||||
/// Plugin: Web Access Log entry (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct WebAccessLogEntry {
|
||||
pub device_uid: String,
|
||||
pub url: String,
|
||||
pub action: String, // "allowed" | "blocked"
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Plugin: Daily Usage Report (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UsageDailyReport {
|
||||
pub device_uid: String,
|
||||
pub date: String,
|
||||
pub total_active_minutes: u32,
|
||||
pub total_idle_minutes: u32,
|
||||
pub first_active_at: Option<String>,
|
||||
pub last_active_at: Option<String>,
|
||||
}
|
||||
|
||||
/// Plugin: App Usage Report (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AppUsageEntry {
|
||||
pub device_uid: String,
|
||||
pub date: String,
|
||||
pub app_name: String,
|
||||
pub usage_minutes: u32,
|
||||
}
|
||||
|
||||
/// Plugin: Software Violation (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SoftwareViolationReport {
|
||||
pub device_uid: String,
|
||||
pub software_name: String,
|
||||
pub action_taken: String, // "blocked_install" | "auto_uninstalled" | "alerted"
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Plugin: USB File Operation (Client → Server)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UsbFileOpEntry {
|
||||
pub device_uid: String,
|
||||
pub usb_serial: Option<String>,
|
||||
pub drive_letter: Option<String>,
|
||||
pub operation: String, // "create" | "delete" | "rename" | "modify"
|
||||
pub file_path: String,
|
||||
pub file_size: Option<u64>,
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Plugin: Watermark Config (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct WatermarkConfigPayload {
|
||||
pub content: String,
|
||||
pub font_size: u32,
|
||||
pub opacity: f64,
|
||||
pub color: String,
|
||||
pub angle: i32,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Plugin enable/disable command (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PluginControlPayload {
|
||||
pub plugin_name: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Plugin: USB Policy Config (Server → Client)
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct UsbPolicyPayload {
|
||||
pub policy_type: String, // "all_block" | "whitelist" | "blacklist"
|
||||
pub enabled: bool,
|
||||
pub rules: Vec<UsbDeviceRule>,
|
||||
}
|
||||
|
||||
/// A single USB device rule for whitelist/blacklist matching
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct UsbDeviceRule {
|
||||
pub vendor_id: Option<String>,
|
||||
pub product_id: Option<String>,
|
||||
pub serial: Option<String>,
|
||||
pub device_name: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_frame_encode_decode_roundtrip() {
|
||||
let original = Frame::new(MessageType::Heartbeat, b"test payload".to_vec());
|
||||
let encoded = original.encode();
|
||||
let decoded = Frame::decode(&encoded).unwrap().unwrap();
|
||||
|
||||
assert_eq!(decoded.version, PROTOCOL_VERSION);
|
||||
assert_eq!(decoded.msg_type, MessageType::Heartbeat);
|
||||
assert_eq!(decoded.payload, b"test payload");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_decode_incomplete_data() {
|
||||
let data = [0x43, 0x53, 0x4D, 0x01, 0x01];
|
||||
let result = Frame::decode(&data).unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_decode_invalid_magic() {
|
||||
let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00];
|
||||
let result = Frame::decode(&data);
|
||||
assert!(matches!(result, Err(FrameError::InvalidMagic)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_frame_roundtrip() {
|
||||
let heartbeat = HeartbeatPayload {
|
||||
device_uid: "test-uid".to_string(),
|
||||
timestamp: "2026-04-03T12:00:00Z".to_string(),
|
||||
hmac: "abc123".to_string(),
|
||||
};
|
||||
|
||||
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat).unwrap();
|
||||
let encoded = frame.encode();
|
||||
let decoded = Frame::decode(&encoded).unwrap().unwrap();
|
||||
let parsed: HeartbeatPayload = decoded.decode_payload().unwrap();
|
||||
|
||||
assert_eq!(parsed.device_uid, "test-uid");
|
||||
assert_eq!(parsed.hmac, "abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plugin_message_types_roundtrip() {
|
||||
let types = [
|
||||
MessageType::WebAccessLog,
|
||||
MessageType::UsageReport,
|
||||
MessageType::AppUsageReport,
|
||||
MessageType::SoftwareViolation,
|
||||
MessageType::UsbFileOp,
|
||||
MessageType::WatermarkConfig,
|
||||
MessageType::PluginEnable,
|
||||
MessageType::PluginDisable,
|
||||
];
|
||||
|
||||
for mt in types {
|
||||
let frame = Frame::new(mt, vec![1, 2, 3]);
|
||||
let encoded = frame.encode();
|
||||
let decoded = Frame::decode(&encoded).unwrap().unwrap();
|
||||
assert_eq!(decoded.msg_type, mt);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_decode_payload_too_large() {
|
||||
// Craft a header that claims a 10 MB payload
|
||||
let mut data = Vec::with_capacity(FRAME_HEADER_SIZE);
|
||||
data.extend_from_slice(&MAGIC);
|
||||
data.push(PROTOCOL_VERSION);
|
||||
data.push(MessageType::Heartbeat as u8);
|
||||
data.extend_from_slice(&(10 * 1024 * 1024u32).to_be_bytes());
|
||||
// Don't actually include the payload — the size check should reject first
|
||||
let result = Frame::decode(&data);
|
||||
assert!(matches!(result, Err(FrameError::PayloadTooLarge(_))));
|
||||
}
|
||||
}
|
||||
51
crates/server/Cargo.toml
Normal file
51
crates/server/Cargo.toml
Normal file
@@ -0,0 +1,51 @@
|
||||
[package]
|
||||
name = "csm-server"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
csm-protocol = { path = "../protocol" }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
|
||||
# Web framework
|
||||
axum = { version = "0.7", features = ["ws"] }
|
||||
tower-http = { version = "0.5", features = ["cors", "fs", "trace", "compression-gzip", "set-header"] }
|
||||
tower = "0.4"
|
||||
|
||||
# Database
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
|
||||
|
||||
# TLS
|
||||
rustls = "0.23"
|
||||
tokio-rustls = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
rustls-pki-types = "1"
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Auth
|
||||
jsonwebtoken = "9"
|
||||
bcrypt = "0.15"
|
||||
|
||||
# Notifications
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder", "hostname"] }
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
|
||||
|
||||
# Config & logging
|
||||
toml = "0.8"
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
# Utilities
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
include_dir = "0.7"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
118
crates/server/src/alert.rs
Normal file
118
crates/server/src/alert.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use crate::AppState;
|
||||
use tracing::{info, warn, error};
|
||||
|
||||
/// Background task for data cleanup and alert processing
|
||||
pub async fn cleanup_task(state: AppState) {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Cleanup old status history
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM device_status_history WHERE reported_at < datetime('now', ?)"
|
||||
)
|
||||
.bind(format!("-{} days", state.config.retention.status_history_days))
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup status history: {}", e);
|
||||
}
|
||||
|
||||
// Cleanup old USB events
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM usb_events WHERE event_time < datetime('now', ?)"
|
||||
)
|
||||
.bind(format!("-{} days", state.config.retention.usb_events_days))
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup USB events: {}", e);
|
||||
}
|
||||
|
||||
// Cleanup handled alert records
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM alert_records WHERE handled = 1 AND triggered_at < datetime('now', ?)"
|
||||
)
|
||||
.bind(format!("-{} days", state.config.retention.alert_records_days))
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup alert records: {}", e);
|
||||
}
|
||||
|
||||
// Mark devices as offline if no heartbeat for 2 minutes
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to mark stale devices offline: {}", e);
|
||||
}
|
||||
|
||||
// SQLite WAL checkpoint
|
||||
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
warn!("WAL checkpoint failed: {}", e);
|
||||
}
|
||||
|
||||
info!("Cleanup cycle completed");
|
||||
}
|
||||
}
|
||||
|
||||
/// Send email notification
|
||||
pub async fn send_email(
|
||||
smtp_config: &crate::config::SmtpConfig,
|
||||
to: &str,
|
||||
subject: &str,
|
||||
body: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
use lettre::message::header::ContentType;
|
||||
use lettre::{Message, SmtpTransport, Transport};
|
||||
use lettre::transport::smtp::authentication::Credentials;
|
||||
|
||||
let email = Message::builder()
|
||||
.from(smtp_config.from.parse()?)
|
||||
.to(to.parse()?)
|
||||
.subject(subject)
|
||||
.header(ContentType::TEXT_HTML)
|
||||
.body(body.to_string())?;
|
||||
|
||||
let creds = Credentials::new(
|
||||
smtp_config.username.clone(),
|
||||
smtp_config.password.clone(),
|
||||
);
|
||||
|
||||
let mailer = SmtpTransport::starttls_relay(&smtp_config.host)?
|
||||
.port(smtp_config.port)
|
||||
.credentials(creds)
|
||||
.build();
|
||||
|
||||
mailer.send(&email)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shared HTTP client for webhook notifications.
|
||||
/// Lazily initialized once and reused across calls to benefit from connection pooling.
|
||||
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||||
|
||||
fn webhook_client() -> &'static reqwest::Client {
|
||||
WEBHOOK_CLIENT.get_or_init(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
})
|
||||
}
|
||||
|
||||
/// Send webhook notification
|
||||
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
|
||||
webhook_client().post(url)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
243
crates/server/src/api/alerts.rs
Normal file
243
crates/server/src/api/alerts.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
use super::auth::Claims;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AlertRecordListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub alert_type: Option<String>,
|
||||
pub severity: Option<String>,
|
||||
pub handled: Option<i32>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
|
||||
FROM alert_rules ORDER BY created_at DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"name": r.get::<String, _>("name"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"condition": r.get::<String, _>("condition"),
|
||||
"severity": r.get::<String, _>("severity"),
|
||||
"enabled": r.get::<i32, _>("enabled"),
|
||||
"notify_email": r.get::<Option<String>, _>("notify_email"),
|
||||
"notify_webhook": r.get::<Option<String>, _>("notify_webhook"),
|
||||
"created_at": r.get::<String, _>("created_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"rules": items,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query alert rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_records(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AlertRecordListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None (Axum deserializes `key=` as Some(""))
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let alert_type = params.alert_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let severity = params.severity.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let handled = params.handled;
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, rule_id, device_uid, alert_type, severity, detail, handled, handled_by, handled_at, triggered_at
|
||||
FROM alert_records WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR alert_type = ?)
|
||||
AND (? IS NULL OR severity = ?)
|
||||
AND (? IS NULL OR handled = ?)
|
||||
ORDER BY triggered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&alert_type).bind(&alert_type)
|
||||
.bind(&severity).bind(&severity)
|
||||
.bind(&handled).bind(&handled)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_id": r.get::<Option<i64>, _>("rule_id"),
|
||||
"device_uid": r.get::<Option<String>, _>("device_uid"),
|
||||
"alert_type": r.get::<String, _>("alert_type"),
|
||||
"severity": r.get::<String, _>("severity"),
|
||||
"detail": r.get::<String, _>("detail"),
|
||||
"handled": r.get::<i32, _>("handled"),
|
||||
"handled_by": r.get::<Option<String>, _>("handled_by"),
|
||||
"handled_at": r.get::<Option<String>, _>("handled_at"),
|
||||
"triggered_at": r.get::<String, _>("triggered_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"records": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query alert records", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub name: String,
|
||||
pub rule_type: String,
|
||||
pub condition: String,
|
||||
pub severity: Option<String>,
|
||||
pub notify_email: Option<String>,
|
||||
pub notify_webhook: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn create_rule(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<CreateRuleRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
|
||||
VALUES (?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&body.name)
|
||||
.bind(&body.rule_type)
|
||||
.bind(&body.condition)
|
||||
.bind(&severity)
|
||||
.bind(&body.notify_email)
|
||||
.bind(&body.notify_webhook)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
|
||||
"id": r.last_insert_rowid(),
|
||||
})))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create alert rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest {
|
||||
pub name: Option<String>,
|
||||
pub rule_type: Option<String>,
|
||||
pub condition: Option<String>,
|
||||
pub severity: Option<String>,
|
||||
pub enabled: Option<i32>,
|
||||
pub notify_email: Option<String>,
|
||||
pub notify_webhook: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn update_rule(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdateRuleRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let existing = sqlx::query("SELECT * FROM alert_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Rule not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query alert rule", e)),
|
||||
};
|
||||
|
||||
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
|
||||
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
|
||||
let condition = body.condition.unwrap_or_else(|| existing.get::<String, _>("condition"));
|
||||
let severity = body.severity.unwrap_or_else(|| existing.get::<String, _>("severity"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
|
||||
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
|
||||
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
|
||||
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(&rule_type)
|
||||
.bind(&condition)
|
||||
.bind(&severity)
|
||||
.bind(enabled)
|
||||
.bind(¬ify_email)
|
||||
.bind(¬ify_webhook)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::ok(serde_json::json!({"updated": true}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("update alert rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let result = sqlx::query("DELETE FROM alert_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Rule not found"))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("delete alert rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_record(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
claims: axum::Extension<Claims>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let handled_by = &claims.username;
|
||||
let result = sqlx::query(
|
||||
"UPDATE alert_records SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(handled_by)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
Json(ApiResponse::ok(serde_json::json!({"handled": true})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Alert record not found"))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("handle alert record", e)),
|
||||
}
|
||||
}
|
||||
143
crates/server/src/api/assets.rs
Normal file
143
crates/server/src/api/assets.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::{ApiResponse, Pagination};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AssetListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub search: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_hardware(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AssetListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb,
|
||||
gpu_model, motherboard, serial_number, reported_at
|
||||
FROM hardware_assets WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR cpu_model LIKE '%' || ? || '%' OR gpu_model LIKE '%' || ? || '%')
|
||||
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"cpu_model": r.get::<String, _>("cpu_model"),
|
||||
"cpu_cores": r.get::<i32, _>("cpu_cores"),
|
||||
"memory_total_mb": r.get::<i64, _>("memory_total_mb"),
|
||||
"disk_model": r.get::<String, _>("disk_model"),
|
||||
"disk_total_mb": r.get::<i64, _>("disk_total_mb"),
|
||||
"gpu_model": r.get::<Option<String>, _>("gpu_model"),
|
||||
"motherboard": r.get::<Option<String>, _>("motherboard"),
|
||||
"serial_number": r.get::<Option<String>, _>("serial_number"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"hardware": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query hardware assets", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_software(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AssetListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, name, version, publisher, install_date, install_path
|
||||
FROM software_assets WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR name LIKE '%' || ? || '%' OR publisher LIKE '%' || ? || '%')
|
||||
ORDER BY name ASC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"name": r.get::<String, _>("name"),
|
||||
"version": r.get::<Option<String>, _>("version"),
|
||||
"publisher": r.get::<Option<String>, _>("publisher"),
|
||||
"install_date": r.get::<Option<String>, _>("install_date"),
|
||||
"install_path": r.get::<Option<String>, _>("install_path"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"software": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query software assets", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_changes(
|
||||
State(state): State<AppState>,
|
||||
Query(page): Query<Pagination>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let offset = page.offset();
|
||||
let limit = page.limit();
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, change_type, change_detail, detected_at
|
||||
FROM asset_changes ORDER BY detected_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"change_type": r.get::<String, _>("change_type"),
|
||||
"change_detail": r.get::<String, _>("change_detail"),
|
||||
"detected_at": r.get::<String, _>("detected_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"changes": items,
|
||||
"page": page.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query asset changes", e)),
|
||||
}
|
||||
}
|
||||
295
crates/server/src/api/auth.rs
Normal file
295
crates/server/src/api/auth.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: i64,
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub exp: u64,
|
||||
pub iat: u64,
|
||||
pub token_type: String,
|
||||
/// Random family ID for refresh token rotation detection
|
||||
pub family: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LoginResponse {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub user: UserInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct UserInfo {
|
||||
pub id: i64,
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RefreshRequest {
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
/// In-memory rate limiter for login attempts
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LoginRateLimiter {
|
||||
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
|
||||
}
|
||||
|
||||
impl LoginRateLimiter {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Returns true if the request should be rate-limited
|
||||
pub async fn is_limited(&self, key: &str) -> bool {
|
||||
let mut attempts = self.attempts.lock().await;
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(300); // 5-minute window
|
||||
let max_attempts = 10u32;
|
||||
|
||||
if let Some((first_attempt, count)) = attempts.get_mut(key) {
|
||||
if now.duration_since(*first_attempt) > window {
|
||||
// Window expired, reset
|
||||
*first_attempt = now;
|
||||
*count = 1;
|
||||
false
|
||||
} else if *count >= max_attempts {
|
||||
true // Rate limited
|
||||
} else {
|
||||
*count += 1;
|
||||
false
|
||||
}
|
||||
} else {
|
||||
attempts.insert(key.to_string(), (now, 1));
|
||||
// Cleanup old entries periodically
|
||||
if attempts.len() > 1000 {
|
||||
let cutoff = now - window;
|
||||
attempts.retain(|_, (t, _)| *t > cutoff);
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
|
||||
// Rate limit check
|
||||
if state.login_limiter.is_limited(&req.username).await {
|
||||
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
|
||||
}
|
||||
|
||||
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
|
||||
"SELECT id, username, role FROM users WHERE username = ?"
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let user = match user {
|
||||
Some(u) => u,
|
||||
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
|
||||
};
|
||||
|
||||
let hash: String = sqlx::query_scalar::<_, String>(
|
||||
"SELECT password FROM users WHERE id = ?"
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let family = uuid::Uuid::new_v4().to_string();
|
||||
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
|
||||
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
|
||||
|
||||
// Audit log
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(format!("User {} logged in", user.username))
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token,
|
||||
user,
|
||||
}))))
|
||||
}
|
||||
|
||||
pub async fn refresh(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<RefreshRequest>,
|
||||
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
|
||||
let claims = decode::<Claims>(
|
||||
&req.refresh_token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if claims.claims.token_type != "refresh" {
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
|
||||
}
|
||||
|
||||
// Check if this refresh token family has been revoked (reuse detection)
|
||||
let revoked: bool = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
|
||||
)
|
||||
.bind(&claims.claims.family)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0) > 0;
|
||||
|
||||
if revoked {
|
||||
// Token reuse detected — revoke entire family and force re-login
|
||||
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
|
||||
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
|
||||
.bind(claims.claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
|
||||
}
|
||||
|
||||
let user = UserInfo {
|
||||
id: claims.claims.sub,
|
||||
username: claims.claims.username,
|
||||
role: claims.claims.role,
|
||||
};
|
||||
|
||||
// Rotate: new family for each refresh
|
||||
let new_family = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
|
||||
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
|
||||
|
||||
// Revoke old family
|
||||
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
|
||||
.bind(&claims.claims.family)
|
||||
.bind(claims.claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token,
|
||||
user,
|
||||
}))))
|
||||
}
|
||||
|
||||
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
|
||||
let claims = Claims {
|
||||
sub: user.id,
|
||||
username: user.username.clone(),
|
||||
role: user.role.clone(),
|
||||
exp: now + ttl,
|
||||
iat: now,
|
||||
token_type: token_type.to_string(),
|
||||
family: family.to_string(),
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_bytes()),
|
||||
)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
/// Axum middleware: require valid JWT access token
|
||||
pub async fn require_auth(
|
||||
State(state): State<AppState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let auth_header = request.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let token = match auth_header {
|
||||
Some(t) => t,
|
||||
None => return Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
|
||||
let claims = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if claims.claims.token_type != "access" {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// Inject claims into request extensions for handlers to use
|
||||
request.extensions_mut().insert(claims.claims);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
/// Axum middleware: require admin role for write operations + audit log
|
||||
pub async fn require_admin(
|
||||
State(state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let claims = request.extensions()
|
||||
.get::<Claims>()
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if claims.role != "admin" {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
// Capture audit info before running handler
|
||||
let method = request.method().clone();
|
||||
let path = request.uri().path().to_string();
|
||||
let user_id = claims.sub;
|
||||
let username = claims.username.clone();
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Record admin action to audit log (fire and forget — don't block response)
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let action = format!("{} {}", method, path);
|
||||
let detail = format!("by {}", username);
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&action)
|
||||
.bind(&detail)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
263
crates/server/src/api/devices.rs
Normal file
263
crates/server/src/api/devices.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use axum::{extract::{State, Path, Query}, Json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::{ApiResponse, Pagination};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeviceListParams {
|
||||
pub status: Option<String>,
|
||||
pub group: Option<String>,
|
||||
pub search: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct DeviceRow {
|
||||
pub id: i64,
|
||||
pub device_uid: String,
|
||||
pub hostname: String,
|
||||
pub ip_address: String,
|
||||
pub mac_address: Option<String>,
|
||||
pub os_version: Option<String>,
|
||||
pub client_version: Option<String>,
|
||||
pub status: String,
|
||||
pub last_heartbeat: Option<String>,
|
||||
pub registered_at: Option<String>,
|
||||
pub group_name: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<DeviceListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None (Axum deserializes `status=` as Some(""))
|
||||
let status = params.status.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let group = params.group.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let devices = sqlx::query_as::<_, DeviceRow>(
|
||||
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
|
||||
status, last_heartbeat, registered_at, group_name
|
||||
FROM devices WHERE 1=1
|
||||
AND (? IS NULL OR status = ?)
|
||||
AND (? IS NULL OR group_name = ?)
|
||||
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
|
||||
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&status).bind(&status)
|
||||
.bind(&group).bind(&group)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM devices WHERE 1=1
|
||||
AND (? IS NULL OR status = ?)
|
||||
AND (? IS NULL OR group_name = ?)
|
||||
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')"
|
||||
)
|
||||
.bind(&status).bind(&status)
|
||||
.bind(&group).bind(&group)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
match devices {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"devices": rows,
|
||||
"total": total,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_detail(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let device = sqlx::query_as::<_, DeviceRow>(
|
||||
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
|
||||
status, last_heartbeat, registered_at, group_name
|
||||
FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
match device {
|
||||
Ok(Some(d)) => Json(ApiResponse::ok(serde_json::to_value(d).unwrap_or_default())),
|
||||
Ok(None) => Json(ApiResponse::error("Device not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
struct StatusRow {
|
||||
pub cpu_usage: f64,
|
||||
pub memory_usage: f64,
|
||||
pub memory_total_mb: i64,
|
||||
pub disk_usage: f64,
|
||||
pub disk_total_mb: i64,
|
||||
pub network_rx_rate: i64,
|
||||
pub network_tx_rate: i64,
|
||||
pub running_procs: i32,
|
||||
pub top_processes: Option<String>,
|
||||
pub reported_at: String,
|
||||
}
|
||||
|
||||
pub async fn get_status(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let status = sqlx::query_as::<_, StatusRow>(
|
||||
"SELECT cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb,
|
||||
network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at
|
||||
FROM device_status WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
match status {
|
||||
Ok(Some(s)) => {
|
||||
let mut val = serde_json::to_value(&s).unwrap_or_default();
|
||||
// Parse top_processes JSON string back to array
|
||||
if let Some(tp_str) = &s.top_processes {
|
||||
if let Ok(tp) = serde_json::from_str::<serde_json::Value>(tp_str) {
|
||||
val["top_processes"] = tp;
|
||||
}
|
||||
}
|
||||
Json(ApiResponse::ok(val))
|
||||
}
|
||||
Ok(None) => Json(ApiResponse::error("No status data found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_history(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
Query(page): Query<Pagination>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let offset = page.offset();
|
||||
let limit = page.limit();
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT cpu_usage, memory_usage, disk_usage, running_procs, reported_at
|
||||
FROM device_status_history WHERE device_uid = ?
|
||||
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&uid)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"cpu_usage": r.get::<f64, _>("cpu_usage"),
|
||||
"memory_usage": r.get::<f64, _>("memory_usage"),
|
||||
"disk_usage": r.get::<f64, _>("disk_usage"),
|
||||
"running_procs": r.get::<i32, _>("running_procs"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
})
|
||||
}).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"history": items,
|
||||
"page": page.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
// If client is connected, send self-destruct command
|
||||
let frame = csm_protocol::Frame::new_json(
|
||||
csm_protocol::MessageType::ConfigUpdate,
|
||||
&serde_json::json!({"type": "SelfDestruct"}),
|
||||
).ok();
|
||||
|
||||
if let Some(frame) = frame {
|
||||
state.clients.send_to(&uid, frame.encode()).await;
|
||||
}
|
||||
|
||||
// Delete device and all associated data in a transaction
|
||||
let mut tx = match state.db.begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => return Json(ApiResponse::internal_error("begin transaction", e)),
|
||||
};
|
||||
|
||||
// Delete status history
|
||||
if let Err(e) = sqlx::query("DELETE FROM device_status_history WHERE device_uid = ?")
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
return Json(ApiResponse::internal_error("remove device history", e));
|
||||
}
|
||||
|
||||
// Delete current status
|
||||
if let Err(e) = sqlx::query("DELETE FROM device_status WHERE device_uid = ?")
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
return Json(ApiResponse::internal_error("remove device status", e));
|
||||
}
|
||||
|
||||
// Delete plugin-related data
|
||||
let cleanup_tables = [
|
||||
"hardware_assets",
|
||||
"usb_events",
|
||||
"usb_file_operations",
|
||||
"usage_daily",
|
||||
"app_usage_daily",
|
||||
"software_violations",
|
||||
"web_access_log",
|
||||
"popup_block_stats",
|
||||
];
|
||||
for table in &cleanup_tables {
|
||||
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to clean {} for device {}: {}", table, uid, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Finally delete the device itself
|
||||
let delete_result = sqlx::query("DELETE FROM devices WHERE device_uid = ?")
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
match delete_result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
if let Err(e) = tx.commit().await {
|
||||
return Json(ApiResponse::internal_error("commit device deletion", e));
|
||||
}
|
||||
state.clients.unregister(&uid).await;
|
||||
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Device not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("remove device", e)),
|
||||
}
|
||||
}
|
||||
120
crates/server/src/api/mod.rs
Normal file
120
crates/server/src/api/mod.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
pub mod auth;
|
||||
pub mod devices;
|
||||
pub mod assets;
|
||||
pub mod usb;
|
||||
pub mod alerts;
|
||||
pub mod plugins;
|
||||
|
||||
pub fn routes(state: AppState) -> Router<AppState> {
|
||||
let public = Router::new()
|
||||
.route("/api/auth/login", post(auth::login))
|
||||
.route("/api/auth/refresh", post(auth::refresh))
|
||||
.route("/health", get(health_check))
|
||||
.with_state(state.clone());
|
||||
|
||||
// Read-only routes (any authenticated user)
|
||||
let read_routes = Router::new()
|
||||
// Devices
|
||||
.route("/api/devices", get(devices::list))
|
||||
.route("/api/devices/:uid", get(devices::get_detail))
|
||||
.route("/api/devices/:uid/status", get(devices::get_status))
|
||||
.route("/api/devices/:uid/history", get(devices::get_history))
|
||||
// Assets
|
||||
.route("/api/assets/hardware", get(assets::list_hardware))
|
||||
.route("/api/assets/software", get(assets::list_software))
|
||||
.route("/api/assets/changes", get(assets::list_changes))
|
||||
// USB (read)
|
||||
.route("/api/usb/events", get(usb::list_events))
|
||||
.route("/api/usb/policies", get(usb::list_policies))
|
||||
// Alerts (read)
|
||||
.route("/api/alerts/rules", get(alerts::list_rules))
|
||||
.route("/api/alerts/records", get(alerts::list_records))
|
||||
// Plugin read routes
|
||||
.merge(plugins::read_routes())
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
|
||||
|
||||
// Write routes (admin only)
|
||||
let write_routes = Router::new()
|
||||
// Devices
|
||||
.route("/api/devices/:uid", delete(devices::remove))
|
||||
// USB (write)
|
||||
.route("/api/usb/policies", post(usb::create_policy))
|
||||
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
|
||||
// Alerts (write)
|
||||
.route("/api/alerts/rules", post(alerts::create_rule))
|
||||
.route("/api/alerts/rules/:id", put(alerts::update_rule).delete(alerts::delete_rule))
|
||||
.route("/api/alerts/records/:id/handle", put(alerts::handle_record))
|
||||
// Plugin write routes (already has require_admin layer internally)
|
||||
.merge(plugins::write_routes())
|
||||
// Layer order: outer (require_admin) runs AFTER inner (require_auth)
|
||||
// so require_auth sets Claims extension first, then require_admin checks it
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_admin))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
|
||||
|
||||
// WebSocket has its own JWT auth via query parameter
|
||||
let ws_router = Router::new()
|
||||
.route("/ws", get(crate::ws::ws_handler))
|
||||
.with_state(state.clone());
|
||||
|
||||
Router::new()
|
||||
.merge(public)
|
||||
.merge(read_routes)
|
||||
.merge(write_routes)
|
||||
.merge(ws_router)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct HealthResponse {
|
||||
status: &'static str,
|
||||
}
|
||||
|
||||
async fn health_check() -> Json<HealthResponse> {
|
||||
Json(HealthResponse {
|
||||
status: "ok",
|
||||
})
|
||||
}
|
||||
|
||||
/// Standard API response envelope
|
||||
#[derive(Serialize)]
|
||||
pub struct ApiResponse<T: Serialize> {
|
||||
pub success: bool,
|
||||
pub data: Option<T>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl<T: Serialize> ApiResponse<T> {
|
||||
pub fn ok(data: T) -> Self {
|
||||
Self { success: true, data: Some(data), error: None }
|
||||
}
|
||||
|
||||
pub fn error(msg: impl Into<String>) -> Self {
|
||||
Self { success: false, data: None, error: Some(msg.into()) }
|
||||
}
|
||||
|
||||
/// Log internal error and return sanitized message to client
|
||||
pub fn internal_error(context: &str, e: impl std::fmt::Display) -> Self {
|
||||
tracing::error!("{}: {}", context, e);
|
||||
Self { success: false, data: None, error: Some("Internal server error".to_string()) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Pagination parameters
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Pagination {
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl Pagination {
|
||||
pub fn offset(&self) -> u32 {
|
||||
self.page.unwrap_or(1).saturating_sub(1) * self.limit()
|
||||
}
|
||||
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.page_size.unwrap_or(20).min(100)
|
||||
}
|
||||
}
|
||||
49
crates/server/src/api/plugins/mod.rs
Normal file
49
crates/server/src/api/plugins/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
pub mod web_filter;
|
||||
pub mod usage_timer;
|
||||
pub mod software_blocker;
|
||||
pub mod popup_blocker;
|
||||
pub mod usb_file_audit;
|
||||
pub mod watermark;
|
||||
|
||||
use axum::{Router, routing::{get, post, put}};
|
||||
use crate::AppState;
|
||||
|
||||
/// Read-only plugin routes (accessible by admin + viewer)
|
||||
pub fn read_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
// Web Filter
|
||||
.route("/api/plugins/web-filter/rules", get(web_filter::list_rules))
|
||||
.route("/api/plugins/web-filter/log", get(web_filter::list_access_log))
|
||||
// Usage Timer
|
||||
.route("/api/plugins/usage-timer/daily", get(usage_timer::list_daily))
|
||||
.route("/api/plugins/usage-timer/app-usage", get(usage_timer::list_app_usage))
|
||||
.route("/api/plugins/usage-timer/leaderboard", get(usage_timer::leaderboard))
|
||||
// Software Blocker
|
||||
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
|
||||
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
|
||||
// Popup Blocker
|
||||
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
|
||||
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
|
||||
// USB File Audit
|
||||
.route("/api/plugins/usb-file-audit/log", get(usb_file_audit::list_operations))
|
||||
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
|
||||
// Watermark
|
||||
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
|
||||
}
|
||||
|
||||
/// Write plugin routes (admin only — require_admin middleware applied by caller)
|
||||
pub fn write_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
// Web Filter
|
||||
.route("/api/plugins/web-filter/rules", post(web_filter::create_rule))
|
||||
.route("/api/plugins/web-filter/rules/:id", put(web_filter::update_rule).delete(web_filter::delete_rule))
|
||||
// Software Blocker
|
||||
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
|
||||
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
|
||||
// Popup Blocker
|
||||
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
|
||||
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
|
||||
// Watermark
|
||||
.route("/api/plugins/watermark/config", post(watermark::create_config))
|
||||
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
|
||||
}
|
||||
155
crates/server/src/api/plugins/popup_blocker.rs
Normal file
155
crates/server/src/api/plugins/popup_blocker.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use axum::{extract::{State, Path, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub rule_type: String, // "block" | "allow"
|
||||
pub window_title: Option<String>,
|
||||
pub window_class: Option<String>,
|
||||
pub process_name: Option<String>,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
"window_title": r.get::<Option<String>,_>("window_title"),
|
||||
"window_class": r.get::<Option<String>,_>("window_class"),
|
||||
"process_name": r.get::<Option<String>,_>("process_name"),
|
||||
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
|
||||
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query popup filter rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if !matches!(req.rule_type.as_str(), "block" | "allow") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
let has_filter = req.window_title.as_ref().map_or(false, |s| !s.is_empty())
|
||||
|| req.window_class.as_ref().map_or(false, |s| !s.is_empty())
|
||||
|| req.process_name.as_ref().map_or(false, |s| !s.is_empty());
|
||||
if !has_filter {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
|
||||
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let rules = fetch_popup_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create popup filter rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest { pub window_title: Option<String>, pub window_class: Option<String>, pub process_name: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM popup_filter_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query popup filter rule", e)),
|
||||
};
|
||||
|
||||
let window_title = body.window_title.or_else(|| existing.get::<Option<String>, _>("window_title"));
|
||||
let window_class = body.window_class.or_else(|| existing.get::<Option<String>, _>("window_class"));
|
||||
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&window_title)
|
||||
.bind(&window_class)
|
||||
.bind(&process_name)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let rules = fetch_popup_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update popup filter rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM popup_filter_rules WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
match sqlx::query("DELETE FROM popup_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let rules = fetch_popup_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_stats(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT device_uid, blocked_count, date FROM popup_block_stats ORDER BY date DESC LIMIT 30")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"stats": rows.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String,_>("device_uid"), "blocked_count": r.get::<i32,_>("blocked_count"),
|
||||
"date": r.get::<String,_>("date")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query popup block stats", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_popup_rules_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
"window_title": r.get::<Option<String>,_>("window_title"),
|
||||
"window_class": r.get::<Option<String>,_>("window_class"),
|
||||
"process_name": r.get::<Option<String>,_>("process_name"),
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
155
crates/server/src/api/plugins/software_blocker.rs
Normal file
155
crates/server/src/api/plugins/software_blocker.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateBlacklistRequest {
|
||||
pub name_pattern: String,
|
||||
pub category: Option<String>,
|
||||
pub action: Option<String>,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
|
||||
"category": r.get::<Option<String>,_>("category"), "action": r.get::<String,_>("action"),
|
||||
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
|
||||
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query software blacklist", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<CreateBlacklistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let action = req.action.unwrap_or_else(|| "block".to_string());
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
|
||||
}
|
||||
if !matches!(action.as_str(), "block" | "alert") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("action must be 'block' or 'alert'")));
|
||||
}
|
||||
if let Some(ref cat) = req.category {
|
||||
if !matches!(cat.as_str(), "game" | "social" | "vpn" | "mining" | "custom") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid category")));
|
||||
}
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO software_blacklist (name_pattern, category, action, target_type, target_id) VALUES (?,?,?,?,?)")
|
||||
.bind(&req.name_pattern).bind(&req.category).bind(&action).bind(&target_type).bind(&req.target_id)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateBlacklistRequest { pub name_pattern: Option<String>, pub action: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateBlacklistRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM software_blacklist WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query software blacklist", e)),
|
||||
};
|
||||
|
||||
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
|
||||
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&name_pattern)
|
||||
.bind(&action)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update software blacklist", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM software_blacklist WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ViolationFilters { pub device_uid: Option<String> }
|
||||
|
||||
pub async fn list_violations(State(state): State<AppState>, Query(f): Query<ViolationFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, device_uid, software_name, action_taken, timestamp FROM software_violations WHERE (? IS NULL OR device_uid=?) ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"violations": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"software_name": r.get::<String,_>("software_name"), "action_taken": r.get::<String,_>("action_taken"),
|
||||
"timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query software violations", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_blacklist_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"), "action": r.get::<String,_>("action")
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
60
crates/server/src/api/plugins/usage_timer.rs
Normal file
60
crates/server/src/api/plugins/usage_timer.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DailyFilters { pub device_uid: Option<String>, pub start_date: Option<String>, pub end_date: Option<String> }
|
||||
|
||||
pub async fn list_daily(State(state): State<AppState>, Query(f): Query<DailyFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at
|
||||
FROM usage_daily WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date>=?) AND (? IS NULL OR date<=?)
|
||||
ORDER BY date DESC LIMIT 90")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.bind(&f.start_date).bind(&f.start_date)
|
||||
.bind(&f.end_date).bind(&f.end_date)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"daily": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"date": r.get::<String,_>("date"), "total_active_minutes": r.get::<i32,_>("total_active_minutes"),
|
||||
"total_idle_minutes": r.get::<i32,_>("total_idle_minutes"),
|
||||
"first_active_at": r.get::<Option<String>,_>("first_active_at"),
|
||||
"last_active_at": r.get::<Option<String>,_>("last_active_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query daily usage", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AppUsageFilters { pub device_uid: Option<String>, pub date: Option<String> }
|
||||
|
||||
pub async fn list_app_usage(State(state): State<AppState>, Query(f): Query<AppUsageFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, date, app_name, usage_minutes FROM app_usage_daily
|
||||
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date=?)
|
||||
ORDER BY usage_minutes DESC LIMIT 100")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.bind(&f.date).bind(&f.date)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"app_usage": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"date": r.get::<String,_>("date"), "app_name": r.get::<String,_>("app_name"),
|
||||
"usage_minutes": r.get::<i32,_>("usage_minutes")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query app usage", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn leaderboard(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT device_uid, SUM(total_active_minutes) as total_minutes FROM usage_daily
|
||||
WHERE date >= date('now', '-7 days') GROUP BY device_uid ORDER BY total_minutes DESC LIMIT 20")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"leaderboard": rows.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String,_>("device_uid"), "total_minutes": r.get::<i64,_>("total_minutes")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query usage leaderboard", e)),
|
||||
}
|
||||
}
|
||||
47
crates/server/src/api/plugins/usb_file_audit.rs
Normal file
47
crates/server/src/api/plugins/usb_file_audit.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LogFilters {
|
||||
pub device_uid: Option<String>,
|
||||
pub operation: Option<String>,
|
||||
pub usb_serial: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_operations(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp
|
||||
FROM usb_file_operations WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR operation=?) AND (? IS NULL OR usb_serial=?)
|
||||
ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.bind(&f.operation).bind(&f.operation)
|
||||
.bind(&f.usb_serial).bind(&f.usb_serial)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"operations": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"usb_serial": r.get::<Option<String>,_>("usb_serial"), "drive_letter": r.get::<Option<String>,_>("drive_letter"),
|
||||
"operation": r.get::<String,_>("operation"), "file_path": r.get::<String,_>("file_path"),
|
||||
"file_size": r.get::<Option<i64>,_>("file_size"), "timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query USB file operations", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn summary(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT device_uid, COUNT(*) as op_count, COUNT(DISTINCT usb_serial) as usb_count,
|
||||
MIN(timestamp) as first_op, MAX(timestamp) as last_op
|
||||
FROM usb_file_operations WHERE timestamp >= datetime('now', '-7 days')
|
||||
GROUP BY device_uid ORDER BY op_count DESC LIMIT 50")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"summary": rows.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String,_>("device_uid"), "op_count": r.get::<i64,_>("op_count"),
|
||||
"usb_count": r.get::<i64,_>("usb_count"), "first_op": r.get::<String,_>("first_op"),
|
||||
"last_op": r.get::<String,_>("last_op")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query USB file audit summary", e)),
|
||||
}
|
||||
}
|
||||
186
crates/server/src/api/plugins/watermark.rs
Normal file
186
crates/server/src/api/plugins/watermark.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use axum::{extract::{State, Path, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::{MessageType, WatermarkConfigPayload};
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateConfigRequest {
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub font_size: Option<u32>,
|
||||
pub opacity: Option<f64>,
|
||||
pub color: Option<String>,
|
||||
pub angle: Option<i32>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn get_config_list(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, target_type, target_id, content, font_size, opacity, color, angle, enabled, updated_at FROM watermark_config ORDER BY id")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"configs": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "target_type": r.get::<String,_>("target_type"),
|
||||
"target_id": r.get::<Option<String>,_>("target_id"), "content": r.get::<String,_>("content"),
|
||||
"font_size": r.get::<i32,_>("font_size"), "opacity": r.get::<f64,_>("opacity"),
|
||||
"color": r.get::<String,_>("color"), "angle": r.get::<i32,_>("angle"),
|
||||
"enabled": r.get::<bool,_>("enabled"), "updated_at": r.get::<String,_>("updated_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query watermark configs", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_config(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<CreateConfigRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
let content = req.content.unwrap_or_else(|| "{company} | {username} | {date}".to_string());
|
||||
let font_size = req.font_size.unwrap_or(14).clamp(8, 72) as i32;
|
||||
let opacity = req.opacity.unwrap_or(0.15).clamp(0.01, 1.0);
|
||||
let color = req.color.unwrap_or_else(|| "#808080".to_string());
|
||||
let angle = req.angle.unwrap_or(-30);
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
|
||||
// Validate inputs
|
||||
if content.len() > 200 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("content too long (max 200 chars)")));
|
||||
}
|
||||
if !is_valid_hex_color(&color) {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid color format (expected #RRGGBB)")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO watermark_config (target_type, target_id, content, font_size, opacity, color, angle, enabled) VALUES (?,?,?,?,?,?,?,?)")
|
||||
.bind(&target_type).bind(&req.target_id).bind(&content).bind(font_size).bind(opacity).bind(&color).bind(angle).bind(enabled)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
// Push to online clients
|
||||
let config = WatermarkConfigPayload {
|
||||
content: content.clone(),
|
||||
font_size: font_size as u32,
|
||||
opacity,
|
||||
color: color.clone(),
|
||||
angle,
|
||||
enabled,
|
||||
};
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create watermark config", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateConfigRequest {
|
||||
pub content: Option<String>, pub font_size: Option<u32>, pub opacity: Option<f64>,
|
||||
pub color: Option<String>, pub angle: Option<i32>, pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn update_config(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdateConfigRequest>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM watermark_config WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query watermark config", e)),
|
||||
};
|
||||
|
||||
let content = body.content.unwrap_or_else(|| existing.get::<String, _>("content"));
|
||||
let font_size = body.font_size.map(|v| v.clamp(8, 72) as i32).unwrap_or_else(|| existing.get::<i32, _>("font_size"));
|
||||
let opacity = body.opacity.map(|v| v.clamp(0.01, 1.0)).unwrap_or_else(|| existing.get::<f64, _>("opacity"));
|
||||
let color = body.color.unwrap_or_else(|| existing.get::<String, _>("color"));
|
||||
let angle = body.angle.unwrap_or_else(|| existing.get::<i32, _>("angle"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Validate inputs
|
||||
if content.len() > 200 {
|
||||
return Json(ApiResponse::error("content too long (max 200 chars)"));
|
||||
}
|
||||
if !is_valid_hex_color(&color) {
|
||||
return Json(ApiResponse::error("invalid color format (expected #RRGGBB)"));
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE watermark_config SET content = ?, font_size = ?, opacity = ?, color = ?, angle = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&content)
|
||||
.bind(font_size)
|
||||
.bind(opacity)
|
||||
.bind(&color)
|
||||
.bind(angle)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
// Push updated config to online clients
|
||||
let config = WatermarkConfigPayload {
|
||||
content: content.clone(),
|
||||
font_size: font_size as u32,
|
||||
opacity,
|
||||
color: color.clone(),
|
||||
angle,
|
||||
enabled,
|
||||
};
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update watermark config", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_config(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
// Fetch existing config to get target info for push
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM watermark_config WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
|
||||
match sqlx::query("DELETE FROM watermark_config WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
// Push disabled watermark to clients
|
||||
let disabled = WatermarkConfigPayload {
|
||||
content: String::new(),
|
||||
font_size: 0,
|
||||
opacity: 0.0,
|
||||
color: String::new(),
|
||||
angle: 0,
|
||||
enabled: false,
|
||||
};
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &disabled, &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a hex color string (#RRGGBB format)
|
||||
fn is_valid_hex_color(color: &str) -> bool {
|
||||
if color.len() != 7 || !color.starts_with('#') {
|
||||
return false;
|
||||
}
|
||||
color[1..].chars().all(|c| c.is_ascii_hexdigit())
|
||||
}
|
||||
156
crates/server/src/api/plugins/web_filter.rs
Normal file
156
crates/server/src/api/plugins/web_filter.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub rule_type: String,
|
||||
pub pattern: String,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
"pattern": r.get::<String,_>("pattern"), "target_type": r.get::<String,_>("target_type"),
|
||||
"target_id": r.get::<Option<String>,_>("target_id"), "enabled": r.get::<bool,_>("enabled"),
|
||||
"created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>() }))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query web filter rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if !matches!(req.rule_type.as_str(), "blacklist" | "whitelist" | "category") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid rule_type (expected blacklist|whitelist|category)")));
|
||||
}
|
||||
if req.pattern.trim().is_empty() || req.pattern.len() > 255 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("pattern must be 1-255 chars")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO web_filter_rules (rule_type, pattern, target_type, target_id, enabled) VALUES (?,?,?,?,?)")
|
||||
.bind(&req.rule_type).bind(&req.pattern).bind(&target_type).bind(&req.target_id).bind(enabled)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let rules = fetch_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create web filter rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest { pub rule_type: Option<String>, pub pattern: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM web_filter_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query web filter rule", e)),
|
||||
};
|
||||
|
||||
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
|
||||
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&rule_type)
|
||||
.bind(&pattern)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let rules = fetch_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update web filter rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM web_filter_rules WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
match sqlx::query("DELETE FROM web_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let rules = fetch_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
|
||||
|
||||
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
|
||||
"timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>() }))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch enabled web filter rules applicable to a given target scope.
|
||||
/// For "device" targets, includes both global rules and device-specific rules
|
||||
/// (matching the logic used during registration push in tcp.rs).
|
||||
async fn fetch_rules_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"), "pattern": r.get::<String,_>("pattern")
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
246
crates/server/src/api/usb.rs
Normal file
246
crates/server/src/api/usb.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
use csm_protocol::{MessageType, UsbPolicyPayload, UsbDeviceRule};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UsbEventListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub event_type: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_events(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<UsbEventListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let event_type = params.event_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, vendor_id, product_id, serial_number, device_name, event_type, event_time
|
||||
FROM usb_events WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR event_type = ?)
|
||||
ORDER BY event_time DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&event_type).bind(&event_type)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"vendor_id": r.get::<Option<String>, _>("vendor_id"),
|
||||
"product_id": r.get::<Option<String>, _>("product_id"),
|
||||
"serial_number": r.get::<Option<String>, _>("serial_number"),
|
||||
"device_name": r.get::<Option<String>, _>("device_name"),
|
||||
"event_type": r.get::<String, _>("event_type"),
|
||||
"event_time": r.get::<String, _>("event_time"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"events": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query usb events", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_policies(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
|
||||
FROM usb_policies ORDER BY created_at DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"name": r.get::<String, _>("name"),
|
||||
"policy_type": r.get::<String, _>("policy_type"),
|
||||
"target_group": r.get::<Option<String>, _>("target_group"),
|
||||
"rules": r.get::<String, _>("rules"),
|
||||
"enabled": r.get::<i32, _>("enabled"),
|
||||
"created_at": r.get::<String, _>("created_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"policies": items,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query usb policies", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreatePolicyRequest {
|
||||
pub name: String,
|
||||
pub policy_type: String,
|
||||
pub target_group: Option<String>,
|
||||
pub rules: String,
|
||||
pub enabled: Option<i32>,
|
||||
}
|
||||
|
||||
pub async fn create_policy(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<CreatePolicyRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let enabled = body.enabled.unwrap_or(1);
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&body.name)
|
||||
.bind(&body.policy_type)
|
||||
.bind(&body.target_group)
|
||||
.bind(&body.rules)
|
||||
.bind(enabled)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
// Push USB policy to matching online clients
|
||||
if enabled == 1 {
|
||||
let payload = build_usb_policy_payload(&body.policy_type, true, &body.rules);
|
||||
let target_group = body.target_group.as_deref();
|
||||
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
|
||||
}
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
|
||||
"id": new_id,
|
||||
}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create usb policy", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdatePolicyRequest {
|
||||
pub name: Option<String>,
|
||||
pub policy_type: Option<String>,
|
||||
pub target_group: Option<String>,
|
||||
pub rules: Option<String>,
|
||||
pub enabled: Option<i32>,
|
||||
}
|
||||
|
||||
pub async fn update_policy(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdatePolicyRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// Fetch existing policy
|
||||
let existing = sqlx::query("SELECT * FROM usb_policies WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Policy not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query usb policy", e)),
|
||||
};
|
||||
|
||||
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
|
||||
let policy_type = body.policy_type.unwrap_or_else(|| existing.get::<String, _>("policy_type"));
|
||||
let target_group = body.target_group.or_else(|| existing.get::<Option<String>, _>("target_group"));
|
||||
let rules = body.rules.unwrap_or_else(|| existing.get::<String, _>("rules"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE usb_policies SET name = ?, policy_type = ?, target_group = ?, rules = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(&policy_type)
|
||||
.bind(&target_group)
|
||||
.bind(&rules)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Push updated USB policy to matching online clients
|
||||
let payload = build_usb_policy_payload(&policy_type, enabled == 1, &rules);
|
||||
let target_group = target_group.as_deref();
|
||||
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
|
||||
Json(ApiResponse::ok(serde_json::json!({"updated": true})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("update usb policy", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_policy(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// Fetch existing policy to get target info for push
|
||||
let existing = sqlx::query("SELECT target_group FROM usb_policies WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let target_group = match existing {
|
||||
Ok(Some(row)) => row.get::<Option<String>, _>("target_group"),
|
||||
_ => return Json(ApiResponse::error("Policy not found")),
|
||||
};
|
||||
|
||||
let result = sqlx::query("DELETE FROM usb_policies WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
// Push disabled policy to clients
|
||||
let disabled = UsbPolicyPayload {
|
||||
policy_type: String::new(),
|
||||
enabled: false,
|
||||
rules: vec![],
|
||||
};
|
||||
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &disabled, "group", target_group.as_deref()).await;
|
||||
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Policy not found"))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("delete usb policy", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a UsbPolicyPayload from raw policy fields
|
||||
fn build_usb_policy_payload(policy_type: &str, enabled: bool, rules_json: &str) -> UsbPolicyPayload {
|
||||
let raw_rules: Vec<serde_json::Value> = serde_json::from_str(rules_json).unwrap_or_default();
|
||||
let rules: Vec<UsbDeviceRule> = raw_rules.iter().map(|r| UsbDeviceRule {
|
||||
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
|
||||
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
|
||||
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
|
||||
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
|
||||
}).collect();
|
||||
UsbPolicyPayload {
|
||||
policy_type: policy_type.to_string(),
|
||||
enabled,
|
||||
rules,
|
||||
}
|
||||
}
|
||||
28
crates/server/src/audit.rs
Normal file
28
crates/server/src/audit.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use sqlx::SqlitePool;
|
||||
use tracing::debug;
|
||||
|
||||
/// Record an admin audit log entry.
|
||||
pub async fn audit_log(
|
||||
db: &SqlitePool,
|
||||
user_id: i64,
|
||||
action: &str,
|
||||
target_type: Option<&str>,
|
||||
target_id: Option<&str>,
|
||||
detail: Option<&str>,
|
||||
) {
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, target_type, target_id, detail) VALUES (?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(action)
|
||||
.bind(target_type)
|
||||
.bind(target_id)
|
||||
.bind(detail)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => debug!("Audit: user={} action={} target={}/{}", user_id, action, target_type.unwrap_or("-"), target_id.unwrap_or("-")),
|
||||
Err(e) => tracing::warn!("Failed to write audit log: {}", e),
|
||||
}
|
||||
}
|
||||
134
crates/server/src/config.rs
Normal file
134
crates/server/src/config.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
pub retention: RetentionConfig,
|
||||
#[serde(default)]
|
||||
pub notify: NotifyConfig,
|
||||
/// Token required for device registration. Empty = any token accepted.
|
||||
#[serde(default)]
|
||||
pub registration_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
pub http_addr: String,
|
||||
pub tcp_addr: String,
|
||||
/// Allowed CORS origins. Empty = same-origin only (no CORS headers).
|
||||
#[serde(default)]
|
||||
pub cors_origins: Vec<String>,
|
||||
/// Optional TLS configuration for the TCP listener.
|
||||
#[serde(default)]
|
||||
pub tls: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct TlsConfig {
|
||||
/// Path to the server certificate (PEM format)
|
||||
pub cert_path: String,
|
||||
/// Path to the server private key (PEM format)
|
||||
pub key_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct DatabaseConfig {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct AuthConfig {
|
||||
pub jwt_secret: String,
|
||||
#[serde(default = "default_access_ttl")]
|
||||
pub access_token_ttl_secs: u64,
|
||||
#[serde(default = "default_refresh_ttl")]
|
||||
pub refresh_token_ttl_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct RetentionConfig {
|
||||
#[serde(default = "default_status_history_days")]
|
||||
pub status_history_days: u32,
|
||||
#[serde(default = "default_usb_events_days")]
|
||||
pub usb_events_days: u32,
|
||||
#[serde(default = "default_asset_changes_days")]
|
||||
pub asset_changes_days: u32,
|
||||
#[serde(default = "default_alert_records_days")]
|
||||
pub alert_records_days: u32,
|
||||
#[serde(default = "default_audit_log_days")]
|
||||
pub audit_log_days: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct NotifyConfig {
|
||||
#[serde(default)]
|
||||
pub smtp: Option<SmtpConfig>,
|
||||
#[serde(default)]
|
||||
pub webhook_urls: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct SmtpConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub from: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub async fn load(path: &str) -> Result<Self> {
|
||||
if Path::new(path).exists() {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
let config: AppConfig = toml::from_str(&content)?;
|
||||
Ok(config)
|
||||
} else {
|
||||
let config = default_config();
|
||||
// Write default config for reference
|
||||
let toml_str = toml::to_string_pretty(&config)?;
|
||||
tokio::fs::write(path, &toml_str).await?;
|
||||
tracing::warn!("Created default config at {}", path);
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_access_ttl() -> u64 { 1800 } // 30 minutes
|
||||
fn default_refresh_ttl() -> u64 { 604800 } // 7 days
|
||||
fn default_status_history_days() -> u32 { 7 }
|
||||
fn default_usb_events_days() -> u32 { 90 }
|
||||
fn default_asset_changes_days() -> u32 { 365 }
|
||||
fn default_alert_records_days() -> u32 { 90 }
|
||||
fn default_audit_log_days() -> u32 { 365 }
|
||||
|
||||
pub fn default_config() -> AppConfig {
|
||||
AppConfig {
|
||||
server: ServerConfig {
|
||||
http_addr: "0.0.0.0:8080".into(),
|
||||
tcp_addr: "0.0.0.0:9999".into(),
|
||||
cors_origins: vec![],
|
||||
tls: None,
|
||||
},
|
||||
database: DatabaseConfig {
|
||||
path: "./csm.db".into(),
|
||||
},
|
||||
auth: AuthConfig {
|
||||
jwt_secret: uuid::Uuid::new_v4().to_string(),
|
||||
access_token_ttl_secs: default_access_ttl(),
|
||||
refresh_token_ttl_secs: default_refresh_ttl(),
|
||||
},
|
||||
retention: RetentionConfig {
|
||||
status_history_days: default_status_history_days(),
|
||||
usb_events_days: default_usb_events_days(),
|
||||
asset_changes_days: default_asset_changes_days(),
|
||||
alert_records_days: default_alert_records_days(),
|
||||
audit_log_days: default_audit_log_days(),
|
||||
},
|
||||
notify: NotifyConfig::default(),
|
||||
registration_token: uuid::Uuid::new_v4().to_string(),
|
||||
}
|
||||
}
|
||||
118
crates/server/src/db.rs
Normal file
118
crates/server/src/db.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use sqlx::SqlitePool;
|
||||
use anyhow::Result;
|
||||
|
||||
/// Database repository for device operations
|
||||
pub struct DeviceRepo;
|
||||
|
||||
impl DeviceRepo {
|
||||
pub async fn upsert_status(pool: &SqlitePool, device_uid: &str, status: &csm_protocol::DeviceStatus) -> Result<()> {
|
||||
let top_procs_json = serde_json::to_string(&status.top_processes)?;
|
||||
|
||||
// Update latest snapshot
|
||||
sqlx::query(
|
||||
"INSERT INTO device_status (device_uid, cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb, network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(device_uid) DO UPDATE SET
|
||||
cpu_usage = excluded.cpu_usage,
|
||||
memory_usage = excluded.memory_usage,
|
||||
memory_total_mb = excluded.memory_total_mb,
|
||||
disk_usage = excluded.disk_usage,
|
||||
disk_total_mb = excluded.disk_total_mb,
|
||||
network_rx_rate = excluded.network_rx_rate,
|
||||
network_tx_rate = excluded.network_tx_rate,
|
||||
running_procs = excluded.running_procs,
|
||||
top_processes = excluded.top_processes,
|
||||
reported_at = datetime('now'),
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(status.cpu_usage)
|
||||
.bind(status.memory_usage)
|
||||
.bind(status.memory_total_mb as i64)
|
||||
.bind(status.disk_usage)
|
||||
.bind(status.disk_total_mb as i64)
|
||||
.bind(status.network_rx_rate as i64)
|
||||
.bind(status.network_tx_rate as i64)
|
||||
.bind(status.running_procs as i32)
|
||||
.bind(&top_procs_json)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Insert into history
|
||||
sqlx::query(
|
||||
"INSERT INTO device_status_history (device_uid, cpu_usage, memory_usage, disk_usage, network_rx_rate, network_tx_rate, running_procs, reported_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(status.cpu_usage)
|
||||
.bind(status.memory_usage)
|
||||
.bind(status.disk_usage)
|
||||
.bind(status.network_rx_rate as i64)
|
||||
.bind(status.network_tx_rate as i64)
|
||||
.bind(status.running_procs as i32)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Update device heartbeat
|
||||
sqlx::query(
|
||||
"UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_usb_event(pool: &SqlitePool, event: &csm_protocol::UsbEvent) -> Result<i64> {
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO usb_events (device_uid, vendor_id, product_id, serial_number, device_name, event_type)
|
||||
VALUES (?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&event.device_uid)
|
||||
.bind(&event.vendor_id)
|
||||
.bind(&event.product_id)
|
||||
.bind(&event.serial)
|
||||
.bind(&event.device_name)
|
||||
.bind(match event.event_type {
|
||||
csm_protocol::UsbEventType::Inserted => "inserted",
|
||||
csm_protocol::UsbEventType::Removed => "removed",
|
||||
csm_protocol::UsbEventType::Blocked => "blocked",
|
||||
})
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result.last_insert_rowid())
|
||||
}
|
||||
|
||||
pub async fn upsert_hardware(pool: &SqlitePool, asset: &csm_protocol::HardwareAsset) -> Result<()> {
|
||||
sqlx::query(
|
||||
"INSERT INTO hardware_assets (device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb, gpu_model, motherboard, serial_number, reported_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(device_uid) DO UPDATE SET
|
||||
cpu_model = excluded.cpu_model,
|
||||
cpu_cores = excluded.cpu_cores,
|
||||
memory_total_mb = excluded.memory_total_mb,
|
||||
disk_model = excluded.disk_model,
|
||||
disk_total_mb = excluded.disk_total_mb,
|
||||
gpu_model = excluded.gpu_model,
|
||||
motherboard = excluded.motherboard,
|
||||
serial_number = excluded.serial_number,
|
||||
reported_at = datetime('now'),
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&asset.device_uid)
|
||||
.bind(&asset.cpu_model)
|
||||
.bind(asset.cpu_cores as i32)
|
||||
.bind(asset.memory_total_mb as i64)
|
||||
.bind(&asset.disk_model)
|
||||
.bind(asset.disk_total_mb as i64)
|
||||
.bind(&asset.gpu_model)
|
||||
.bind(&asset.motherboard)
|
||||
.bind(&asset.serial_number)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
264
crates/server/src/main.rs
Normal file
264
crates/server/src/main.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use anyhow::Result;
|
||||
use axum::Router;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::cors::{CorsLayer, Any};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::compression::CompressionLayer;
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
use tracing::{info, warn, error};
|
||||
|
||||
mod api;
|
||||
mod audit;
|
||||
mod config;
|
||||
mod db;
|
||||
mod tcp;
|
||||
mod ws;
|
||||
mod alert;
|
||||
|
||||
use config::AppConfig;
|
||||
|
||||
/// Application shared state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub db: sqlx::SqlitePool,
|
||||
pub config: Arc<AppConfig>,
|
||||
pub clients: Arc<tcp::ClientRegistry>,
|
||||
pub ws_hub: Arc<ws::WsHub>,
|
||||
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "csm_server=info,tower_http=info".into()),
|
||||
)
|
||||
.json()
|
||||
.init();
|
||||
|
||||
info!("CSM Server starting...");
|
||||
|
||||
// Load configuration
|
||||
let config = AppConfig::load("config.toml").await?;
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Initialize database
|
||||
let db = init_database(&config.database.path).await?;
|
||||
run_migrations(&db).await?;
|
||||
info!("Database initialized at {}", config.database.path);
|
||||
|
||||
// Ensure default admin exists
|
||||
ensure_default_admin(&db).await?;
|
||||
|
||||
// Initialize shared state
|
||||
let clients = Arc::new(tcp::ClientRegistry::new());
|
||||
let ws_hub = Arc::new(ws::WsHub::new());
|
||||
|
||||
let state = AppState {
|
||||
db: db.clone(),
|
||||
config: config.clone(),
|
||||
clients: clients.clone(),
|
||||
ws_hub: ws_hub.clone(),
|
||||
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
|
||||
};
|
||||
|
||||
// Start background tasks
|
||||
let cleanup_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
alert::cleanup_task(cleanup_state).await;
|
||||
});
|
||||
|
||||
// Start TCP listener for client connections
|
||||
let tcp_state = state.clone();
|
||||
let tcp_addr = config.server.tcp_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = tcp::start_tcp_server(tcp_addr, tcp_state).await {
|
||||
error!("TCP server error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Build HTTP router
|
||||
let app = Router::new()
|
||||
.merge(api::routes(state.clone()))
|
||||
.layer(
|
||||
build_cors_layer(&config.server.cors_origins),
|
||||
)
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CompressionLayer::new())
|
||||
// Security headers
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::X_CONTENT_TYPE_OPTIONS,
|
||||
axum::http::HeaderValue::from_static("nosniff"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::X_FRAME_OPTIONS,
|
||||
axum::http::HeaderValue::from_static("DENY"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("x-xss-protection"),
|
||||
axum::http::HeaderValue::from_static("1; mode=block"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("referrer-policy"),
|
||||
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("content-security-policy"),
|
||||
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
|
||||
))
|
||||
.with_state(state);
|
||||
|
||||
// Start HTTP server
|
||||
let http_addr = &config.server.http_addr;
|
||||
info!("HTTP server listening on {}", http_addr);
|
||||
let listener = TcpListener::bind(http_addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn init_database(db_path: &str) -> Result<sqlx::SqlitePool> {
|
||||
// Ensure parent directory exists for file-based databases
|
||||
// Strip sqlite: prefix if present for directory creation
|
||||
let file_path = db_path.strip_prefix("sqlite:").unwrap_or(db_path);
|
||||
// Strip query parameters
|
||||
let file_path = file_path.split('?').next().unwrap_or(file_path);
|
||||
if let Some(parent) = Path::new(file_path).parent() {
|
||||
if !parent.as_os_str().is_empty() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let options = SqliteConnectOptions::from_str(db_path)?
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
|
||||
.busy_timeout(std::time::Duration::from_secs(5))
|
||||
.foreign_keys(true);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(8)
|
||||
.connect_with(options)
|
||||
.await?;
|
||||
|
||||
// Set pragmas on each connection
|
||||
sqlx::query("PRAGMA cache_size = -64000")
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
sqlx::query("PRAGMA wal_autocheckpoint = 1000")
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
// Embedded migrations - run in order
|
||||
let migrations = [
|
||||
include_str!("../../../migrations/001_init.sql"),
|
||||
include_str!("../../../migrations/002_assets.sql"),
|
||||
include_str!("../../../migrations/003_usb.sql"),
|
||||
include_str!("../../../migrations/004_alerts.sql"),
|
||||
include_str!("../../../migrations/005_plugins_web_filter.sql"),
|
||||
include_str!("../../../migrations/006_plugins_usage_timer.sql"),
|
||||
include_str!("../../../migrations/007_plugins_software_blocker.sql"),
|
||||
include_str!("../../../migrations/008_plugins_popup_blocker.sql"),
|
||||
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
|
||||
include_str!("../../../migrations/010_plugins_watermark.sql"),
|
||||
include_str!("../../../migrations/011_token_security.sql"),
|
||||
];
|
||||
|
||||
// Create migrations tracking table
|
||||
sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS _migrations (id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, applied_at TEXT NOT NULL DEFAULT (datetime('now')))"
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
for (i, migration_sql) in migrations.iter().enumerate() {
|
||||
let name = format!("{:03}", i + 1);
|
||||
let exists: bool = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM _migrations WHERE name = ?"
|
||||
)
|
||||
.bind(&name)
|
||||
.fetch_one(pool)
|
||||
.await? > 0;
|
||||
|
||||
if !exists {
|
||||
info!("Running migration: {}", name);
|
||||
sqlx::query(migration_sql)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
sqlx::query("INSERT INTO _migrations (name) VALUES (?)")
|
||||
.bind(&name)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users")
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if count == 0 {
|
||||
// Generate a random 16-character alphanumeric password
|
||||
let random_password: String = {
|
||||
use std::fmt::Write;
|
||||
let bytes = uuid::Uuid::new_v4();
|
||||
let mut s = String::with_capacity(16);
|
||||
for byte in bytes.as_bytes().iter().take(16) {
|
||||
write!(s, "{:02x}", byte).unwrap();
|
||||
}
|
||||
s
|
||||
};
|
||||
|
||||
let password_hash = bcrypt::hash(&random_password, 12)?;
|
||||
sqlx::query(
|
||||
"INSERT INTO users (username, password, role) VALUES (?, ?, 'admin')"
|
||||
)
|
||||
.bind("admin")
|
||||
.bind(&password_hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
warn!("Created default admin user (username: admin)");
|
||||
// Print password directly to stderr — bypasses tracing JSON formatter
|
||||
eprintln!("============================================================");
|
||||
eprintln!(" Generated admin password: {}", random_password);
|
||||
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
|
||||
eprintln!("============================================================");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build CORS layer from configured origins.
|
||||
/// If cors_origins is empty, no CORS headers are sent (same-origin only).
|
||||
/// If origins are specified, only those are allowed.
|
||||
fn build_cors_layer(origins: &[String]) -> CorsLayer {
|
||||
use axum::http::HeaderValue;
|
||||
|
||||
let allowed_origins: Vec<HeaderValue> = origins.iter()
|
||||
.filter_map(|o| o.parse::<HeaderValue>().ok())
|
||||
.collect();
|
||||
|
||||
if allowed_origins.is_empty() {
|
||||
// No CORS — production safe by default
|
||||
CorsLayer::new()
|
||||
} else {
|
||||
CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.max_age(std::time::Duration::from_secs(3600))
|
||||
}
|
||||
}
|
||||
844
crates/server/src/tcp.rs
Normal file
844
crates/server/src/tcp.rs
Normal file
@@ -0,0 +1,844 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tracing::{info, warn, debug};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use csm_protocol::{Frame, MessageType, PROTOCOL_VERSION};
|
||||
use crate::AppState;
|
||||
|
||||
/// Maximum frames per second per connection before rate-limiting kicks in
|
||||
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
|
||||
const RATE_LIMIT_MAX_FRAMES: usize = 100;
|
||||
|
||||
/// Per-connection rate limiter using a sliding window of frame timestamps
|
||||
struct RateLimiter {
|
||||
timestamps: Vec<Instant>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
fn new() -> Self {
|
||||
Self { timestamps: Vec::with_capacity(RATE_LIMIT_MAX_FRAMES) }
|
||||
}
|
||||
|
||||
/// Returns false if the connection is rate-limited
|
||||
fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
let cutoff = now - std::time::Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
|
||||
|
||||
// Evict timestamps outside the window
|
||||
self.timestamps.retain(|t| *t > cutoff);
|
||||
|
||||
if self.timestamps.len() >= RATE_LIMIT_MAX_FRAMES {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.timestamps.push(now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a plugin config frame to all online clients matching the target scope.
|
||||
/// target_type: "global" | "device" | "group"
|
||||
/// target_id: device_uid or group_name (None for global)
|
||||
pub async fn push_to_targets(
|
||||
db: &sqlx::SqlitePool,
|
||||
clients: &crate::tcp::ClientRegistry,
|
||||
msg_type: MessageType,
|
||||
payload: &impl serde::Serialize,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) {
|
||||
let frame = match Frame::new_json(msg_type, payload) {
|
||||
Ok(f) => f.encode(),
|
||||
Err(e) => {
|
||||
warn!("Failed to encode plugin push frame: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let online = clients.list_online().await;
|
||||
let mut pushed_count = 0usize;
|
||||
|
||||
// For group targeting, resolve group members from DB once
|
||||
let group_members: Option<Vec<String>> = if target_type == "group" {
|
||||
if let Some(group_name) = target_id {
|
||||
sqlx::query_scalar::<_, String>(
|
||||
"SELECT device_uid FROM devices WHERE group_name = ?"
|
||||
)
|
||||
.bind(group_name)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
.ok()
|
||||
.into()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
for uid in &online {
|
||||
let should_push = match target_type {
|
||||
"global" => true,
|
||||
"device" => target_id.map_or(false, |id| id == uid),
|
||||
"group" => {
|
||||
if let Some(members) = &group_members {
|
||||
members.contains(uid)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
other => {
|
||||
warn!("Unknown target_type '{}', skipping push", other);
|
||||
false
|
||||
}
|
||||
};
|
||||
if should_push {
|
||||
if clients.send_to(uid, frame.clone()).await {
|
||||
pushed_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Pushed {:?} to {}/{} online clients (target={})", msg_type, pushed_count, online.len(), target_type);
|
||||
}
|
||||
|
||||
/// Push all active plugin configs to a newly registered client.
|
||||
pub async fn push_all_plugin_configs(
|
||||
db: &sqlx::SqlitePool,
|
||||
clients: &crate::tcp::ClientRegistry,
|
||||
device_uid: &str,
|
||||
) {
|
||||
use sqlx::Row;
|
||||
|
||||
// Watermark configs — only push the highest-priority enabled config (device > group > global)
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT content, font_size, opacity, color, angle, enabled FROM watermark_config WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?))) ORDER BY CASE WHEN target_type = 'device' THEN 0 WHEN target_type = 'group' THEN 1 ELSE 2 END LIMIT 1"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
if let Some(row) = rows.first() {
|
||||
let config = csm_protocol::WatermarkConfigPayload {
|
||||
content: row.get("content"),
|
||||
font_size: row.get::<i32, _>("font_size") as u32,
|
||||
opacity: row.get("opacity"),
|
||||
color: row.get("color"),
|
||||
angle: row.get::<i32, _>("angle"),
|
||||
enabled: row.get("enabled"),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::WatermarkConfig, &config) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Web filter rules
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"pattern": r.get::<String, _>("pattern"),
|
||||
})).collect();
|
||||
if !rules.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Software blacklist
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let entries: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"name_pattern": r.get::<String, _>("name_pattern"),
|
||||
"action": r.get::<String, _>("action"),
|
||||
})).collect();
|
||||
if !entries.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Popup blocker rules
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"window_title": r.get::<Option<String>, _>("window_title"),
|
||||
"window_class": r.get::<Option<String>, _>("window_class"),
|
||||
"process_name": r.get::<Option<String>, _>("process_name"),
|
||||
})).collect();
|
||||
if !rules.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PopupRules, &serde_json::json!({"rules": rules})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// USB policies — push highest-priority enabled policy for the device's group
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT policy_type, rules, enabled FROM usb_policies WHERE enabled = 1 AND target_group = (SELECT group_name FROM devices WHERE device_uid = ?) ORDER BY CASE WHEN policy_type = 'all_block' THEN 0 WHEN policy_type = 'blacklist' THEN 1 ELSE 2 END LIMIT 1"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
if let Some(row) = rows.first() {
|
||||
let policy_type: String = row.get("policy_type");
|
||||
let rules_json: String = row.get("rules");
|
||||
let rules: Vec<serde_json::Value> = serde_json::from_str(&rules_json).unwrap_or_default();
|
||||
let payload = csm_protocol::UsbPolicyPayload {
|
||||
policy_type,
|
||||
enabled: true,
|
||||
rules: rules.iter().map(|r| csm_protocol::UsbDeviceRule {
|
||||
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
|
||||
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
|
||||
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
|
||||
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
|
||||
}).collect(),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsbPolicyUpdate, &payload) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Pushed all plugin configs to newly registered device {}", device_uid);
|
||||
}
|
||||
|
||||
/// Maximum accumulated read buffer size per connection (8 MB)
|
||||
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Registry of all connected client sessions
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ClientRegistry {
|
||||
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
|
||||
self.sessions.write().await.insert(device_uid, tx);
|
||||
}
|
||||
|
||||
pub async fn unregister(&self, device_uid: &str) {
|
||||
self.sessions.write().await.remove(device_uid);
|
||||
}
|
||||
|
||||
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
|
||||
if let Some(tx) = self.sessions.read().await.get(device_uid) {
|
||||
tx.send(data).await.is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn count(&self) -> usize {
|
||||
self.sessions.read().await.len()
|
||||
}
|
||||
|
||||
pub async fn list_online(&self) -> Vec<String> {
|
||||
self.sessions.read().await.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the TCP server for client connections (optionally with TLS)
|
||||
pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<()> {
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
|
||||
// Build TLS acceptor if configured
|
||||
let tls_acceptor = build_tls_acceptor(&state.config.server.tls)?;
|
||||
|
||||
if tls_acceptor.is_some() {
|
||||
info!("TCP server listening on {} (TLS enabled)", addr);
|
||||
} else {
|
||||
info!("TCP server listening on {} (plaintext)", addr);
|
||||
}
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let state = state.clone();
|
||||
let acceptor = tls_acceptor.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
debug!("New TCP connection from {}", peer_addr);
|
||||
match acceptor {
|
||||
Some(acceptor) => {
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
if let Err(e) = handle_client_tls(tls_stream, state).await {
|
||||
warn!("Client {} TLS error: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("TLS handshake failed for {}: {}", peer_addr, e),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if let Err(e) = handle_client(stream, state).await {
|
||||
warn!("Client {} error: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tls_acceptor(
|
||||
tls_config: &Option<crate::config::TlsConfig>,
|
||||
) -> anyhow::Result<Option<tokio_rustls::TlsAcceptor>> {
|
||||
let config = match tls_config {
|
||||
Some(c) => c,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let cert_pem = std::fs::read(&config.cert_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read TLS cert {}: {}", config.cert_path, e))?;
|
||||
let key_pem = std::fs::read(&config.key_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read TLS key {}: {}", config.key_path, e))?;
|
||||
|
||||
let certs: Vec<rustls_pki_types::CertificateDer> = rustls_pemfile::certs(&mut &cert_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse TLS cert: {:?}", e))?
|
||||
.into_iter()
|
||||
.map(|c| c.into())
|
||||
.collect();
|
||||
|
||||
let key = rustls_pemfile::private_key(&mut &key_pem[..])
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse TLS key: {:?}", e))?
|
||||
.ok_or_else(|| anyhow::anyhow!("No private key found in {}", config.key_path))?;
|
||||
|
||||
let server_config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build TLS config: {}", e))?;
|
||||
|
||||
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(server_config))))
|
||||
}
|
||||
|
||||
/// Cleanup on client disconnect: unregister from client map, mark offline, notify WS.
|
||||
async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
|
||||
if let Some(uid) = device_uid {
|
||||
state.clients.unregister(uid).await;
|
||||
sqlx::query("UPDATE devices SET status = 'offline' WHERE device_uid = ?")
|
||||
.bind(uid)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "device_state",
|
||||
"device_uid": uid,
|
||||
"status": "offline"
|
||||
}).to_string()).await;
|
||||
|
||||
info!("Device disconnected: {}", uid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
|
||||
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
/// Verify that a frame sender is a registered device and that the claimed device_uid
|
||||
/// matches the one registered on this connection. Returns true if valid.
|
||||
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
|
||||
match device_uid {
|
||||
Some(uid) if *uid == claimed_uid => true,
|
||||
Some(uid) => {
|
||||
warn!("{} device_uid mismatch: expected {:?}, got {}", msg_type, uid, claimed_uid);
|
||||
false
|
||||
}
|
||||
None => {
|
||||
warn!("{} from unregistered connection", msg_type);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
|
||||
async fn process_frame(
|
||||
frame: Frame,
|
||||
state: &AppState,
|
||||
device_uid: &mut Option<String>,
|
||||
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
) -> anyhow::Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::Register => {
|
||||
let req: csm_protocol::RegisterRequest = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration payload: {}", e))?;
|
||||
|
||||
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
|
||||
|
||||
// Validate registration token against configured token
|
||||
let expected_token = &state.config.registration_token;
|
||||
if !expected_token.is_empty() {
|
||||
if req.registration_token.is_empty() || req.registration_token != *expected_token {
|
||||
warn!("Registration rejected for {}: invalid token", req.device_uid);
|
||||
let err_frame = Frame::new_json(MessageType::RegisterResponse,
|
||||
&serde_json::json!({"error": "invalid_registration_token"}))?;
|
||||
tx.send(err_frame.encode()).await.ok();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Check if device already exists with a secret (reconnection scenario)
|
||||
let existing_secret: Option<String> = sqlx::query_scalar(
|
||||
"SELECT device_secret FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&req.device_uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let device_secret = match existing_secret {
|
||||
// Existing device — keep the same secret, don't rotate
|
||||
Some(secret) if !secret.is_empty() => secret,
|
||||
// New device — generate a fresh secret
|
||||
_ => uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
|
||||
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
|
||||
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
|
||||
mac_address=excluded.mac_address, status='online'"
|
||||
)
|
||||
.bind(&req.device_uid)
|
||||
.bind(&req.hostname)
|
||||
.bind(&req.mac_address)
|
||||
.bind(&req.os_version)
|
||||
.bind(&device_secret)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error during registration: {}", e))?;
|
||||
|
||||
*device_uid = Some(req.device_uid.clone());
|
||||
// If this device was already connected on a different session, evict the old one
|
||||
// The new register() call will replace it in the hashmap
|
||||
state.clients.register(req.device_uid.clone(), tx.clone()).await;
|
||||
|
||||
// Send registration response
|
||||
let config = csm_protocol::ClientConfig::default();
|
||||
let response = csm_protocol::RegisterResponse {
|
||||
device_secret,
|
||||
config,
|
||||
};
|
||||
let resp_frame = Frame::new_json(MessageType::RegisterResponse, &response)?;
|
||||
tx.send(resp_frame.encode()).await?;
|
||||
|
||||
info!("Device registered successfully: {} ({})", req.hostname, req.device_uid);
|
||||
|
||||
// Push all active plugin configs to newly registered client
|
||||
push_all_plugin_configs(&state.db, &state.clients, &req.device_uid).await;
|
||||
}
|
||||
|
||||
MessageType::Heartbeat => {
|
||||
let heartbeat: csm_protocol::HeartbeatPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid heartbeat: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "Heartbeat", &heartbeat.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
|
||||
let secret: Option<String> = sqlx::query_scalar(
|
||||
"SELECT device_secret FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&heartbeat.device_uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
|
||||
anyhow::anyhow!("DB error during HMAC verification")
|
||||
})?;
|
||||
|
||||
if let Some(ref secret) = secret {
|
||||
if !secret.is_empty() {
|
||||
if heartbeat.hmac.is_empty() {
|
||||
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
|
||||
return Ok(());
|
||||
}
|
||||
// Constant-time HMAC verification using hmac::Mac::verify_slice
|
||||
let message = format!("{}\n{}", heartbeat.device_uid, heartbeat.timestamp);
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
|
||||
.map_err(|_| anyhow::anyhow!("HMAC key error"))?;
|
||||
mac.update(message.as_bytes());
|
||||
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
|
||||
if mac.verify_slice(&provided_bytes).is_err() {
|
||||
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Heartbeat from {} (hmac verified)", heartbeat.device_uid);
|
||||
|
||||
// Update device status in DB
|
||||
sqlx::query("UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?")
|
||||
.bind(&heartbeat.device_uid)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Push to WebSocket subscribers
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "device_state",
|
||||
"device_uid": heartbeat.device_uid,
|
||||
"status": "online"
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::StatusReport => {
|
||||
let status: csm_protocol::DeviceStatus = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid status report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "StatusReport", &status.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
crate::db::DeviceRepo::upsert_status(&state.db, &status.device_uid, &status).await?;
|
||||
|
||||
// Push to WebSocket subscribers
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "device_status",
|
||||
"device_uid": status.device_uid,
|
||||
"cpu": status.cpu_usage,
|
||||
"memory": status.memory_usage,
|
||||
"disk": status.disk_usage
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::UsbEvent => {
|
||||
let event: csm_protocol::UsbEvent = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid USB event: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "UsbEvent", &event.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
crate::db::DeviceRepo::insert_usb_event(&state.db, &event).await?;
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "usb_event",
|
||||
"device_uid": event.device_uid,
|
||||
"event": event.event_type,
|
||||
"usb_name": event.device_name
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::AssetReport => {
|
||||
let asset: csm_protocol::HardwareAsset = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid asset report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "AssetReport", &asset.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
crate::db::DeviceRepo::upsert_hardware(&state.db, &asset).await?;
|
||||
}
|
||||
|
||||
MessageType::UsageReport => {
|
||||
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "UsageReport", &report.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO usage_daily (device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?) \
|
||||
ON CONFLICT(device_uid, date) DO UPDATE SET \
|
||||
total_active_minutes = excluded.total_active_minutes, \
|
||||
total_idle_minutes = excluded.total_idle_minutes, \
|
||||
first_active_at = excluded.first_active_at, \
|
||||
last_active_at = excluded.last_active_at"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.date)
|
||||
.bind(report.total_active_minutes as i32)
|
||||
.bind(report.total_idle_minutes as i32)
|
||||
.bind(&report.first_active_at)
|
||||
.bind(&report.last_active_at)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting usage report: {}", e))?;
|
||||
|
||||
debug!("Usage report saved for device {}", report.device_uid);
|
||||
}
|
||||
|
||||
MessageType::AppUsageReport => {
|
||||
let report: csm_protocol::AppUsageEntry = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid app usage report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "AppUsageReport", &report.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
|
||||
VALUES (?, ?, ?, ?) \
|
||||
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
|
||||
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.date)
|
||||
.bind(&report.app_name)
|
||||
.bind(report.usage_minutes as i32)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting app usage: {}", e))?;
|
||||
|
||||
debug!("App usage saved: {} -> {} ({} min)", report.device_uid, report.app_name, report.usage_minutes);
|
||||
}
|
||||
|
||||
MessageType::SoftwareViolation => {
|
||||
let report: csm_protocol::SoftwareViolationReport = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software violation: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "SoftwareViolation", &report.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO software_violations (device_uid, software_name, action_taken, timestamp) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.software_name)
|
||||
.bind(&report.action_taken)
|
||||
.bind(&report.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting software violation: {}", e))?;
|
||||
|
||||
info!("Software violation: {} tried to run {} -> {}", report.device_uid, report.software_name, report.action_taken);
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "software_violation",
|
||||
"device_uid": report.device_uid,
|
||||
"software_name": report.software_name,
|
||||
"action_taken": report.action_taken
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::UsbFileOp => {
|
||||
let entry: csm_protocol::UsbFileOpEntry = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid USB file op: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "UsbFileOp", &entry.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO usb_file_operations (device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&entry.device_uid)
|
||||
.bind(&entry.usb_serial)
|
||||
.bind(&entry.drive_letter)
|
||||
.bind(&entry.operation)
|
||||
.bind(&entry.file_path)
|
||||
.bind(entry.file_size.map(|s| s as i64))
|
||||
.bind(&entry.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting USB file op: {}", e))?;
|
||||
|
||||
debug!("USB file op: {} {} on {}", entry.operation, entry.file_path, entry.device_uid);
|
||||
}
|
||||
|
||||
MessageType::WebAccessLog => {
|
||||
let entry: csm_protocol::WebAccessLogEntry = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web access log: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "WebAccessLog", &entry.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO web_access_log (device_uid, url, action, timestamp) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&entry.device_uid)
|
||||
.bind(&entry.url)
|
||||
.bind(&entry.action)
|
||||
.bind(&entry.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting web access log: {}", e))?;
|
||||
|
||||
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a single client TCP connection
|
||||
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Set read timeout to detect stale connections
|
||||
let _ = stream.set_nodelay(true);
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
|
||||
// Writer task: forwards messages from channel to TCP stream
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
if writer.write_all(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Reader loop
|
||||
'reader: loop {
|
||||
let n = reader.read(&mut buffer).await?;
|
||||
if n == 0 {
|
||||
break; // Connection closed
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Guard against unbounded buffer growth
|
||||
if read_buf.len() > MAX_READ_BUF_SIZE {
|
||||
warn!("Connection exceeded max buffer size, dropping");
|
||||
break;
|
||||
}
|
||||
|
||||
// Process complete frames
|
||||
while let Some(frame) = Frame::decode(&read_buf)? {
|
||||
let frame_size = frame.encoded_size();
|
||||
// Remove consumed bytes without reallocating
|
||||
read_buf.drain(..frame_size);
|
||||
|
||||
// Rate limit check
|
||||
if !rate_limiter.check() {
|
||||
warn!("Rate limit exceeded for device {:?}, dropping connection", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
|
||||
// Verify protocol version
|
||||
if frame.version != PROTOCOL_VERSION {
|
||||
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_on_disconnect(&state, &device_uid).await;
|
||||
write_task.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a TLS-wrapped client connection
|
||||
async fn handle_client_tls(
|
||||
stream: tokio_rustls::server::TlsStream<TcpStream>,
|
||||
state: AppState,
|
||||
) -> anyhow::Result<()> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let (mut reader, mut writer) = tokio::io::split(stream);
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
if writer.write_all(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Reader loop — same logic as plaintext handler
|
||||
'reader: loop {
|
||||
let n = reader.read(&mut buffer).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
if read_buf.len() > MAX_READ_BUF_SIZE {
|
||||
warn!("TLS connection exceeded max buffer size, dropping");
|
||||
break;
|
||||
}
|
||||
|
||||
while let Some(frame) = Frame::decode(&read_buf)? {
|
||||
let frame_size = frame.encoded_size();
|
||||
read_buf.drain(..frame_size);
|
||||
|
||||
if frame.version != PROTOCOL_VERSION {
|
||||
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !rate_limiter.check() {
|
||||
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_on_disconnect(&state, &device_uid).await;
|
||||
write_task.abort();
|
||||
Ok(())
|
||||
}
|
||||
125
crates/server/src/ws.rs
Normal file
125
crates/server/src/ws.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::extract::Query;
|
||||
use jsonwebtoken::{decode, Validation, DecodingKey};
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::broadcast;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, warn};
|
||||
use crate::api::auth::Claims;
|
||||
use crate::AppState;
|
||||
|
||||
/// WebSocket hub for broadcasting real-time events to admin browsers
|
||||
#[derive(Clone)]
|
||||
pub struct WsHub {
|
||||
tx: broadcast::Sender<String>,
|
||||
}
|
||||
|
||||
impl WsHub {
|
||||
pub fn new() -> Self {
|
||||
let (tx, _) = broadcast::channel(1024);
|
||||
Self { tx }
|
||||
}
|
||||
|
||||
pub async fn broadcast(&self, message: String) {
|
||||
if self.tx.send(message).is_err() {
|
||||
debug!("No WebSocket subscribers to receive broadcast");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<String> {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsAuthParams {
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
/// HTTP upgrade handler for WebSocket connections
|
||||
/// Validates JWT token from query parameter before upgrading
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Query(params): Query<WsAuthParams>,
|
||||
axum::extract::State(state): axum::extract::State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let token = match params.token {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("WebSocket connection rejected: no token provided");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let claims = match decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(c) => c.claims,
|
||||
Err(e) => {
|
||||
warn!("WebSocket connection rejected: invalid token - {}", e);
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if claims.token_type != "access" {
|
||||
warn!("WebSocket connection rejected: not an access token");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
|
||||
}
|
||||
|
||||
let hub = state.ws_hub.clone();
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
|
||||
debug!("WebSocket client connected: user={}", claims.username);
|
||||
|
||||
let welcome = serde_json::json!({
|
||||
"type": "connected",
|
||||
"message": "CSM real-time feed active",
|
||||
"user": claims.username
|
||||
});
|
||||
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Subscribe to broadcast hub for real-time events
|
||||
let mut rx = hub.subscribe();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward broadcast messages to WebSocket client
|
||||
msg = rx.recv() => {
|
||||
match msg {
|
||||
Ok(text) => {
|
||||
if socket.send(Message::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
debug!("WebSocket client lagged {} messages, continuing", n);
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
// Handle incoming WebSocket messages (ping/close)
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
if socket.send(Message::Pong(data)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => break,
|
||||
Some(Err(_)) => break,
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("WebSocket client disconnected: user={}", claims.username);
|
||||
}
|
||||
Reference in New Issue
Block a user