feat: 初始化项目基础架构和核心功能

- 添加项目基础结构:Cargo.toml、.gitignore、设备UID和密钥文件
- 实现前端Vue3项目结构:路由、登录页面、设备管理页面
- 添加核心协议定义(crates/protocol):设备状态、资产、USB事件等
- 实现客户端监控模块:系统状态收集、资产收集
- 实现服务端基础API和插件系统
- 添加数据库迁移脚本:设备管理、资产跟踪、告警系统等
- 实现前端设备状态展示和基本交互
- 添加使用时长统计和水印功能插件
This commit is contained in:
iven
2026-04-05 00:57:51 +08:00
commit fd6fb5cca0
87 changed files with 19576 additions and 0 deletions

36
.gitignore vendored Normal file
View File

@@ -0,0 +1,36 @@
# Rust build artifacts
target/
# Database files
*.db
*.db-journal
*.db-wal
*.db-shm
# Configuration with secrets
config.toml
# Environment variables
.env
.env.*
# Logs
*.log
# Frontend
web/node_modules/
web/dist/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Plans (development artifacts)
plans/

3923
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

29
Cargo.toml Normal file
View File

@@ -0,0 +1,29 @@
[workspace]
resolver = "2"
members = [
"crates/protocol",
"crates/server",
"crates/client",
]
[workspace.package]
version = "0.1.0"
edition = "2021"
license = "MIT"
[workspace.dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
tokio = { version = "1", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
anyhow = "1"
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
[profile.release]
lto = true
strip = true
codegen-units = 1
opt-level = "s"

48
crates/client/Cargo.toml Normal file
View File

@@ -0,0 +1,48 @@
[package]
name = "csm-client"
version.workspace = true
edition.workspace = true
[dependencies]
csm-protocol = { path = "../protocol" }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
sysinfo = "0.30"
tokio-rustls = "0.26"
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
rustls-pki-types = "1"
webpki-roots = "0.26"
rustls-pemfile = "2"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"
[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.54", features = [
"Win32_Foundation",
"Win32_System_SystemInformation",
"Win32_System_Registry",
"Win32_System_IO",
"Win32_Security",
"Win32_NetworkManagement_IpHelper",
"Win32_Storage_FileSystem",
"Win32_UI_WindowsAndMessaging",
"Win32_UI_Input_KeyboardAndMouse",
"Win32_System_Threading",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_LibraryLoader",
"Win32_System_Performance",
"Win32_Graphics_Gdi",
] }
windows-service = "0.7"
hostname = "0.4"
[target.'cfg(not(target_os = "windows"))'.dependencies]
hostname = "0.4"

View File

@@ -0,0 +1,54 @@
use csm_protocol::{Frame, MessageType, HardwareAsset};
use std::time::Duration;
use tokio::sync::mpsc::Sender;
use tracing::{info, error};
use sysinfo::System;
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
let interval = Duration::from_secs(86400); // Once per day
// Initial collection on startup
if let Err(e) = collect_and_send(&tx, &device_uid).await {
error!("Initial asset collection failed: {}", e);
}
loop {
tokio::time::sleep(interval).await;
if let Err(e) = collect_and_send(&tx, &device_uid).await {
error!("Asset collection failed: {}", e);
}
}
}
async fn collect_and_send(tx: &Sender<Frame>, device_uid: &str) -> anyhow::Result<()> {
let hardware = collect_hardware(device_uid)?;
let frame = Frame::new_json(MessageType::AssetReport, &hardware)?;
tx.send(frame).await.map_err(|e| anyhow::anyhow!("Channel send failed: {}", e))?;
info!("Asset report sent for {}", device_uid);
Ok(())
}
fn collect_hardware(device_uid: &str) -> anyhow::Result<HardwareAsset> {
let mut sys = System::new_all();
sys.refresh_all();
let cpu_model = sys.cpus().first()
.map(|c| c.brand().to_string())
.unwrap_or_else(|| "Unknown".to_string());
let cpu_cores = sys.cpus().len() as u32;
let memory_total_mb = sys.total_memory() / 1024 / 1024; // bytes to MB (sysinfo 0.30)
Ok(HardwareAsset {
device_uid: device_uid.to_string(),
cpu_model,
cpu_cores,
memory_total_mb: memory_total_mb as u64,
disk_model: "Unknown".to_string(),
disk_total_mb: 0,
gpu_model: None,
motherboard: None,
serial_number: None,
})
}

206
crates/client/src/main.rs Normal file
View File

@@ -0,0 +1,206 @@
use anyhow::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tracing::{info, error, warn};
use csm_protocol::{Frame, ClientConfig, UsbPolicyPayload};
mod monitor;
mod asset;
mod usb;
mod network;
mod watermark;
mod usage_timer;
mod usb_audit;
mod popup_blocker;
mod software_blocker;
mod web_filter;
/// Shared shutdown flag
static SHUTDOWN: AtomicBool = AtomicBool::new(false);
/// Client configuration
struct ClientState {
device_uid: String,
server_addr: String,
config: ClientConfig,
device_secret: Option<String>,
registration_token: String,
/// Whether to use TLS when connecting to the server
use_tls: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter("csm_client=info")
.init();
info!("CSM Client starting...");
// Load or generate device identity
let device_uid = load_or_create_device_uid()?;
info!("Device UID: {}", device_uid);
// Load server address
let server_addr = std::env::var("CSM_SERVER")
.unwrap_or_else(|_| "127.0.0.1:9999".to_string());
let state = ClientState {
device_uid,
server_addr,
config: ClientConfig::default(),
device_secret: load_device_secret(),
registration_token: std::env::var("CSM_REGISTRATION_TOKEN").unwrap_or_default(),
use_tls: std::env::var("CSM_USE_TLS").as_deref() == Ok("true"),
};
// TODO: Register as Windows Service on Windows
// For development, run directly
// Main event loop
run(state).await
}
async fn run(state: ClientState) -> Result<()> {
let (data_tx, mut data_rx) = tokio::sync::mpsc::channel::<Frame>(1024);
// Spawn Ctrl+C handler
tokio::spawn(async {
tokio::signal::ctrl_c().await.ok();
info!("Received Ctrl+C, initiating graceful shutdown...");
SHUTDOWN.store(true, Ordering::SeqCst);
});
// Create plugin config channels
let (watermark_tx, watermark_rx) = tokio::sync::watch::channel(None);
let (web_filter_tx, web_filter_rx) = tokio::sync::watch::channel(web_filter::WebFilterConfig::default());
let (software_blocker_tx, software_blocker_rx) = tokio::sync::watch::channel(software_blocker::SoftwareBlockerConfig::default());
let (popup_blocker_tx, popup_blocker_rx) = tokio::sync::watch::channel(popup_blocker::PopupBlockerConfig::default());
let (usb_audit_tx, usb_audit_rx) = tokio::sync::watch::channel(usb_audit::UsbAuditConfig::default());
let (usage_timer_tx, usage_timer_rx) = tokio::sync::watch::channel(usage_timer::UsageConfig::default());
let (usb_policy_tx, usb_policy_rx) = tokio::sync::watch::channel(None::<UsbPolicyPayload>);
let plugins = network::PluginChannels {
watermark_tx,
web_filter_tx,
software_blocker_tx,
popup_blocker_tx,
usb_audit_tx,
usage_timer_tx,
usb_policy_tx,
};
// Spawn core monitoring tasks
let monitor_tx = data_tx.clone();
let uid = state.device_uid.clone();
tokio::spawn(async move {
monitor::start_collecting(monitor_tx, uid).await;
});
let asset_tx = data_tx.clone();
let uid = state.device_uid.clone();
tokio::spawn(async move {
asset::start_collecting(asset_tx, uid).await;
});
let usb_tx = data_tx.clone();
let uid = state.device_uid.clone();
tokio::spawn(async move {
usb::start_monitoring(usb_tx, uid, usb_policy_rx).await;
});
// Spawn plugin tasks
tokio::spawn(async move {
watermark::start(watermark_rx).await;
});
let usage_data_tx = data_tx.clone();
let usage_uid = state.device_uid.clone();
tokio::spawn(async move {
usage_timer::start(usage_timer_rx, usage_data_tx, usage_uid).await;
});
let audit_data_tx = data_tx.clone();
let audit_uid = state.device_uid.clone();
tokio::spawn(async move {
usb_audit::start(usb_audit_rx, audit_data_tx, audit_uid).await;
});
tokio::spawn(async move {
popup_blocker::start(popup_blocker_rx).await;
});
let sw_data_tx = data_tx.clone();
let sw_uid = state.device_uid.clone();
tokio::spawn(async move {
software_blocker::start(software_blocker_rx, sw_data_tx, sw_uid).await;
});
tokio::spawn(async move {
web_filter::start(web_filter_rx).await;
});
// Connect to server with reconnect
let mut backoff = Duration::from_secs(1);
let max_backoff = Duration::from_secs(60);
loop {
if SHUTDOWN.load(Ordering::SeqCst) {
info!("Shutting down gracefully...");
break Ok(());
}
match network::connect_and_run(&state, &mut data_rx, &plugins).await {
Ok(()) => {
warn!("Disconnected from server, reconnecting...");
// Use a short fixed delay for clean disconnects (server-initiated),
// but don't reset to zero to prevent connection storms
tokio::time::sleep(Duration::from_secs(2)).await;
}
Err(e) => {
error!("Connection error: {}, reconnecting...", e);
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(max_backoff);
}
}
// Drain stale frames that accumulated during disconnection
let drained = data_rx.try_recv().ok().map(|_| 1).unwrap_or(0);
if drained > 0 {
let mut count = drained;
while data_rx.try_recv().is_ok() {
count += 1;
}
warn!("Drained {} stale frames from channel", count);
}
}
}
fn load_or_create_device_uid() -> Result<String> {
// In production, store in Windows Credential Store or local config
// For now, use a simple file
let uid_file = "device_uid.txt";
if std::path::Path::new(uid_file).exists() {
let uid = std::fs::read_to_string(uid_file)?;
Ok(uid.trim().to_string())
} else {
let uid = uuid::Uuid::new_v4().to_string();
std::fs::write(uid_file, &uid)?;
Ok(uid)
}
}
/// Load persisted device_secret from disk (if available)
pub fn load_device_secret() -> Option<String> {
let secret_file = "device_secret.txt";
let secret = std::fs::read_to_string(secret_file).ok()?;
let trimmed = secret.trim().to_string();
if trimmed.is_empty() { None } else { Some(trimmed) }
}
/// Persist device_secret to disk
pub fn save_device_secret(secret: &str) {
if let Err(e) = std::fs::write("device_secret.txt", secret) {
warn!("Failed to persist device_secret: {}", e);
}
}

View File

@@ -0,0 +1,86 @@
use anyhow::Result;
use csm_protocol::{Frame, MessageType, DeviceStatus, ProcessInfo};
use std::time::Duration;
use tokio::sync::mpsc::Sender;
use tracing::{info, error, debug};
use sysinfo::System;
pub async fn start_collecting(tx: Sender<Frame>, device_uid: String) {
let interval = Duration::from_secs(60);
loop {
// Run blocking sysinfo collection on a dedicated thread
let uid_clone = device_uid.clone();
let result = tokio::task::spawn_blocking(move || {
collect_system_status(&uid_clone)
}).await;
match result {
Ok(Ok(status)) => {
if let Ok(frame) = Frame::new_json(MessageType::StatusReport, &status) {
debug!("Sending status report: cpu={:.1}%, mem={:.1}%", status.cpu_usage, status.memory_usage);
if tx.send(frame).await.is_err() {
info!("Monitor channel closed, exiting");
break;
}
}
}
Ok(Err(e)) => {
error!("Failed to collect system status: {}", e);
}
Err(e) => {
error!("Monitor task join error: {}", e);
}
}
tokio::time::sleep(interval).await;
}
}
fn collect_system_status(device_uid: &str) -> Result<DeviceStatus> {
let mut sys = System::new_all();
sys.refresh_all();
// Brief wait for CPU usage to stabilize
std::thread::sleep(Duration::from_millis(200));
sys.refresh_all();
let cpu_usage = sys.global_cpu_info().cpu_usage() as f64;
let total_memory = sys.total_memory() / 1024 / 1024; // Convert bytes to MB (sysinfo 0.30 returns bytes)
let used_memory = sys.used_memory() / 1024 / 1024;
let memory_usage = if total_memory > 0 {
(used_memory as f64 / total_memory as f64) * 100.0
} else {
0.0
};
// Top processes by CPU
let mut processes: Vec<ProcessInfo> = sys.processes()
.iter()
.map(|(_, p)| {
ProcessInfo {
name: p.name().to_string(),
pid: p.pid().as_u32(),
cpu_usage: p.cpu_usage() as f64,
memory_mb: p.memory() / 1024 / 1024, // bytes to MB (sysinfo 0.30)
}
})
.collect();
processes.sort_by(|a, b| b.cpu_usage.partial_cmp(&a.cpu_usage).unwrap_or(std::cmp::Ordering::Equal));
processes.truncate(10);
Ok(DeviceStatus {
device_uid: device_uid.to_string(),
cpu_usage,
memory_usage,
memory_total_mb: total_memory as u64,
disk_usage: 0.0, // TODO: implement disk usage via Windows API
disk_total_mb: 0,
network_rx_rate: 0,
network_tx_rate: 0,
running_procs: sys.processes().len() as u32,
top_processes: processes,
})
}

View File

@@ -0,0 +1,380 @@
use anyhow::Result;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::ClientState;
/// Holds senders for all plugin config channels
pub struct PluginChannels {
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
}
/// Connect to server and run the main communication loop
pub async fn connect_and_run(
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
run_comm_loop(tls_stream, state, data_rx, plugins).await
} else {
run_comm_loop(tcp_stream, state, data_rx, plugins).await
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
// Load custom CA certificate if specified
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
let ca_pem = std::fs::read(&ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
for cert in certs {
root_store.add(cert)?;
}
info!("Loaded custom CA certificates from {}", ca_path);
}
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_no_client_auth()
} else {
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
let tls_stream = connector.connect(server_name, stream).await?;
info!("TLS handshake completed with {}", domain);
Ok(tls_stream)
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &rustls_pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
/// Main communication loop over any read+write stream
async fn run_comm_loop<S>(
mut stream: S,
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
// Send registration
let register = RegisterRequest {
device_uid: state.device_uid.clone(),
hostname: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()),
registration_token: state.registration_token.clone(),
os_version: get_os_info(),
mac_address: None,
};
let frame = Frame::new_json(MessageType::Register, &register)?;
stream.write_all(&frame.encode()).await?;
info!("Registration request sent");
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
// Clamp heartbeat interval to sane range [5, 3600] to prevent CPU spin or effective disable
let heartbeat_secs = state.config.heartbeat_interval_secs.clamp(5, 3600);
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
heartbeat_interval.tick().await; // Skip first tick
// HMAC key — set after receiving RegisterResponse
let mut device_secret: Option<String> = state.device_secret.clone();
loop {
tokio::select! {
// Read from server
result = stream.read(&mut buffer) => {
let n = result?;
if n == 0 {
return Err(anyhow::anyhow!("Server closed connection"));
}
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth from a malicious server
if read_buf.len() > 1_048_576 {
return Err(anyhow::anyhow!("Read buffer exceeded 1MB, server may be malicious"));
}
// Process complete frames
loop {
match Frame::decode(&read_buf)? {
Some(frame) => {
let consumed = frame.encoded_size();
read_buf.drain(..consumed);
// Capture device_secret from registration response
if frame.msg_type == MessageType::RegisterResponse {
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
device_secret = Some(resp.device_secret.clone());
crate::save_device_secret(&resp.device_secret);
info!("Device secret received and persisted, HMAC enabled for heartbeats");
}
}
handle_server_message(frame, plugins)?;
}
None => break, // Incomplete frame, wait for more data
}
}
}
// Send queued data
frame = data_rx.recv() => {
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
stream.write_all(&frame.encode()).await?;
}
// Heartbeat
_ = heartbeat_interval.tick() => {
let timestamp = chrono::Utc::now().to_rfc3339();
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, &timestamp);
let heartbeat = HeartbeatPayload {
device_uid: state.device_uid.clone(),
timestamp,
hmac: hmac_value,
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
stream.write_all(&frame.encode()).await?;
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
}
}
}
}
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
match frame.msg_type {
MessageType::RegisterResponse => {
let resp: RegisterResponse = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
info!("Registration accepted by server (server version: {})", resp.config.server_version);
}
MessageType::PolicyUpdate => {
let policy: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
}
MessageType::WatermarkConfig => {
let config: WatermarkConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
info!("Received watermark config: enabled={}", config.enabled);
plugins.watermark_tx.send(Some(config))?;
}
MessageType::UsbPolicyUpdate => {
let policy: UsbPolicyPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB policy: {}", e))?;
info!("Received USB policy: type={}, enabled={}", policy.policy_type, policy.enabled);
plugins.usb_policy_tx.send(Some(policy))?;
}
MessageType::WebFilterRuleUpdate => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
info!("Received web filter rules update");
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
plugins.web_filter_tx.send(config)?;
}
MessageType::SoftwareBlacklist => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
info!("Received software blacklist update");
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
info!("Received popup blocker rules update");
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
info!("Plugin enabled: {}", payload.plugin_name);
// Route to appropriate plugin channel based on plugin_name
handle_plugin_control(&payload, plugins, true)?;
}
MessageType::PluginDisable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
fn handle_plugin_control(
payload: &csm_protocol::PluginControlPayload,
plugins: &PluginChannels,
enabled: bool,
) -> Result<()> {
match payload.plugin_name.as_str() {
"watermark" => {
if !enabled {
// Send disabled config to remove overlay
plugins.watermark_tx.send(None)?;
}
// When enabling, server will push the actual config next
}
"web_filter" => {
if !enabled {
// Clear hosts rules on disable
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
}
// When enabling, server will push rules
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
}
}
"popup_blocker" => {
if !enabled {
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
}
}
"usb_audit" => {
if !enabled {
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
}
}
"usage_timer" => {
if !enabled {
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}
}
Ok(())
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
let secret = match secret {
Some(s) if !s.is_empty() => s,
_ => return String::new(),
};
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
fn get_os_info() -> String {
use sysinfo::System;
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
format!("{} {}", name, version)
}

View File

@@ -0,0 +1,370 @@
use anyhow::Result;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use crate::ClientState;
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Holds senders for all plugin config channels
pub struct PluginChannels {
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
}
/// Connect to server and run the main communication loop
pub async fn connect_and_run(
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
info!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
run_comm_loop(tls_stream, state, data_rx, plugins).await
} else {
run_comm_loop(tcp_stream, state, data_rx, plugins).await
}
}
/// Wrap a TCP stream with TLS.
/// Supports custom CA certificate via CSM_TLS_CA_CERT env var (path to PEM file).
/// Supports skipping verification via CSM_TLS_SKIP_VERIFY=true (development only).
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
// Load custom CA certificate if specified
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
let ca_pem = std::fs::read(&ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
for cert in certs {
root_store.add(cert)?;
}
info!("Loaded custom CA certificates from {}", ca_path);
}
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = if std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") {
warn!("TLS certificate verification DISABLED — do not use in production!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(NoVerifier))
.with_no_client_auth()
} else {
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
let tls_stream = connector.connect(server_name, stream).await?;
info!("TLS handshake completed with {}", domain);
Ok(tls_stream)
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &rustls_pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
/// Main communication loop over any read+write stream
async fn run_comm_loop<S>(
mut stream: S,
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
// Send registration
let register = RegisterRequest {
device_uid: state.device_uid.clone(),
hostname: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()),
registration_token: state.registration_token.clone(),
os_version: get_os_info(),
mac_address: None,
};
let frame = Frame::new_json(MessageType::Register, &register)?;
stream.write_all(&frame.encode()).await?;
info!("Registration request sent");
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let heartbeat_secs = state.config.heartbeat_interval_secs;
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
heartbeat_interval.tick().await; // Skip first tick
// HMAC key — set after receiving RegisterResponse
let mut device_secret: Option<String> = state.device_secret.clone();
loop {
tokio::select! {
// Read from server
result = stream.read(&mut buffer) => {
let n = result?;
if n == 0 {
return Err(anyhow::anyhow!("Server closed connection"));
}
read_buf.extend_from_slice(&buffer[..n]);
// Process complete frames
loop {
match Frame::decode(&read_buf)? {
Some(frame) => {
let consumed = frame.encoded_size();
read_buf.drain(..consumed);
// Capture device_secret from registration response
if frame.msg_type == MessageType::RegisterResponse {
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
device_secret = Some(resp.device_secret.clone());
crate::save_device_secret(&resp.device_secret);
info!("Device secret received and persisted, HMAC enabled for heartbeats");
}
}
handle_server_message(frame, plugins)?;
}
None => break, // Incomplete frame, wait for more data
}
}
}
// Send queued data
frame = data_rx.recv() => {
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
stream.write_all(&frame.encode()).await?;
}
// Heartbeat
_ = heartbeat_interval.tick() => {
let timestamp = chrono::Utc::now().to_rfc3339();
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, &timestamp);
let heartbeat = HeartbeatPayload {
device_uid: state.device_uid.clone(),
timestamp,
hmac: hmac_value,
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
stream.write_all(&frame.encode()).await?;
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
}
}
}
}
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
match frame.msg_type {
MessageType::RegisterResponse => {
let resp: RegisterResponse = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
info!("Registration accepted by server (server version: {})", resp.config.server_version);
}
MessageType::PolicyUpdate => {
let policy: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
info!("Received config update");
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
}
MessageType::WatermarkConfig => {
let config: WatermarkConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
info!("Received watermark config: enabled={}", config.enabled);
plugins.watermark_tx.send(Some(config))?;
}
MessageType::WebFilterRuleUpdate => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
info!("Received web filter rules update");
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
plugins.web_filter_tx.send(config)?;
}
MessageType::SoftwareBlacklist => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
info!("Received software blacklist update");
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist };
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
info!("Received popup blocker rules update");
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
info!("Plugin enabled: {}", payload.plugin_name);
// Route to appropriate plugin channel based on plugin_name
handle_plugin_control(&payload, plugins, true)?;
}
MessageType::PluginDisable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
fn handle_plugin_control(
payload: &csm_protocol::PluginControlPayload,
plugins: &PluginChannels,
enabled: bool,
) -> Result<()> {
match payload.plugin_name.as_str() {
"watermark" => {
if !enabled {
// Send disabled config to remove overlay
plugins.watermark_tx.send(None)?;
}
// When enabling, server will push the actual config next
}
"web_filter" => {
if !enabled {
// Clear hosts rules on disable
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
}
// When enabling, server will push rules
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![] })?;
}
}
"popup_blocker" => {
if !enabled {
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
}
}
"usb_audit" => {
if !enabled {
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
}
}
"usage_timer" => {
if !enabled {
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}
}
Ok(())
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
let secret = match secret {
Some(s) if !s.is_empty() => s,
_ => return String::new(),
};
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
fn get_os_info() -> String {
use sysinfo::System;
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
format!("{} {}", name, version)
}

View File

@@ -0,0 +1,254 @@
use tokio::sync::watch;
use tracing::{info, debug};
use serde::Deserialize;
/// Popup blocker rule from server
#[derive(Debug, Clone, Deserialize)]
pub struct PopupRule {
pub id: i64,
pub rule_type: String,
pub window_title: Option<String>,
pub window_class: Option<String>,
pub process_name: Option<String>,
}
/// Popup blocker configuration
#[derive(Debug, Clone, Default)]
pub struct PopupBlockerConfig {
pub enabled: bool,
pub rules: Vec<PopupRule>,
}
/// Context passed to EnumWindows callback via LPARAM
struct ScanContext {
rules: Vec<PopupRule>,
blocked_count: u32,
}
/// Start popup blocker plugin.
/// Periodically enumerates windows and closes those matching rules.
pub async fn start(mut config_rx: watch::Receiver<PopupBlockerConfig>) {
info!("Popup blocker plugin started");
let mut config = PopupBlockerConfig::default();
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(2));
scan_interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("Popup blocker config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
config = new_config;
}
_ = scan_interval.tick() => {
if !config.enabled || config.rules.is_empty() {
continue;
}
scan_and_block(&config.rules);
}
}
}
}
fn scan_and_block(rules: &[PopupRule]) {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::WindowsAndMessaging::EnumWindows;
use windows::Win32::Foundation::LPARAM;
let mut ctx = ScanContext {
rules: rules.to_vec(),
blocked_count: 0,
};
unsafe {
let _ = EnumWindows(
Some(enum_windows_callback),
LPARAM(&mut ctx as *mut ScanContext as isize),
);
}
if ctx.blocked_count > 0 {
debug!("Popup scan blocked {} windows", ctx.blocked_count);
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = rules;
}
}
#[cfg(target_os = "windows")]
unsafe extern "system" fn enum_windows_callback(
hwnd: windows::Win32::Foundation::HWND,
lparam: windows::Win32::Foundation::LPARAM,
) -> windows::Win32::Foundation::BOOL {
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::Win32::Foundation::*;
// Only consider visible, top-level windows without an owner
if !IsWindowVisible(hwnd).as_bool() {
return BOOL(1);
}
// Skip windows that have an owner (they're child dialogs, not popups)
if GetWindow(hwnd, GW_OWNER).0 != 0 {
return BOOL(1);
}
// Get window title
let mut title_buf = [0u16; 512];
let title_len = GetWindowTextW(hwnd, &mut title_buf);
let title_len = title_len.max(0) as usize;
let title = String::from_utf16_lossy(&title_buf[..title_len]);
// Skip windows with empty titles
if title.is_empty() {
return BOOL(1);
}
// Get class name
let mut class_buf = [0u16; 256];
let class_len = GetClassNameW(hwnd, &mut class_buf);
let class_len = class_len.max(0) as usize;
let class_name = String::from_utf16_lossy(&class_buf[..class_len]);
// Get process name from PID
let mut pid: u32 = 0;
GetWindowThreadProcessId(hwnd, Some(&mut pid));
let process_name = get_process_name(pid);
// Recover the ScanContext from LPARAM
let ctx = &mut *(lparam.0 as *mut ScanContext);
// Check against each rule
for rule in &ctx.rules {
if rule.rule_type != "block" {
continue;
}
let matches = rule_matches(rule, &title, &class_name, &process_name);
if matches {
let _ = PostMessageW(hwnd, WM_CLOSE, WPARAM(0), LPARAM(0));
ctx.blocked_count += 1;
info!(
"Blocked popup: title='{}' class='{}' process='{}' (rule_id={})",
title, class_name, process_name, rule.id
);
break; // One match is enough per window
}
}
BOOL(1) // Continue enumeration
}
fn rule_matches(rule: &PopupRule, title: &str, class_name: &str, process_name: &str) -> bool {
let title_match = match &rule.window_title {
Some(pattern) => pattern_match(pattern, title),
None => true, // No title filter = match all
};
let class_match = match &rule.window_class {
Some(pattern) => pattern_match(pattern, class_name),
None => true,
};
let process_match = match &rule.process_name {
Some(pattern) => pattern_match(pattern, process_name),
None => true,
};
title_match && class_match && process_match
}
/// Simple case-insensitive wildcard pattern matching.
/// Supports `*` as wildcard (matches any characters).
fn pattern_match(pattern: &str, text: &str) -> bool {
let p = pattern.to_lowercase();
let t = text.to_lowercase();
if !p.contains('*') {
return t.contains(&p);
}
let parts: Vec<&str> = p.split('*').collect();
if parts.is_empty() {
return true;
}
let mut pos = 0usize;
let mut matched_any = false;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
// Leading empty = pattern starts with * → no start anchor
// Trailing empty = pattern ends with * → no end anchor
continue;
}
matched_any = true;
if i == 0 && !parts[0].is_empty() {
// Pattern starts with literal → must match at start
if !t.starts_with(part) {
return false;
}
pos = part.len();
} else {
// Find this segment anywhere after current position
match t[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with literal (no trailing *), must match at end
if matched_any && !parts.last().map_or(true, |p| p.is_empty()) {
return t.ends_with(parts.last().unwrap());
}
true
}
#[cfg(target_os = "windows")]
fn get_process_name(pid: u32) -> String {
use windows::Win32::System::Diagnostics::ToolHelp::*;
use windows::Win32::Foundation::CloseHandle;
unsafe {
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
Ok(s) => s,
Err(_) => return format!("pid:{}", pid),
};
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
if entry.th32ProcessID == pid {
let name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
let _ = CloseHandle(snapshot);
return name;
}
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if Process32NextW(snapshot, &mut entry).is_err() {
break;
}
}
}
let _ = CloseHandle(snapshot);
format!("pid:{}", pid)
}
}
#[cfg(not(target_os = "windows"))]
fn get_process_name(pid: u32) -> String {
format!("pid:{}", pid)
}

View File

@@ -0,0 +1,254 @@
use tokio::sync::watch;
use tracing::{info, warn};
use csm_protocol::{Frame, MessageType, SoftwareViolationReport};
use serde::Deserialize;
/// System-critical processes that must never be killed regardless of server rules.
/// Killing any of these would cause system instability or a BSOD.
const PROTECTED_PROCESSES: &[&str] = &[
"system",
"system idle process",
"svchost.exe",
"lsass.exe",
"csrss.exe",
"wininit.exe",
"winlogon.exe",
"services.exe",
"dwm.exe",
"explorer.exe",
"taskhostw.exe",
"registry",
"smss.exe",
"conhost.exe",
];
/// Software blacklist entry from server
#[derive(Debug, Clone, Deserialize)]
pub struct BlacklistEntry {
pub id: i64,
pub name_pattern: String,
pub action: String,
}
/// Software blocker configuration
#[derive(Debug, Clone, Default)]
pub struct SoftwareBlockerConfig {
pub enabled: bool,
pub blacklist: Vec<BlacklistEntry>,
}
/// Start software blocker plugin.
/// Periodically scans running processes against the blacklist.
pub async fn start(
mut config_rx: watch::Receiver<SoftwareBlockerConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Software blocker plugin started");
let mut config = SoftwareBlockerConfig::default();
let mut scan_interval = tokio::time::interval(std::time::Duration::from_secs(10));
scan_interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("Software blocker config updated: enabled={}, blacklist={}", new_config.enabled, new_config.blacklist.len());
config = new_config;
}
_ = scan_interval.tick() => {
if !config.enabled || config.blacklist.is_empty() {
continue;
}
scan_processes(&config.blacklist, &data_tx, &device_uid).await;
}
}
}
}
async fn scan_processes(
blacklist: &[BlacklistEntry],
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
) {
let running = get_running_processes_with_pids();
for entry in blacklist {
for (process_name, pid) in &running {
if pattern_matches(&entry.name_pattern, process_name) {
// Never kill system-critical processes
if is_protected_process(process_name) {
warn!("Blacklisted match '{}' skipped — system-critical process (pid={})", process_name, pid);
continue;
}
warn!("Blacklisted software detected: {} (action: {})", process_name, entry.action);
// Report violation to server
// Map action to DB-compatible values: "block" -> "blocked_install", "alert" -> "alerted"
let action_taken = match entry.action.as_str() {
"block" => "blocked_install",
"alert" => "alerted",
other => other,
};
let violation = SoftwareViolationReport {
device_uid: device_uid.to_string(),
software_name: process_name.clone(),
action_taken: action_taken.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::SoftwareViolation, &violation) {
let _ = data_tx.send(frame).await;
}
// Kill the process directly by captured PID (avoids TOCTOU race)
if entry.action == "block" {
kill_process_by_pid(*pid, process_name);
}
}
}
}
}
fn pattern_matches(pattern: &str, name: &str) -> bool {
let pattern_lower = pattern.to_lowercase();
let name_lower = name.to_lowercase();
// Support wildcard patterns
if pattern_lower.contains('*') {
let parts: Vec<&str> = pattern_lower.split('*').collect();
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 && !parts[0].is_empty() {
// Pattern starts with literal → must match at start
if !name_lower.starts_with(part) {
return false;
}
pos = part.len();
} else {
match name_lower[pos..].find(part) {
Some(idx) => pos += idx + part.len(),
None => return false,
}
}
}
// If pattern ends with literal (no trailing *), must match at end
if !parts.last().map_or(true, |p| p.is_empty()) {
return name_lower.ends_with(parts.last().unwrap());
}
true
} else {
name_lower.contains(&pattern_lower)
}
}
/// Get all running processes with their PIDs (single snapshot, no TOCTOU)
fn get_running_processes_with_pids() -> Vec<(String, u32)> {
#[cfg(target_os = "windows")]
{
use windows::Win32::System::Diagnostics::ToolHelp::*;
use windows::Win32::Foundation::CloseHandle;
let mut procs = Vec::new();
unsafe {
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
Ok(s) => s,
Err(_) => return procs,
};
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
let name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
procs.push((name, entry.th32ProcessID));
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if !Process32NextW(snapshot, &mut entry).is_ok() {
break;
}
}
}
let _ = CloseHandle(snapshot);
}
procs
}
#[cfg(not(target_os = "windows"))]
{
Vec::new()
}
}
fn kill_process_by_pid(pid: u32, expected_name: &str) {
#[cfg(target_os = "windows")]
{
use windows::Win32::System::Threading::{OpenProcess, TerminateProcess, PROCESS_TERMINATE};
use windows::Win32::Foundation::CloseHandle;
use windows::Win32::System::Diagnostics::ToolHelp::*;
unsafe {
// Verify the PID still belongs to the expected process (prevent PID reuse kills)
let snapshot = match CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) {
Ok(s) => s,
Err(_) => return,
};
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
let mut name_matches = false;
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
if entry.th32ProcessID == pid {
let current_name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
name_matches = current_name.to_lowercase() == expected_name.to_lowercase();
break;
}
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if !Process32NextW(snapshot, &mut entry).is_ok() {
break;
}
}
}
let _ = CloseHandle(snapshot);
if !name_matches {
warn!("Skipping kill: PID {} no longer matches expected process '{}'", pid, expected_name);
return;
}
if let Ok(handle) = OpenProcess(PROCESS_TERMINATE, false, pid) {
let terminated = TerminateProcess(handle, 1).is_ok();
let _ = CloseHandle(handle);
if terminated {
warn!("Killed process: {} (pid={})", expected_name, pid);
} else {
warn!("TerminateProcess failed for: {} (pid={})", expected_name, pid);
}
} else {
warn!("OpenProcess failed for: {} (pid={}) — insufficient privileges?", expected_name, pid);
}
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = (pid, expected_name);
}
}
fn is_protected_process(name: &str) -> bool {
let lower = name.to_lowercase();
PROTECTED_PROCESSES.iter().any(|p| lower == **p || lower.ends_with(&format!("/{}", p).replace('/', "\\")))
}

View File

@@ -0,0 +1,197 @@
use std::time::{Duration, Instant};
use tokio::sync::watch;
use tracing::{info, debug};
use csm_protocol::{Frame, MessageType, UsageDailyReport, AppUsageEntry};
/// Usage tracking configuration pushed from server
#[derive(Debug, Clone, Default)]
pub struct UsageConfig {
pub enabled: bool,
pub idle_threshold_secs: u64,
pub report_interval_secs: u64,
}
/// Start the usage timer plugin.
/// Tracks active/idle time and foreground application usage.
pub async fn start(
mut config_rx: watch::Receiver<UsageConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("Usage timer plugin started");
let mut config = UsageConfig {
enabled: false,
idle_threshold_secs: 300, // 5 minutes default
report_interval_secs: 300,
};
let mut report_interval = tokio::time::interval(Duration::from_secs(60));
report_interval.tick().await;
let mut active_secs: u64 = 0;
let mut idle_secs: u64 = 0;
let mut app_usage: std::collections::HashMap<String, u64> = std::collections::HashMap::new();
let mut last_tick = Instant::now();
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
if new_config.enabled != config.enabled {
info!("Usage timer enabled: {}", new_config.enabled);
}
config = new_config;
if config.enabled {
report_interval = tokio::time::interval(Duration::from_secs(config.report_interval_secs));
report_interval.tick().await;
last_tick = Instant::now();
}
}
_ = report_interval.tick() => {
if !config.enabled {
continue;
}
// Measure actual elapsed time since last tick
let now = Instant::now();
let elapsed_secs = now.duration_since(last_tick).as_secs();
last_tick = now;
let idle_ms = get_idle_millis();
let is_idle = idle_ms as u64 > config.idle_threshold_secs * 1000;
if is_idle {
idle_secs += elapsed_secs;
} else {
active_secs += elapsed_secs;
}
// Track foreground app
if let Some(app) = get_foreground_app_name() {
*app_usage.entry(app).or_insert(0) += elapsed_secs;
}
// Report usage to server
if active_secs > 0 || idle_secs > 0 {
let report = UsageDailyReport {
device_uid: device_uid.clone(),
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
total_active_minutes: (active_secs / 60) as u32,
total_idle_minutes: (idle_secs / 60) as u32,
first_active_at: None,
last_active_at: Some(chrono::Utc::now().to_rfc3339()),
};
if let Ok(frame) = Frame::new_json(MessageType::UsageReport, &report) {
if data_tx.send(frame).await.is_err() {
break;
}
}
// Report per-app usage
for (app, secs) in &app_usage {
let entry = AppUsageEntry {
device_uid: device_uid.clone(),
date: chrono::Local::now().format("%Y-%m-%d").to_string(),
app_name: app.clone(),
usage_minutes: (secs / 60) as u32,
};
if let Ok(frame) = Frame::new_json(MessageType::AppUsageReport, &entry) {
let _ = data_tx.send(frame).await;
}
}
// Reset counters after reporting
active_secs = 0;
idle_secs = 0;
app_usage.clear();
}
debug!("Usage report sent (idle_ms={}, elapsed={}s)", idle_ms, elapsed_secs);
}
}
}
}
/// Get system idle time in milliseconds using GetLastInputInfo + GetTickCount64.
/// Both use the same time base. LASTINPUTINFO.dwTime is u32 (GetTickCount legacy),
/// so we take the low 32 bits of GetTickCount64 for correct wrapping comparison.
fn get_idle_millis() -> u32 {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::Input::KeyboardAndMouse::{GetLastInputInfo, LASTINPUTINFO};
unsafe {
let mut lii = LASTINPUTINFO {
cbSize: std::mem::size_of::<LASTINPUTINFO>() as u32,
dwTime: 0,
};
if GetLastInputInfo(&mut lii).as_bool() {
// GetTickCount64 returns u64, but dwTime is u32.
// Take low 32 bits so wrapping_sub produces correct idle delta.
let tick_low32 = (windows::Win32::System::SystemInformation::GetTickCount64() & 0xFFFFFFFF) as u32;
return tick_low32.wrapping_sub(lii.dwTime);
}
}
0
}
#[cfg(not(target_os = "windows"))]
{
0
}
}
/// Get the foreground window's process name using CreateToolhelp32Snapshot
fn get_foreground_app_name() -> Option<String> {
#[cfg(target_os = "windows")]
{
use windows::Win32::UI::WindowsAndMessaging::{GetForegroundWindow, GetWindowThreadProcessId};
use windows::Win32::System::Diagnostics::ToolHelp::*;
use windows::Win32::Foundation::CloseHandle;
unsafe {
let hwnd = GetForegroundWindow();
if hwnd.0 == 0 {
return None;
}
let mut pid: u32 = 0;
GetWindowThreadProcessId(hwnd, Some(&mut pid));
if pid == 0 {
return None;
}
// Get process name via CreateToolhelp32Snapshot
let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()?;
let mut entry = PROCESSENTRY32W {
dwSize: std::mem::size_of::<PROCESSENTRY32W>() as u32,
..Default::default()
};
if Process32FirstW(snapshot, &mut entry).is_ok() {
loop {
if entry.th32ProcessID == pid {
let name = String::from_utf16_lossy(
&entry.szExeFile.iter().take_while(|&&c| c != 0).copied().collect::<Vec<u16>>()
);
let _ = CloseHandle(snapshot);
return Some(name);
}
entry.dwSize = std::mem::size_of::<PROCESSENTRY32W>() as u32;
if !Process32NextW(snapshot, &mut entry).is_ok() {
break;
}
}
}
let _ = CloseHandle(snapshot);
None
}
}
#[cfg(not(target_os = "windows"))]
{
None
}
}

View File

@@ -0,0 +1,249 @@
use csm_protocol::{Frame, MessageType, UsbEvent, UsbEventType, UsbPolicyPayload, UsbDeviceRule};
use std::time::Duration;
use tokio::sync::{mpsc::Sender, watch};
use tracing::{info, warn};
/// Start USB monitoring with policy enforcement.
/// Monitors for removable drive insertions/removals and enforces USB policies.
pub async fn start_monitoring(
tx: Sender<Frame>,
device_uid: String,
mut policy_rx: watch::Receiver<Option<UsbPolicyPayload>>,
) {
info!("USB monitoring started for {}", device_uid);
let mut current_policy: Option<UsbPolicyPayload> = None;
let mut known_drives: Vec<String> = Vec::new();
let interval = Duration::from_secs(10);
loop {
// Check for policy updates (non-blocking)
if policy_rx.has_changed().unwrap_or(false) {
let new_policy = policy_rx.borrow_and_update().clone();
let policy_desc = new_policy.as_ref()
.map(|p| format!("type={}, enabled={}", p.policy_type, p.enabled))
.unwrap_or_else(|| "None".to_string());
info!("USB policy updated: {}", policy_desc);
current_policy = new_policy;
// Re-check all currently mounted drives against new policy
let drives = detect_removable_drives();
for drive in &drives {
if should_block_drive(drive, &current_policy) {
warn!("USB device {} blocked by policy, ejecting...", drive);
if let Err(e) = eject_drive(drive) {
warn!("Failed to eject {}: {}", drive, e);
} else {
info!("Successfully ejected {}", drive);
}
// Report blocked event
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
}
}
}
let current_drives = detect_removable_drives();
// Detect new drives (inserted)
for drive in &current_drives {
if !known_drives.iter().any(|d: &String| d == drive) {
info!("USB device inserted: {}", drive);
// Check against policy before reporting
if should_block_drive(drive, &current_policy) {
warn!("USB device {} blocked by policy, ejecting...", drive);
if let Err(e) = eject_drive(drive) {
warn!("Failed to eject {}: {}", drive, e);
} else {
info!("Successfully ejected {}", drive);
}
send_usb_event(&tx, &device_uid, UsbEventType::Blocked, drive).await;
continue;
}
send_usb_event(&tx, &device_uid, UsbEventType::Inserted, drive).await;
}
}
// Detect removed drives
for drive in &known_drives {
if !current_drives.iter().any(|d| d == drive) {
info!("USB device removed: {}", drive);
send_usb_event(&tx, &device_uid, UsbEventType::Removed, drive).await;
}
}
known_drives = current_drives;
tokio::time::sleep(interval).await;
}
}
async fn send_usb_event(
tx: &Sender<Frame>,
device_uid: &str,
event_type: UsbEventType,
drive: &str,
) {
let event = UsbEvent {
device_uid: device_uid.to_string(),
event_type,
vendor_id: None,
product_id: None,
serial: None,
device_name: Some(drive.to_string()),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbEvent, &event) {
tx.send(frame).await.ok();
}
}
/// Check if a drive should be blocked based on the current policy.
fn should_block_drive(drive: &str, policy: &Option<UsbPolicyPayload>) -> bool {
let policy = match policy {
Some(p) if p.enabled => p,
_ => return false,
};
match policy.policy_type.as_str() {
"all_block" => true,
"blacklist" => {
// Block if any rule matches
policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
}
"whitelist" => {
// Block if NOT in whitelist (empty whitelist = block all)
if policy.rules.is_empty() {
return true;
}
!policy.rules.iter().any(|rule| device_matches_rule(drive, rule))
}
_ => false,
}
}
/// Check if a device matches a rule pattern.
/// Currently matches by device_name (drive letter path) since that's what we detect.
fn device_matches_rule(drive: &str, rule: &UsbDeviceRule) -> bool {
// Match by device name (drive root path like "E:\")
if let Some(ref name) = rule.device_name {
if drive.eq_ignore_ascii_case(name) || drive.eq_ignore_ascii_case(&name.replace('\\', "")) {
return true;
}
}
// vendor_id, product_id, serial matching would require WMI or SetupDi queries
// For now, these are placeholder checks
let _ = (drive, rule);
false
}
/// DRIVE_REMOVABLE = 2 in Windows API
const DRIVE_REMOVABLE: u32 = 2;
fn detect_removable_drives() -> Vec<String> {
let mut drives = Vec::new();
#[cfg(target_os = "windows")]
{
for letter in b'A'..=b'Z' {
let root = format!("{}:\\", letter as char);
let root_wide: Vec<u16> = root.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let drive_type = windows::Win32::Storage::FileSystem::GetDriveTypeW(
windows::core::PCWSTR(root_wide.as_ptr()),
);
if drive_type == DRIVE_REMOVABLE {
drives.push(root);
}
}
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = &drives; // Suppress unused warning on non-Windows
}
drives
}
/// Eject a removable drive using Windows API.
/// Opens the volume handle, dismounts the filesystem, and ejects the media.
#[cfg(target_os = "windows")]
fn eject_drive(drive: &str) -> Result<(), anyhow::Error> {
use windows::Win32::Storage::FileSystem::*;
use windows::Win32::System::IO::DeviceIoControl;
use windows::Win32::Foundation::*;
use windows::core::PCWSTR;
// IOCTL control codes (not exposed directly in windows crate)
const FSCTL_DISMOUNT_VOLUME: u32 = 0x00090020;
const IOCTL_STORAGE_EJECT_MEDIA: u32 = 0x002D4808;
// Extract drive letter from path like "E:\"
let letter = drive.chars().next().unwrap_or('A');
let path = format!("\\\\.\\{}:\0", letter);
let path_wide: Vec<u16> = path.encode_utf16().collect();
unsafe {
let handle = CreateFileW(
PCWSTR(path_wide.as_ptr()),
GENERIC_READ.0,
FILE_SHARE_READ | FILE_SHARE_WRITE,
None,
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
None,
)?;
if handle.is_invalid() {
return Err(anyhow::anyhow!("Failed to open volume handle for {}", drive));
}
// Dismount the filesystem
let mut bytes_returned = 0u32;
let dismount_ok = DeviceIoControl(
handle,
FSCTL_DISMOUNT_VOLUME,
None,
0,
None,
0,
Some(&mut bytes_returned),
None,
).is_ok();
if !dismount_ok {
warn!("FSCTL_DISMOUNT_VOLUME failed for {}", drive);
}
// Eject the media
let eject_ok = DeviceIoControl(
handle,
IOCTL_STORAGE_EJECT_MEDIA,
None,
0,
None,
0,
Some(&mut bytes_returned),
None,
).is_ok();
if !eject_ok {
warn!("IOCTL_STORAGE_EJECT_MEDIA failed for {}", drive);
}
let _ = CloseHandle(handle);
if !dismount_ok && !eject_ok {
Err(anyhow::anyhow!("Failed to eject {}", drive))
} else {
Ok(())
}
}
}
#[cfg(not(target_os = "windows"))]
fn eject_drive(_drive: &str) -> Result<(), anyhow::Error> {
Ok(())
}

View File

@@ -0,0 +1,190 @@
use tokio::sync::watch;
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, UsbFileOpEntry};
use std::collections::{HashMap, HashSet};
/// USB file audit configuration
#[derive(Debug, Clone, Default)]
pub struct UsbAuditConfig {
pub enabled: bool,
pub monitored_extensions: Vec<String>,
}
/// Start USB file audit plugin.
/// Detects removable drives and monitors file changes via periodic scanning.
pub async fn start(
mut config_rx: watch::Receiver<UsbAuditConfig>,
data_tx: tokio::sync::mpsc::Sender<Frame>,
device_uid: String,
) {
info!("USB file audit plugin started");
let mut config = UsbAuditConfig::default();
let mut active_drives: HashSet<String> = HashSet::new();
// Track file listings per drive to detect changes
let mut drive_snapshots: HashMap<String, HashSet<String>> = HashMap::new();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
interval.tick().await;
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("USB audit config updated: enabled={}", new_config.enabled);
config = new_config;
}
_ = interval.tick() => {
if !config.enabled {
continue;
}
let drives = detect_removable_drives();
// Detect new drives
for drive in &drives {
if !active_drives.contains(drive) {
info!("New removable drive detected: {}", drive);
// Take initial snapshot in blocking thread to avoid freezing
let drive_clone = drive.clone();
let files = tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await;
match files {
Ok(files) => { drive_snapshots.insert(drive.clone(), files); }
Err(e) => { warn!("Failed to scan drive {}: {}", drive, e); }
}
}
}
// Scan existing drives for changes (each in a blocking thread)
for drive in &drives {
let drive_clone = drive.clone();
let current_files = match tokio::task::spawn_blocking(move || scan_drive_files(&drive_clone)).await {
Ok(files) => files,
Err(e) => { warn!("Drive scan failed for {}: {}", drive, e); continue; }
};
if let Some(prev_files) = drive_snapshots.get_mut(drive) {
// Find new files (created)
for file in &current_files {
if !prev_files.contains(file) {
report_file_op(&data_tx, &device_uid, drive, file, "create", &config.monitored_extensions).await;
}
}
// Find deleted files
for file in prev_files.iter() {
if !current_files.contains(file) {
report_file_op(&data_tx, &device_uid, drive, file, "delete", &config.monitored_extensions).await;
}
}
*prev_files = current_files;
} else {
drive_snapshots.insert(drive.clone(), current_files);
}
}
// Clean up removed drives
active_drives.retain(|d| drives.contains(d));
drive_snapshots.retain(|d, _| drives.contains(d));
}
}
}
}
async fn report_file_op(
data_tx: &tokio::sync::mpsc::Sender<Frame>,
device_uid: &str,
drive: &str,
file_path: &str,
operation: &str,
ext_filter: &[String],
) {
// Check extension filter
let should_report = if ext_filter.is_empty() {
true
} else {
let file_ext = std::path::Path::new(file_path)
.extension()
.and_then(|e| e.to_str())
.map(|e| e.to_lowercase());
file_ext.map_or(true, |ext| {
ext_filter.iter().any(|f| f.to_lowercase() == ext)
})
};
if should_report {
let entry = UsbFileOpEntry {
device_uid: device_uid.to_string(),
usb_serial: None,
drive_letter: Some(drive.to_string()),
operation: operation.to_string(),
file_path: file_path.to_string(),
file_size: None,
timestamp: chrono::Utc::now().to_rfc3339(),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbFileOp, &entry) {
let _ = data_tx.send(frame).await;
}
debug!("USB file op: {} {}", operation, file_path);
}
}
/// Recursively scan a drive and return all file paths
fn scan_drive_files(drive: &str) -> HashSet<String> {
let mut files = HashSet::new();
let max_depth = 3; // Limit recursion depth for performance
scan_dir_recursive(drive, &mut files, 0, max_depth);
files
}
fn scan_dir_recursive(dir: &str, files: &mut HashSet<String>, depth: usize, max_depth: usize) {
if depth >= max_depth {
return;
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
let path_str = path.to_string_lossy().to_string();
if path.is_dir() {
scan_dir_recursive(&path_str, files, depth + 1, max_depth);
} else {
files.insert(path_str);
}
}
}
}
fn detect_removable_drives() -> Vec<String> {
#[cfg(target_os = "windows")]
{
use windows::Win32::Storage::FileSystem::GetDriveTypeW;
use windows::core::PCWSTR;
let mut drives = Vec::new();
let mask = unsafe { windows::Win32::Storage::FileSystem::GetLogicalDrives() };
let mut mask = mask as u32;
let mut letter = b'A';
while mask != 0 {
if mask & 1 != 0 {
let drive_letter = format!("{}:\\", letter as char);
let drive_wide: Vec<u16> = drive_letter.encode_utf16().chain(std::iter::once(0)).collect();
let drive_type = unsafe {
GetDriveTypeW(PCWSTR(drive_wide.as_ptr()))
};
// DRIVE_REMOVABLE = 2
if drive_type == 2 {
drives.push(drive_letter);
}
}
mask >>= 1;
letter += 1;
}
drives
}
#[cfg(not(target_os = "windows"))]
{
Vec::new()
}
}

View File

@@ -0,0 +1,313 @@
use std::sync::{Arc, Mutex};
use tokio::sync::watch;
use tracing::{info, warn, error};
use csm_protocol::WatermarkConfigPayload;
/// Watermark overlay state
struct WatermarkState {
enabled: bool,
content: String,
font_size: u32,
opacity: f64,
color: String,
angle: i32,
hwnd: Option<isize>,
}
impl Default for WatermarkState {
fn default() -> Self {
Self {
enabled: false,
content: String::new(),
font_size: 14,
opacity: 0.15,
color: "#808080".to_string(),
angle: -30,
hwnd: None,
}
}
}
static WATERMARK_STATE: std::sync::OnceLock<Arc<Mutex<WatermarkState>>> = std::sync::OnceLock::new();
const WM_USER_UPDATE: u32 = 0x0401;
const OVERLAY_CLASS_NAME: &str = "CSM_WatermarkOverlay";
/// Start the watermark plugin, listening for config updates.
pub async fn start(mut rx: watch::Receiver<Option<WatermarkConfigPayload>>) {
info!("Watermark plugin started");
let state = WATERMARK_STATE.get_or_init(|| Arc::new(Mutex::new(WatermarkState::default())));
let state_clone = state.clone();
// Spawn a dedicated thread for the Win32 message loop
std::thread::spawn(move || {
#[cfg(target_os = "windows")]
run_overlay_message_loop(state_clone);
#[cfg(not(target_os = "windows"))]
{
let _ = state_clone;
info!("Watermark overlay not supported on this platform");
}
});
loop {
tokio::select! {
result = rx.changed() => {
if result.is_err() {
warn!("Watermark config channel closed");
break;
}
let config = rx.borrow_and_update().clone();
if let Some(cfg) = config {
info!("Watermark config updated: enabled={}, content='{}'", cfg.enabled, cfg.content);
{
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
s.enabled = cfg.enabled;
s.content = cfg.content.clone();
s.font_size = cfg.font_size;
s.opacity = cfg.opacity;
s.color = cfg.color.clone();
s.angle = cfg.angle;
}
// Post message to overlay window thread
#[cfg(target_os = "windows")]
post_overlay_update();
}
}
}
}
}
#[cfg(target_os = "windows")]
fn post_overlay_update() {
use windows::Win32::UI::WindowsAndMessaging::{FindWindowA, PostMessageW};
use windows::Win32::Foundation::{WPARAM, LPARAM};
use windows::core::PCSTR;
unsafe {
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
let hwnd = FindWindowA(PCSTR(class_name.as_ptr()), PCSTR::null());
if hwnd.0 != 0 {
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
}
}
}
#[cfg(target_os = "windows")]
fn run_overlay_message_loop(state: Arc<Mutex<WatermarkState>>) {
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::Foundation::*;
use windows::Win32::System::LibraryLoader::GetModuleHandleA;
use windows::core::PCSTR;
// Register window class
let class_name = format!("{}\0", OVERLAY_CLASS_NAME);
let instance = unsafe { GetModuleHandleA(PCSTR::null()) }.unwrap_or_default();
let instance: HINSTANCE = instance.into();
let wc = WNDCLASSA {
style: CS_HREDRAW | CS_VREDRAW,
lpfnWndProc: Some(watermark_wnd_proc),
hInstance: instance,
lpszClassName: PCSTR(class_name.as_ptr()),
hbrBackground: unsafe { CreateSolidBrush(COLORREF(0x00000000)) },
..Default::default()
};
unsafe { RegisterClassA(&wc); }
let screen_w = unsafe { GetSystemMetrics(SM_CXSCREEN) };
let screen_h = unsafe { GetSystemMetrics(SM_CYSCREEN) };
// WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW
let ex_style = WS_EX_LAYERED | WS_EX_TRANSPARENT | WS_EX_TOPMOST | WS_EX_TOOLWINDOW;
let style = WS_POPUP;
let hwnd = unsafe {
CreateWindowExA(
ex_style,
PCSTR(class_name.as_ptr()),
PCSTR("CSM Watermark\0".as_ptr()),
style,
0, 0, screen_w, screen_h,
HWND::default(),
HMENU::default(),
instance,
None,
)
};
if hwnd.0 == 0 {
error!("Failed to create watermark overlay window");
return;
}
{
let mut s = state.lock().unwrap_or_else(|e| e.into_inner());
s.hwnd = Some(hwnd.0);
}
info!("Watermark overlay window created (hwnd={})", hwnd.0);
// Post initial update to self — ensures overlay shows even if config
// arrived before window creation completed (race between threads)
unsafe {
let _ = PostMessageW(hwnd, WM_USER_UPDATE, WPARAM(0), LPARAM(0));
}
// Message loop
let mut msg = MSG::default();
unsafe {
loop {
match GetMessageA(&mut msg, HWND::default(), 0, 0) {
BOOL(0) | BOOL(-1) => break,
_ => {
TranslateMessage(&msg);
DispatchMessageA(&msg);
}
}
}
}
}
#[cfg(target_os = "windows")]
unsafe extern "system" fn watermark_wnd_proc(
hwnd: windows::Win32::Foundation::HWND,
msg: u32,
wparam: windows::Win32::Foundation::WPARAM,
lparam: windows::Win32::Foundation::LPARAM,
) -> windows::Win32::Foundation::LRESULT {
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::Foundation::*;
match msg {
WM_PAINT => {
if let Some(state) = WATERMARK_STATE.get() {
let s = state.lock().unwrap_or_else(|e| e.into_inner());
if s.enabled && !s.content.is_empty() {
paint_watermark(hwnd, &s);
}
}
LRESULT(0)
}
WM_ERASEBKGND => {
// Fill with black (will be color-keyed to transparent)
let hdc = HDC(wparam.0 as isize);
unsafe {
let rect = {
let mut r = RECT::default();
let _ = GetClientRect(hwnd, &mut r);
r
};
let brush = CreateSolidBrush(COLORREF(0x00000000));
let _ = FillRect(hdc, &rect, brush);
let _ = DeleteObject(brush);
}
LRESULT(1)
}
m if m == WM_USER_UPDATE => {
if let Some(state) = WATERMARK_STATE.get() {
let s = state.lock().unwrap_or_else(|e| e.into_inner());
if s.enabled {
let _ = SetWindowPos(
hwnd, HWND_TOPMOST,
0, 0,
GetSystemMetrics(SM_CXSCREEN),
GetSystemMetrics(SM_CYSCREEN),
SWP_SHOWWINDOW,
);
let alpha = (s.opacity * 255.0).clamp(0.0, 255.0) as u8;
// Color key black background to transparent, apply alpha to text
let _ = SetLayeredWindowAttributes(hwnd, COLORREF(0), alpha, LWA_COLORKEY | LWA_ALPHA);
let _ = InvalidateRect(hwnd, None, true);
} else {
let _ = ShowWindow(hwnd, SW_HIDE);
}
}
LRESULT(0)
}
WM_DESTROY => {
PostQuitMessage(0);
LRESULT(0)
}
_ => DefWindowProcA(hwnd, msg, wparam, lparam),
}
}
#[cfg(target_os = "windows")]
fn paint_watermark(hwnd: windows::Win32::Foundation::HWND, state: &WatermarkState) {
use windows::Win32::Graphics::Gdi::*;
use windows::Win32::UI::WindowsAndMessaging::*;
use windows::core::PCSTR;
unsafe {
let mut ps = PAINTSTRUCT::default();
let hdc = BeginPaint(hwnd, &mut ps);
let color = parse_color(&state.color);
let font_size = state.font_size.max(1);
// Create font with rotation
let font = CreateFontA(
(font_size as i32) * 2,
0,
(state.angle as i32) * 10,
0,
FW_NORMAL.0 as i32,
0, 0, 0,
DEFAULT_CHARSET.0 as u32,
OUT_DEFAULT_PRECIS.0 as u32,
CLIP_DEFAULT_PRECIS.0 as u32,
DEFAULT_QUALITY.0 as u32,
DEFAULT_PITCH.0 as u32 | FF_DONTCARE.0 as u32,
PCSTR("Arial\0".as_ptr()),
);
let old_font = SelectObject(hdc, font);
let _ = SetBkMode(hdc, TRANSPARENT);
let _ = SetTextColor(hdc, color);
// Draw tiled watermark text
let screen_w = GetSystemMetrics(SM_CXSCREEN);
let screen_h = GetSystemMetrics(SM_CYSCREEN);
let content_bytes: Vec<u8> = state.content.bytes().chain(std::iter::once(0)).collect();
let text_slice = &content_bytes[..content_bytes.len().saturating_sub(1)];
let spacing_x = 400i32;
let spacing_y = 200i32;
let mut y = -100i32;
while y < screen_h + 100 {
let mut x = -200i32;
while x < screen_w + 200 {
let _ = TextOutA(hdc, x, y, text_slice);
x += spacing_x;
}
y += spacing_y;
}
SelectObject(hdc, old_font);
let _ = DeleteObject(font);
EndPaint(hwnd, &ps);
}
}
fn parse_color(hex: &str) -> windows::Win32::Foundation::COLORREF {
use windows::Win32::Foundation::COLORREF;
let hex = hex.trim_start_matches('#');
if hex.len() == 6 {
let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(128);
let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(128);
let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(128);
COLORREF((b as u32) << 16 | (g as u32) << 8 | (r as u32))
} else {
COLORREF(0x00808080)
}
}

View File

@@ -0,0 +1,169 @@
use tokio::sync::watch;
use tracing::{info, warn, debug};
use serde::Deserialize;
use std::io;
/// Web filter rule from server
#[derive(Debug, Clone, Deserialize)]
pub struct WebFilterRule {
pub id: i64,
pub rule_type: String,
pub pattern: String,
}
/// Web filter configuration
#[derive(Debug, Clone, Default)]
pub struct WebFilterConfig {
pub enabled: bool,
pub rules: Vec<WebFilterRule>,
}
/// Start web filter plugin.
/// Manages the hosts file to block/allow URLs based on server rules.
pub async fn start(
mut config_rx: watch::Receiver<WebFilterConfig>,
) {
info!("Web filter plugin started");
let mut _config = WebFilterConfig::default();
loop {
tokio::select! {
result = config_rx.changed() => {
if result.is_err() {
break;
}
let new_config = config_rx.borrow_and_update().clone();
info!("Web filter config updated: enabled={}, rules={}", new_config.enabled, new_config.rules.len());
if new_config.enabled {
match apply_hosts_rules(&new_config.rules) {
Ok(()) => info!("Web filter hosts rules applied ({} rules)", new_config.rules.len()),
Err(e) => warn!("Failed to apply hosts rules: {}", e),
}
} else {
match clear_hosts_rules() {
Ok(()) => info!("Web filter hosts rules cleared"),
Err(e) => warn!("Failed to clear hosts rules: {}", e),
}
}
_config = new_config;
}
}
}
}
const HOSTS_MARKER_START: &str = "# CSM_WEB_FILTER_START";
const HOSTS_MARKER_END: &str = "# CSM_WEB_FILTER_END";
fn apply_hosts_rules(rules: &[WebFilterRule]) -> io::Result<()> {
let hosts_path = get_hosts_path();
let original = std::fs::read_to_string(&hosts_path)?;
// Remove existing CSM block
let clean = remove_csm_block(&original);
// Build new block with block rules
let block_rules: Vec<&WebFilterRule> = rules.iter()
.filter(|r| r.rule_type == "block")
.filter(|r| is_valid_hosts_entry(&r.pattern))
.collect();
if block_rules.is_empty() {
atomic_write(&hosts_path, &clean)?;
return Ok(());
}
let mut new_block = format!("{}\n", HOSTS_MARKER_START);
for rule in &block_rules {
// Redirect blocked domains to 127.0.0.1
new_block.push_str(&format!("127.0.0.1 {}\n", rule.pattern));
}
new_block.push_str(HOSTS_MARKER_END);
new_block.push('\n');
let new_content = format!("{}{}", clean, new_block);
atomic_write(&hosts_path, &new_content)?;
debug!("Applied {} web filter rules to hosts file", block_rules.len());
Ok(())
}
fn clear_hosts_rules() -> io::Result<()> {
let hosts_path = get_hosts_path();
let original = std::fs::read_to_string(&hosts_path)?;
let clean = remove_csm_block(&original);
atomic_write(&hosts_path, &clean)?;
Ok(())
}
/// Write content to a file atomically.
/// On Windows, hosts file may be locked by the DNS cache service,
/// so we use direct overwrite with truncation instead of rename.
fn atomic_write(path: &str, content: &str) -> io::Result<()> {
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.write(true)
.truncate(true)
.create(true)
.open(path)?;
file.write_all(content.as_bytes())?;
file.sync_data()?;
Ok(())
}
fn remove_csm_block(content: &str) -> String {
// Handle paired markers (normal case)
let start_idx = content.find(HOSTS_MARKER_START);
let end_idx = content.find(HOSTS_MARKER_END);
match (start_idx, end_idx) {
(Some(s), Some(e)) if e > s => {
let mut result = String::new();
result.push_str(&content[..s]);
result.push_str(&content[e + HOSTS_MARKER_END.len()..]);
return result.trim_end().to_string() + "\n";
}
_ => {}
}
// Handle orphan markers: remove any lone START or END marker line
let lines: Vec<&str> = content.lines().collect();
let cleaned: Vec<&str> = lines.iter()
.filter(|line| {
let trimmed = line.trim();
trimmed != HOSTS_MARKER_START && trimmed != HOSTS_MARKER_END
})
.copied()
.collect();
let mut result = cleaned.join("\n");
if !result.ends_with('\n') {
result.push('\n');
}
result
}
fn get_hosts_path() -> String {
#[cfg(target_os = "windows")]
{
r"C:\Windows\System32\drivers\etc\hosts".to_string()
}
#[cfg(not(target_os = "windows"))]
{
"/etc/hosts".to_string()
}
}
/// Validate that a pattern is safe to write to the hosts file.
/// Rejects patterns with whitespace, control chars, or comment markers.
fn is_valid_hosts_entry(pattern: &str) -> bool {
if pattern.is_empty() {
return false;
}
// Reject if contains whitespace, control chars, or comment markers
if pattern.chars().any(|c| c.is_whitespace() || c.is_control() || c == '#') {
return false;
}
// Must look like a hostname (alphanumeric, dots, hyphens, underscores, asterisks)
pattern.chars().all(|c| c.is_alphanumeric() || c == '.' || c == '-' || c == '_' || c == '*')
}

View File

@@ -0,0 +1,12 @@
[package]
name = "csm-protocol"
version.workspace = true
edition.workspace = true
[dependencies]
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
anyhow = { workspace = true }

View File

@@ -0,0 +1,108 @@
use serde::{Deserialize, Serialize};
/// Real-time device status report
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DeviceStatus {
pub device_uid: String,
pub cpu_usage: f64,
pub memory_usage: f64,
pub memory_total_mb: u64,
pub disk_usage: f64,
pub disk_total_mb: u64,
pub network_rx_rate: u64,
pub network_tx_rate: u64,
pub running_procs: u32,
pub top_processes: Vec<ProcessInfo>,
}
/// Top process information
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ProcessInfo {
pub name: String,
pub pid: u32,
pub cpu_usage: f64,
pub memory_mb: u64,
}
/// Hardware asset information
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct HardwareAsset {
pub device_uid: String,
pub cpu_model: String,
pub cpu_cores: u32,
pub memory_total_mb: u64,
pub disk_model: String,
pub disk_total_mb: u64,
pub gpu_model: Option<String>,
pub motherboard: Option<String>,
pub serial_number: Option<String>,
}
/// Software asset information
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SoftwareAsset {
pub device_uid: String,
pub name: String,
pub version: Option<String>,
pub publisher: Option<String>,
pub install_date: Option<String>,
pub install_path: Option<String>,
}
/// USB device event
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbEvent {
pub device_uid: String,
pub event_type: UsbEventType,
pub vendor_id: Option<String>,
pub product_id: Option<String>,
pub serial: Option<String>,
pub device_name: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum UsbEventType {
Inserted,
Removed,
Blocked,
}
/// Asset change event
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AssetChange {
pub device_uid: String,
pub change_type: AssetChangeType,
pub change_detail: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AssetChangeType {
Hardware,
SoftwareAdded,
SoftwareRemoved,
}
/// USB policy (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbPolicy {
pub policy_id: i64,
pub policy_type: UsbPolicyType,
pub allowed_devices: Vec<UsbDevicePattern>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum UsbPolicyType {
AllBlock,
Whitelist,
Blacklist,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbDevicePattern {
pub vendor_id: Option<String>,
pub product_id: Option<String>,
pub serial: Option<String>,
}

View File

@@ -0,0 +1,27 @@
pub mod message;
pub mod device;
// Re-export constants from message module
pub use message::{MAGIC, PROTOCOL_VERSION, FRAME_HEADER_SIZE, MAX_PAYLOAD_SIZE};
// Core frame & message types
pub use message::{
Frame, FrameError, MessageType,
RegisterRequest, RegisterResponse, ClientConfig,
HeartbeatPayload, TaskExecutePayload, ConfigUpdateType,
};
// Device status & asset types
pub use device::{
DeviceStatus, ProcessInfo, HardwareAsset, SoftwareAsset,
UsbEvent, UsbEventType, AssetChange, AssetChangeType,
UsbPolicy, UsbPolicyType, UsbDevicePattern,
};
// Plugin message payloads
pub use message::{
WebAccessLogEntry, UsageDailyReport, AppUsageEntry,
SoftwareViolationReport, UsbFileOpEntry,
WatermarkConfigPayload, PluginControlPayload,
UsbPolicyPayload, UsbDeviceRule,
};

View File

@@ -0,0 +1,417 @@
use serde::{Deserialize, Serialize};
/// Protocol magic bytes: "CSM\0"
pub const MAGIC: [u8; 4] = [0x43, 0x53, 0x4D, 0x00];
/// Current protocol version
pub const PROTOCOL_VERSION: u8 = 0x01;
/// Frame header size: magic(4) + version(1) + type(1) + length(4)
pub const FRAME_HEADER_SIZE: usize = 10;
/// Maximum payload size: 4 MB — prevents memory exhaustion from malicious frames
pub const MAX_PAYLOAD_SIZE: usize = 4 * 1024 * 1024;
/// Binary message types for client-server communication
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageType {
// Client → Server (Core)
Heartbeat = 0x01,
Register = 0x02,
StatusReport = 0x03,
AssetReport = 0x04,
AssetChange = 0x05,
UsbEvent = 0x06,
AlertAck = 0x07,
// Server → Client (Core)
RegisterResponse = 0x08,
PolicyUpdate = 0x10,
ConfigUpdate = 0x11,
TaskExecute = 0x12,
// Plugin: Web Filter (上网拦截)
WebFilterRuleUpdate = 0x20,
WebAccessLog = 0x21,
// Plugin: Usage Timer (时长记录)
UsageReport = 0x30,
AppUsageReport = 0x31,
// Plugin: Software Blocker (软件禁止安装)
SoftwareBlacklist = 0x40,
SoftwareViolation = 0x41,
// Plugin: Popup Blocker (弹窗拦截)
PopupRules = 0x50,
// Plugin: USB File Audit (U盘文件操作记录)
UsbFileOp = 0x60,
// Plugin: Screen Watermark (水印管理)
WatermarkConfig = 0x70,
// Plugin: USB Policy (U盘管控策略)
UsbPolicyUpdate = 0x71,
// Plugin control
PluginEnable = 0x80,
PluginDisable = 0x81,
}
impl TryFrom<u8> for MessageType {
type Error = String;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(Self::Heartbeat),
0x02 => Ok(Self::Register),
0x03 => Ok(Self::StatusReport),
0x04 => Ok(Self::AssetReport),
0x05 => Ok(Self::AssetChange),
0x06 => Ok(Self::UsbEvent),
0x07 => Ok(Self::AlertAck),
0x08 => Ok(Self::RegisterResponse),
0x10 => Ok(Self::PolicyUpdate),
0x11 => Ok(Self::ConfigUpdate),
0x12 => Ok(Self::TaskExecute),
0x20 => Ok(Self::WebFilterRuleUpdate),
0x21 => Ok(Self::WebAccessLog),
0x30 => Ok(Self::UsageReport),
0x31 => Ok(Self::AppUsageReport),
0x40 => Ok(Self::SoftwareBlacklist),
0x41 => Ok(Self::SoftwareViolation),
0x50 => Ok(Self::PopupRules),
0x60 => Ok(Self::UsbFileOp),
0x70 => Ok(Self::WatermarkConfig),
0x71 => Ok(Self::UsbPolicyUpdate),
0x80 => Ok(Self::PluginEnable),
0x81 => Ok(Self::PluginDisable),
_ => Err(format!("Unknown message type: 0x{:02X}", value)),
}
}
}
/// A wire-format frame for transmission over TCP
#[derive(Debug, Clone)]
pub struct Frame {
pub version: u8,
pub msg_type: MessageType,
pub payload: Vec<u8>,
}
impl Frame {
/// Create a new frame with the current protocol version
pub fn new(msg_type: MessageType, payload: Vec<u8>) -> Self {
Self {
version: PROTOCOL_VERSION,
msg_type,
payload,
}
}
/// Create a new frame with JSON-serialized payload
pub fn new_json<T: Serialize>(msg_type: MessageType, data: &T) -> anyhow::Result<Self> {
let payload = serde_json::to_vec(data)?;
Ok(Self::new(msg_type, payload))
}
/// Encode frame to bytes for transmission
/// Format: MAGIC(4) + VERSION(1) + TYPE(1) + LENGTH(4) + PAYLOAD(var)
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + self.payload.len());
buf.extend_from_slice(&MAGIC);
buf.push(self.version);
buf.push(self.msg_type as u8);
buf.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
buf.extend_from_slice(&self.payload);
buf
}
/// Decode frame from bytes. Returns Ok(Some(frame)) when a complete frame is available.
pub fn decode(data: &[u8]) -> Result<Option<Frame>, FrameError> {
if data.len() < FRAME_HEADER_SIZE {
return Ok(None);
}
if data[0..4] != MAGIC {
return Err(FrameError::InvalidMagic);
}
let version = data[4];
let msg_type_byte = data[5];
let payload_len = u32::from_be_bytes([data[6], data[7], data[8], data[9]]) as usize;
if payload_len > MAX_PAYLOAD_SIZE {
return Err(FrameError::PayloadTooLarge(payload_len));
}
if data.len() < FRAME_HEADER_SIZE + payload_len {
return Ok(None);
}
let msg_type = MessageType::try_from(msg_type_byte)
.map_err(|e| FrameError::UnknownMessageType(msg_type_byte, e))?;
let payload = data[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len].to_vec();
Ok(Some(Frame {
version,
msg_type,
payload,
}))
}
/// Deserialize the payload as JSON
pub fn decode_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
serde_json::from_slice(&self.payload)
}
/// Total encoded size of this frame
pub fn encoded_size(&self) -> usize {
FRAME_HEADER_SIZE + self.payload.len()
}
}
#[derive(Debug, thiserror::Error)]
pub enum FrameError {
#[error("Invalid magic bytes in frame header")]
InvalidMagic,
#[error("Unknown message type: 0x{0:02X} - {1}")]
UnknownMessageType(u8, String),
#[error("Payload too large: {0} bytes (max {})", MAX_PAYLOAD_SIZE)]
PayloadTooLarge(usize),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
// ==================== Core Message Payloads ====================
/// Registration request payload
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterRequest {
pub device_uid: String,
pub hostname: String,
pub registration_token: String,
pub os_version: String,
pub mac_address: Option<String>,
}
/// Registration response payload
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterResponse {
pub device_secret: String,
pub config: ClientConfig,
}
/// Server-pushed client configuration
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ClientConfig {
pub heartbeat_interval_secs: u64,
pub status_report_interval_secs: u64,
pub asset_report_interval_secs: u64,
pub server_version: String,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
heartbeat_interval_secs: 30,
status_report_interval_secs: 60,
asset_report_interval_secs: 86400,
server_version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}
/// Heartbeat payload (minimal)
#[derive(Debug, Serialize, Deserialize)]
pub struct HeartbeatPayload {
pub device_uid: String,
pub timestamp: String,
pub hmac: String,
}
/// Task execution request (Server → Client)
#[derive(Debug, Serialize, Deserialize)]
pub struct TaskExecutePayload {
pub task_type: String,
pub params: serde_json::Value,
}
/// Config update types (Server → Client)
#[derive(Debug, Serialize, Deserialize)]
pub enum ConfigUpdateType {
UpdateIntervals { heartbeat: u64, status: u64, asset: u64 },
TlsCertRotate,
SelfDestruct,
}
// ==================== Plugin Message Payloads ====================
/// Plugin: Web Access Log entry (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct WebAccessLogEntry {
pub device_uid: String,
pub url: String,
pub action: String, // "allowed" | "blocked"
pub timestamp: String,
}
/// Plugin: Daily Usage Report (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct UsageDailyReport {
pub device_uid: String,
pub date: String,
pub total_active_minutes: u32,
pub total_idle_minutes: u32,
pub first_active_at: Option<String>,
pub last_active_at: Option<String>,
}
/// Plugin: App Usage Report (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct AppUsageEntry {
pub device_uid: String,
pub date: String,
pub app_name: String,
pub usage_minutes: u32,
}
/// Plugin: Software Violation (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct SoftwareViolationReport {
pub device_uid: String,
pub software_name: String,
pub action_taken: String, // "blocked_install" | "auto_uninstalled" | "alerted"
pub timestamp: String,
}
/// Plugin: USB File Operation (Client → Server)
#[derive(Debug, Serialize, Deserialize)]
pub struct UsbFileOpEntry {
pub device_uid: String,
pub usb_serial: Option<String>,
pub drive_letter: Option<String>,
pub operation: String, // "create" | "delete" | "rename" | "modify"
pub file_path: String,
pub file_size: Option<u64>,
pub timestamp: String,
}
/// Plugin: Watermark Config (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct WatermarkConfigPayload {
pub content: String,
pub font_size: u32,
pub opacity: f64,
pub color: String,
pub angle: i32,
pub enabled: bool,
}
/// Plugin enable/disable command (Server → Client)
#[derive(Debug, Serialize, Deserialize)]
pub struct PluginControlPayload {
pub plugin_name: String,
pub enabled: bool,
}
/// Plugin: USB Policy Config (Server → Client)
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbPolicyPayload {
pub policy_type: String, // "all_block" | "whitelist" | "blacklist"
pub enabled: bool,
pub rules: Vec<UsbDeviceRule>,
}
/// A single USB device rule for whitelist/blacklist matching
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UsbDeviceRule {
pub vendor_id: Option<String>,
pub product_id: Option<String>,
pub serial: Option<String>,
pub device_name: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_encode_decode_roundtrip() {
let original = Frame::new(MessageType::Heartbeat, b"test payload".to_vec());
let encoded = original.encode();
let decoded = Frame::decode(&encoded).unwrap().unwrap();
assert_eq!(decoded.version, PROTOCOL_VERSION);
assert_eq!(decoded.msg_type, MessageType::Heartbeat);
assert_eq!(decoded.payload, b"test payload");
}
#[test]
fn test_frame_decode_incomplete_data() {
let data = [0x43, 0x53, 0x4D, 0x01, 0x01];
let result = Frame::decode(&data).unwrap();
assert!(result.is_none());
}
#[test]
fn test_frame_decode_invalid_magic() {
let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00];
let result = Frame::decode(&data);
assert!(matches!(result, Err(FrameError::InvalidMagic)));
}
#[test]
fn test_json_frame_roundtrip() {
let heartbeat = HeartbeatPayload {
device_uid: "test-uid".to_string(),
timestamp: "2026-04-03T12:00:00Z".to_string(),
hmac: "abc123".to_string(),
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat).unwrap();
let encoded = frame.encode();
let decoded = Frame::decode(&encoded).unwrap().unwrap();
let parsed: HeartbeatPayload = decoded.decode_payload().unwrap();
assert_eq!(parsed.device_uid, "test-uid");
assert_eq!(parsed.hmac, "abc123");
}
#[test]
fn test_plugin_message_types_roundtrip() {
let types = [
MessageType::WebAccessLog,
MessageType::UsageReport,
MessageType::AppUsageReport,
MessageType::SoftwareViolation,
MessageType::UsbFileOp,
MessageType::WatermarkConfig,
MessageType::PluginEnable,
MessageType::PluginDisable,
];
for mt in types {
let frame = Frame::new(mt, vec![1, 2, 3]);
let encoded = frame.encode();
let decoded = Frame::decode(&encoded).unwrap().unwrap();
assert_eq!(decoded.msg_type, mt);
}
}
#[test]
fn test_frame_decode_payload_too_large() {
// Craft a header that claims a 10 MB payload
let mut data = Vec::with_capacity(FRAME_HEADER_SIZE);
data.extend_from_slice(&MAGIC);
data.push(PROTOCOL_VERSION);
data.push(MessageType::Heartbeat as u8);
data.extend_from_slice(&(10 * 1024 * 1024u32).to_be_bytes());
// Don't actually include the payload — the size check should reject first
let result = Frame::decode(&data);
assert!(matches!(result, Err(FrameError::PayloadTooLarge(_))));
}
}

51
crates/server/Cargo.toml Normal file
View File

@@ -0,0 +1,51 @@
[package]
name = "csm-server"
version.workspace = true
edition.workspace = true
[dependencies]
csm-protocol = { path = "../protocol" }
# Async runtime
tokio = { workspace = true }
# Web framework
axum = { version = "0.7", features = ["ws"] }
tower-http = { version = "0.5", features = ["cors", "fs", "trace", "compression-gzip", "set-header"] }
tower = "0.4"
# Database
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
# TLS
rustls = "0.23"
tokio-rustls = "0.26"
rustls-pemfile = "2"
rustls-pki-types = "1"
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
# Auth
jsonwebtoken = "9"
bcrypt = "0.15"
# Notifications
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder", "hostname"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
# Config & logging
toml = "0.8"
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
# Utilities
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
include_dir = "0.7"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"

118
crates/server/src/alert.rs Normal file
View File

@@ -0,0 +1,118 @@
use crate::AppState;
use tracing::{info, warn, error};
/// Background task for data cleanup and alert processing
pub async fn cleanup_task(state: AppState) {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600));
loop {
interval.tick().await;
// Cleanup old status history
if let Err(e) = sqlx::query(
"DELETE FROM device_status_history WHERE reported_at < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.status_history_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup status history: {}", e);
}
// Cleanup old USB events
if let Err(e) = sqlx::query(
"DELETE FROM usb_events WHERE event_time < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.usb_events_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup USB events: {}", e);
}
// Cleanup handled alert records
if let Err(e) = sqlx::query(
"DELETE FROM alert_records WHERE handled = 1 AND triggered_at < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.alert_records_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup alert records: {}", e);
}
// Mark devices as offline if no heartbeat for 2 minutes
if let Err(e) = sqlx::query(
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"
)
.execute(&state.db)
.await
{
error!("Failed to mark stale devices offline: {}", e);
}
// SQLite WAL checkpoint
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
.execute(&state.db)
.await
{
warn!("WAL checkpoint failed: {}", e);
}
info!("Cleanup cycle completed");
}
}
/// Send email notification
pub async fn send_email(
smtp_config: &crate::config::SmtpConfig,
to: &str,
subject: &str,
body: &str,
) -> anyhow::Result<()> {
use lettre::message::header::ContentType;
use lettre::{Message, SmtpTransport, Transport};
use lettre::transport::smtp::authentication::Credentials;
let email = Message::builder()
.from(smtp_config.from.parse()?)
.to(to.parse()?)
.subject(subject)
.header(ContentType::TEXT_HTML)
.body(body.to_string())?;
let creds = Credentials::new(
smtp_config.username.clone(),
smtp_config.password.clone(),
);
let mailer = SmtpTransport::starttls_relay(&smtp_config.host)?
.port(smtp_config.port)
.credentials(creds)
.build();
mailer.send(&email)?;
Ok(())
}
/// Shared HTTP client for webhook notifications.
/// Lazily initialized once and reused across calls to benefit from connection pooling.
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
fn webhook_client() -> &'static reqwest::Client {
WEBHOOK_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new())
})
}
/// Send webhook notification
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
webhook_client().post(url)
.json(payload)
.send()
.await?;
Ok(())
}

View File

@@ -0,0 +1,243 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
use super::auth::Claims;
#[derive(Debug, Deserialize)]
pub struct AlertRecordListParams {
pub device_uid: Option<String>,
pub alert_type: Option<String>,
pub severity: Option<String>,
pub handled: Option<i32>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_rules(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
FROM alert_rules ORDER BY created_at DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name": r.get::<String, _>("name"),
"rule_type": r.get::<String, _>("rule_type"),
"condition": r.get::<String, _>("condition"),
"severity": r.get::<String, _>("severity"),
"enabled": r.get::<i32, _>("enabled"),
"notify_email": r.get::<Option<String>, _>("notify_email"),
"notify_webhook": r.get::<Option<String>, _>("notify_webhook"),
"created_at": r.get::<String, _>("created_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"rules": items,
})))
}
Err(e) => Json(ApiResponse::internal_error("query alert rules", e)),
}
}
pub async fn list_records(
State(state): State<AppState>,
Query(params): Query<AlertRecordListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None (Axum deserializes `key=` as Some(""))
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let alert_type = params.alert_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
let severity = params.severity.as_deref().filter(|s| !s.is_empty()).map(String::from);
let handled = params.handled;
let rows = sqlx::query(
"SELECT id, rule_id, device_uid, alert_type, severity, detail, handled, handled_by, handled_at, triggered_at
FROM alert_records WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR alert_type = ?)
AND (? IS NULL OR severity = ?)
AND (? IS NULL OR handled = ?)
ORDER BY triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&alert_type).bind(&alert_type)
.bind(&severity).bind(&severity)
.bind(&handled).bind(&handled)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_id": r.get::<Option<i64>, _>("rule_id"),
"device_uid": r.get::<Option<String>, _>("device_uid"),
"alert_type": r.get::<String, _>("alert_type"),
"severity": r.get::<String, _>("severity"),
"detail": r.get::<String, _>("detail"),
"handled": r.get::<i32, _>("handled"),
"handled_by": r.get::<Option<String>, _>("handled_by"),
"handled_at": r.get::<Option<String>, _>("handled_at"),
"triggered_at": r.get::<String, _>("triggered_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"records": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query alert records", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub name: String,
pub rule_type: String,
pub condition: String,
pub severity: Option<String>,
pub notify_email: Option<String>,
pub notify_webhook: Option<String>,
}
pub async fn create_rule(
State(state): State<AppState>,
Json(body): Json<CreateRuleRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
let result = sqlx::query(
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&body.name)
.bind(&body.rule_type)
.bind(&body.condition)
.bind(&severity)
.bind(&body.notify_email)
.bind(&body.notify_webhook)
.execute(&state.db)
.await;
match result {
Ok(r) => (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
"id": r.last_insert_rowid(),
})))),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create alert rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest {
pub name: Option<String>,
pub rule_type: Option<String>,
pub condition: Option<String>,
pub severity: Option<String>,
pub enabled: Option<i32>,
pub notify_email: Option<String>,
pub notify_webhook: Option<String>,
}
pub async fn update_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateRuleRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let existing = sqlx::query("SELECT * FROM alert_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Rule not found")),
Err(e) => return Json(ApiResponse::internal_error("query alert rule", e)),
};
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
let condition = body.condition.unwrap_or_else(|| existing.get::<String, _>("condition"));
let severity = body.severity.unwrap_or_else(|| existing.get::<String, _>("severity"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
let result = sqlx::query(
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&name)
.bind(&rule_type)
.bind(&condition)
.bind(&severity)
.bind(enabled)
.bind(&notify_email)
.bind(&notify_webhook)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(_) => Json(ApiResponse::ok(serde_json::json!({"updated": true}))),
Err(e) => Json(ApiResponse::internal_error("update alert rule", e)),
}
}
pub async fn delete_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let result = sqlx::query("DELETE FROM alert_rules WHERE id = ?")
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
} else {
Json(ApiResponse::error("Rule not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("delete alert rule", e)),
}
}
pub async fn handle_record(
State(state): State<AppState>,
Path(id): Path<i64>,
claims: axum::Extension<Claims>,
) -> Json<ApiResponse<serde_json::Value>> {
let handled_by = &claims.username;
let result = sqlx::query(
"UPDATE alert_records SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
)
.bind(handled_by)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
Json(ApiResponse::ok(serde_json::json!({"handled": true})))
} else {
Json(ApiResponse::error("Alert record not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("handle alert record", e)),
}
}

View File

@@ -0,0 +1,143 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
#[derive(Debug, Deserialize)]
pub struct AssetListParams {
pub device_uid: Option<String>,
pub search: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_hardware(
State(state): State<AppState>,
Query(params): Query<AssetListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb,
gpu_model, motherboard, serial_number, reported_at
FROM hardware_assets WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR cpu_model LIKE '%' || ? || '%' OR gpu_model LIKE '%' || ? || '%')
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"cpu_model": r.get::<String, _>("cpu_model"),
"cpu_cores": r.get::<i32, _>("cpu_cores"),
"memory_total_mb": r.get::<i64, _>("memory_total_mb"),
"disk_model": r.get::<String, _>("disk_model"),
"disk_total_mb": r.get::<i64, _>("disk_total_mb"),
"gpu_model": r.get::<Option<String>, _>("gpu_model"),
"motherboard": r.get::<Option<String>, _>("motherboard"),
"serial_number": r.get::<Option<String>, _>("serial_number"),
"reported_at": r.get::<String, _>("reported_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"hardware": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query hardware assets", e)),
}
}
pub async fn list_software(
State(state): State<AppState>,
Query(params): Query<AssetListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, name, version, publisher, install_date, install_path
FROM software_assets WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR name LIKE '%' || ? || '%' OR publisher LIKE '%' || ? || '%')
ORDER BY name ASC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"name": r.get::<String, _>("name"),
"version": r.get::<Option<String>, _>("version"),
"publisher": r.get::<Option<String>, _>("publisher"),
"install_date": r.get::<Option<String>, _>("install_date"),
"install_path": r.get::<Option<String>, _>("install_path"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"software": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query software assets", e)),
}
}
pub async fn list_changes(
State(state): State<AppState>,
Query(page): Query<Pagination>,
) -> Json<ApiResponse<serde_json::Value>> {
let offset = page.offset();
let limit = page.limit();
let rows = sqlx::query(
"SELECT id, device_uid, change_type, change_detail, detected_at
FROM asset_changes ORDER BY detected_at DESC LIMIT ? OFFSET ?"
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"change_type": r.get::<String, _>("change_type"),
"change_detail": r.get::<String, _>("change_detail"),
"detected_at": r.get::<String, _>("detected_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"changes": items,
"page": page.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query asset changes", e)),
}
}

View File

@@ -0,0 +1,295 @@
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::Mutex;
use crate::AppState;
use super::ApiResponse;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: i64,
pub username: String,
pub role: String,
pub exp: u64,
pub iat: u64,
pub token_type: String,
/// Random family ID for refresh token rotation detection
pub family: String,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: String,
pub user: UserInfo,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct UserInfo {
pub id: i64,
pub username: String,
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
/// In-memory rate limiter for login attempts
#[derive(Clone, Default)]
pub struct LoginRateLimiter {
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
}
impl LoginRateLimiter {
pub fn new() -> Self {
Self::default()
}
/// Returns true if the request should be rate-limited
pub async fn is_limited(&self, key: &str) -> bool {
let mut attempts = self.attempts.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(300); // 5-minute window
let max_attempts = 10u32;
if let Some((first_attempt, count)) = attempts.get_mut(key) {
if now.duration_since(*first_attempt) > window {
// Window expired, reset
*first_attempt = now;
*count = 1;
false
} else if *count >= max_attempts {
true // Rate limited
} else {
*count += 1;
false
}
} else {
attempts.insert(key.to_string(), (now, 1));
// Cleanup old entries periodically
if attempts.len() > 1000 {
let cutoff = now - window;
attempts.retain(|_, (t, _)| *t > cutoff);
}
false
}
}
}
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
// Rate limit check
if state.login_limiter.is_limited(&req.username).await {
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
}
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
"SELECT id, username, role FROM users WHERE username = ?"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let user = match user {
Some(u) => u,
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
};
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
.bind(user.id)
.fetch_one(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
}
let now = chrono::Utc::now().timestamp() as u64;
let family = uuid::Uuid::new_v4().to_string();
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
)
.bind(user.id)
.bind(format!("User {} logged in", user.username))
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
}
pub async fn refresh(
State(state): State<AppState>,
Json(req): Json<RefreshRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
let claims = decode::<Claims>(
&req.refresh_token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "refresh" {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
}
// Check if this refresh token family has been revoked (reuse detection)
let revoked: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
)
.bind(&claims.claims.family)
.fetch_one(&state.db)
.await
.unwrap_or(0) > 0;
if revoked {
// Token reuse detected — revoke entire family and force re-login
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.execute(&state.db)
.await;
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
}
let user = UserInfo {
id: claims.claims.sub,
username: claims.claims.username,
role: claims.claims.role,
};
// Rotate: new family for each refresh
let new_family = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().timestamp() as u64;
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
// Revoke old family
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.claims.family)
.bind(claims.claims.sub)
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
}
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
let claims = Claims {
sub: user.id,
username: user.username.clone(),
role: user.role.clone(),
exp: now + ttl,
iat: now,
token_type: token_type.to_string(),
family: family.to_string(),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// Axum middleware: require valid JWT access token
pub async fn require_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let token = match auth_header {
Some(t) => t,
None => return Err(StatusCode::UNAUTHORIZED),
};
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "access" {
return Err(StatusCode::UNAUTHORIZED);
}
// Inject claims into request extensions for handlers to use
request.extensions_mut().insert(claims.claims);
Ok(next.run(request).await)
}
/// Axum middleware: require admin role for write operations + audit log
pub async fn require_admin(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let claims = request.extensions()
.get::<Claims>()
.ok_or(StatusCode::UNAUTHORIZED)?;
if claims.role != "admin" {
return Err(StatusCode::FORBIDDEN);
}
// Capture audit info before running handler
let method = request.method().clone();
let path = request.uri().path().to_string();
let user_id = claims.sub;
let username = claims.username.clone();
let response = next.run(request).await;
// Record admin action to audit log (fire and forget — don't block response)
let status = response.status();
if status.is_success() {
let action = format!("{} {}", method, path);
let detail = format!("by {}", username);
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)"
)
.bind(user_id)
.bind(&action)
.bind(&detail)
.execute(&state.db)
.await;
}
Ok(response)
}

View File

@@ -0,0 +1,263 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
#[derive(Debug, Deserialize)]
pub struct DeviceListParams {
pub status: Option<String>,
pub group: Option<String>,
pub search: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct DeviceRow {
pub id: i64,
pub device_uid: String,
pub hostname: String,
pub ip_address: String,
pub mac_address: Option<String>,
pub os_version: Option<String>,
pub client_version: Option<String>,
pub status: String,
pub last_heartbeat: Option<String>,
pub registered_at: Option<String>,
pub group_name: Option<String>,
}
pub async fn list(
State(state): State<AppState>,
Query(params): Query<DeviceListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None (Axum deserializes `status=` as Some(""))
let status = params.status.as_deref().filter(|s| !s.is_empty()).map(String::from);
let group = params.group.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let devices = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
.bind(&search).bind(&search).bind(&search)
.fetch_one(&state.db)
.await
.unwrap_or(0);
match devices {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"devices": rows,
"total": total,
"page": params.page.unwrap_or(1),
"page_size": limit,
}))),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn get_detail(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let device = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE device_uid = ?"
)
.bind(&uid)
.fetch_optional(&state.db)
.await;
match device {
Ok(Some(d)) => Json(ApiResponse::ok(serde_json::to_value(d).unwrap_or_default())),
Ok(None) => Json(ApiResponse::error("Device not found")),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
#[derive(Debug, Serialize, sqlx::FromRow)]
struct StatusRow {
pub cpu_usage: f64,
pub memory_usage: f64,
pub memory_total_mb: i64,
pub disk_usage: f64,
pub disk_total_mb: i64,
pub network_rx_rate: i64,
pub network_tx_rate: i64,
pub running_procs: i32,
pub top_processes: Option<String>,
pub reported_at: String,
}
pub async fn get_status(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let status = sqlx::query_as::<_, StatusRow>(
"SELECT cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb,
network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at
FROM device_status WHERE device_uid = ?"
)
.bind(&uid)
.fetch_optional(&state.db)
.await;
match status {
Ok(Some(s)) => {
let mut val = serde_json::to_value(&s).unwrap_or_default();
// Parse top_processes JSON string back to array
if let Some(tp_str) = &s.top_processes {
if let Ok(tp) = serde_json::from_str::<serde_json::Value>(tp_str) {
val["top_processes"] = tp;
}
}
Json(ApiResponse::ok(val))
}
Ok(None) => Json(ApiResponse::error("No status data found")),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn get_history(
State(state): State<AppState>,
Path(uid): Path<String>,
Query(page): Query<Pagination>,
) -> Json<ApiResponse<serde_json::Value>> {
let offset = page.offset();
let limit = page.limit();
let rows = sqlx::query(
"SELECT cpu_usage, memory_usage, disk_usage, running_procs, reported_at
FROM device_status_history WHERE device_uid = ?
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
)
.bind(&uid)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| {
serde_json::json!({
"cpu_usage": r.get::<f64, _>("cpu_usage"),
"memory_usage": r.get::<f64, _>("memory_usage"),
"disk_usage": r.get::<f64, _>("disk_usage"),
"running_procs": r.get::<i32, _>("running_procs"),
"reported_at": r.get::<String, _>("reported_at"),
})
}).collect();
Json(ApiResponse::ok(serde_json::json!({
"history": items,
"page": page.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn remove(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<()>> {
// If client is connected, send self-destruct command
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
// Delete device and all associated data in a transaction
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(e) => return Json(ApiResponse::internal_error("begin transaction", e)),
};
// Delete status history
if let Err(e) = sqlx::query("DELETE FROM device_status_history WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await
{
return Json(ApiResponse::internal_error("remove device history", e));
}
// Delete current status
if let Err(e) = sqlx::query("DELETE FROM device_status WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await
{
return Json(ApiResponse::internal_error("remove device status", e));
}
// Delete plugin-related data
let cleanup_tables = [
"hardware_assets",
"usb_events",
"usb_file_operations",
"usage_daily",
"app_usage_daily",
"software_violations",
"web_access_log",
"popup_block_stats",
];
for table in &cleanup_tables {
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
.bind(&uid)
.execute(&mut *tx)
.await
{
tracing::warn!("Failed to clean {} for device {}: {}", table, uid, e);
}
}
// Finally delete the device itself
let delete_result = sqlx::query("DELETE FROM devices WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await;
match delete_result {
Ok(r) if r.rows_affected() > 0 => {
if let Err(e) = tx.commit().await {
return Json(ApiResponse::internal_error("commit device deletion", e));
}
state.clients.unregister(&uid).await;
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Device not found")),
Err(e) => Json(ApiResponse::internal_error("remove device", e)),
}
}

View File

@@ -0,0 +1,120 @@
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
use serde::{Deserialize, Serialize};
use crate::AppState;
pub mod auth;
pub mod devices;
pub mod assets;
pub mod usb;
pub mod alerts;
pub mod plugins;
pub fn routes(state: AppState) -> Router<AppState> {
let public = Router::new()
.route("/api/auth/login", post(auth::login))
.route("/api/auth/refresh", post(auth::refresh))
.route("/health", get(health_check))
.with_state(state.clone());
// Read-only routes (any authenticated user)
let read_routes = Router::new()
// Devices
.route("/api/devices", get(devices::list))
.route("/api/devices/:uid", get(devices::get_detail))
.route("/api/devices/:uid/status", get(devices::get_status))
.route("/api/devices/:uid/history", get(devices::get_history))
// Assets
.route("/api/assets/hardware", get(assets::list_hardware))
.route("/api/assets/software", get(assets::list_software))
.route("/api/assets/changes", get(assets::list_changes))
// USB (read)
.route("/api/usb/events", get(usb::list_events))
.route("/api/usb/policies", get(usb::list_policies))
// Alerts (read)
.route("/api/alerts/rules", get(alerts::list_rules))
.route("/api/alerts/records", get(alerts::list_records))
// Plugin read routes
.merge(plugins::read_routes())
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// Write routes (admin only)
let write_routes = Router::new()
// Devices
.route("/api/devices/:uid", delete(devices::remove))
// USB (write)
.route("/api/usb/policies", post(usb::create_policy))
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
// Alerts (write)
.route("/api/alerts/rules", post(alerts::create_rule))
.route("/api/alerts/rules/:id", put(alerts::update_rule).delete(alerts::delete_rule))
.route("/api/alerts/records/:id/handle", put(alerts::handle_record))
// Plugin write routes (already has require_admin layer internally)
.merge(plugins::write_routes())
// Layer order: outer (require_admin) runs AFTER inner (require_auth)
// so require_auth sets Claims extension first, then require_admin checks it
.layer(middleware::from_fn_with_state(state.clone(), auth::require_admin))
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// WebSocket has its own JWT auth via query parameter
let ws_router = Router::new()
.route("/ws", get(crate::ws::ws_handler))
.with_state(state.clone());
Router::new()
.merge(public)
.merge(read_routes)
.merge(write_routes)
.merge(ws_router)
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
}
async fn health_check() -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok",
})
}
/// Standard API response envelope
#[derive(Serialize)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub data: Option<T>,
pub error: Option<String>,
}
impl<T: Serialize> ApiResponse<T> {
pub fn ok(data: T) -> Self {
Self { success: true, data: Some(data), error: None }
}
pub fn error(msg: impl Into<String>) -> Self {
Self { success: false, data: None, error: Some(msg.into()) }
}
/// Log internal error and return sanitized message to client
pub fn internal_error(context: &str, e: impl std::fmt::Display) -> Self {
tracing::error!("{}: {}", context, e);
Self { success: false, data: None, error: Some("Internal server error".to_string()) }
}
}
/// Pagination parameters
#[derive(Debug, Deserialize)]
pub struct Pagination {
pub page: Option<u32>,
pub page_size: Option<u32>,
}
impl Pagination {
pub fn offset(&self) -> u32 {
self.page.unwrap_or(1).saturating_sub(1) * self.limit()
}
pub fn limit(&self) -> u32 {
self.page_size.unwrap_or(20).min(100)
}
}

View File

@@ -0,0 +1,49 @@
pub mod web_filter;
pub mod usage_timer;
pub mod software_blocker;
pub mod popup_blocker;
pub mod usb_file_audit;
pub mod watermark;
use axum::{Router, routing::{get, post, put}};
use crate::AppState;
/// Read-only plugin routes (accessible by admin + viewer)
pub fn read_routes() -> Router<AppState> {
Router::new()
// Web Filter
.route("/api/plugins/web-filter/rules", get(web_filter::list_rules))
.route("/api/plugins/web-filter/log", get(web_filter::list_access_log))
// Usage Timer
.route("/api/plugins/usage-timer/daily", get(usage_timer::list_daily))
.route("/api/plugins/usage-timer/app-usage", get(usage_timer::list_app_usage))
.route("/api/plugins/usage-timer/leaderboard", get(usage_timer::leaderboard))
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
// USB File Audit
.route("/api/plugins/usb-file-audit/log", get(usb_file_audit::list_operations))
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
// Watermark
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
}
/// Write plugin routes (admin only — require_admin middleware applied by caller)
pub fn write_routes() -> Router<AppState> {
Router::new()
// Web Filter
.route("/api/plugins/web-filter/rules", post(web_filter::create_rule))
.route("/api/plugins/web-filter/rules/:id", put(web_filter::update_rule).delete(web_filter::delete_rule))
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
// Watermark
.route("/api/plugins/watermark/config", post(watermark::create_config))
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
}

View File

@@ -0,0 +1,155 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String, // "block" | "allow"
pub window_title: Option<String>,
pub window_class: Option<String>,
pub process_name: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"window_title": r.get::<Option<String>,_>("window_title"),
"window_class": r.get::<Option<String>,_>("window_class"),
"process_name": r.get::<Option<String>,_>("process_name"),
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query popup filter rules", e)),
}
}
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if !matches!(req.rule_type.as_str(), "block" | "allow") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
let has_filter = req.window_title.as_ref().map_or(false, |s| !s.is_empty())
|| req.window_class.as_ref().map_or(false, |s| !s.is_empty())
|| req.process_name.as_ref().map_or(false, |s| !s.is_empty());
if !has_filter {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
}
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_popup_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create popup filter rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest { pub window_title: Option<String>, pub window_class: Option<String>, pub process_name: Option<String>, pub enabled: Option<bool> }
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM popup_filter_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query popup filter rule", e)),
};
let window_title = body.window_title.or_else(|| existing.get::<Option<String>, _>("window_title"));
let window_class = body.window_class.or_else(|| existing.get::<Option<String>, _>("window_class"));
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
.bind(&window_title)
.bind(&window_class)
.bind(&process_name)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_popup_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update popup filter rule", e)),
}
}
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM popup_filter_rules WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM popup_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_popup_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
pub async fn list_stats(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT device_uid, blocked_count, date FROM popup_block_stats ORDER BY date DESC LIMIT 30")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"stats": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "blocked_count": r.get::<i32,_>("blocked_count"),
"date": r.get::<String,_>("date")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query popup block stats", e)),
}
}
async fn fetch_popup_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"window_title": r.get::<Option<String>,_>("window_title"),
"window_class": r.get::<Option<String>,_>("window_class"),
"process_name": r.get::<Option<String>,_>("process_name"),
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,155 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateBlacklistRequest {
pub name_pattern: String,
pub category: Option<String>,
pub action: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
"category": r.get::<Option<String>,_>("category"), "action": r.get::<String,_>("action"),
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software blacklist", e)),
}
}
pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<CreateBlacklistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let action = req.action.unwrap_or_else(|| "block".to_string());
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
}
if !matches!(action.as_str(), "block" | "alert") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("action must be 'block' or 'alert'")));
}
if let Some(ref cat) = req.category {
if !matches!(cat.as_str(), "game" | "social" | "vpn" | "mining" | "custom") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid category")));
}
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO software_blacklist (name_pattern, category, action, target_type, target_id) VALUES (?,?,?,?,?)")
.bind(&req.name_pattern).bind(&req.category).bind(&action).bind(&target_type).bind(&req.target_id)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateBlacklistRequest { pub name_pattern: Option<String>, pub action: Option<String>, pub enabled: Option<bool> }
pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateBlacklistRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM software_blacklist WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query software blacklist", e)),
};
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern)
.bind(&action)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update software blacklist", e)),
}
}
pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM software_blacklist WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
#[derive(Debug, Deserialize)]
pub struct ViolationFilters { pub device_uid: Option<String> }
pub async fn list_violations(State(state): State<AppState>, Query(f): Query<ViolationFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, software_name, action_taken, timestamp FROM software_violations WHERE (? IS NULL OR device_uid=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"violations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"software_name": r.get::<String,_>("software_name"), "action_taken": r.get::<String,_>("action_taken"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software violations", e)),
}
}
async fn fetch_blacklist_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"), "action": r.get::<String,_>("action")
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,60 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct DailyFilters { pub device_uid: Option<String>, pub start_date: Option<String>, pub end_date: Option<String> }
pub async fn list_daily(State(state): State<AppState>, Query(f): Query<DailyFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at
FROM usage_daily WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date>=?) AND (? IS NULL OR date<=?)
ORDER BY date DESC LIMIT 90")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.start_date).bind(&f.start_date)
.bind(&f.end_date).bind(&f.end_date)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"daily": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"date": r.get::<String,_>("date"), "total_active_minutes": r.get::<i32,_>("total_active_minutes"),
"total_idle_minutes": r.get::<i32,_>("total_idle_minutes"),
"first_active_at": r.get::<Option<String>,_>("first_active_at"),
"last_active_at": r.get::<Option<String>,_>("last_active_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query daily usage", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct AppUsageFilters { pub device_uid: Option<String>, pub date: Option<String> }
pub async fn list_app_usage(State(state): State<AppState>, Query(f): Query<AppUsageFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, date, app_name, usage_minutes FROM app_usage_daily
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date=?)
ORDER BY usage_minutes DESC LIMIT 100")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.date).bind(&f.date)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"app_usage": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"date": r.get::<String,_>("date"), "app_name": r.get::<String,_>("app_name"),
"usage_minutes": r.get::<i32,_>("usage_minutes")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query app usage", e)),
}
}
pub async fn leaderboard(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT device_uid, SUM(total_active_minutes) as total_minutes FROM usage_daily
WHERE date >= date('now', '-7 days') GROUP BY device_uid ORDER BY total_minutes DESC LIMIT 20")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"leaderboard": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "total_minutes": r.get::<i64,_>("total_minutes")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query usage leaderboard", e)),
}
}

View File

@@ -0,0 +1,47 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct LogFilters {
pub device_uid: Option<String>,
pub operation: Option<String>,
pub usb_serial: Option<String>,
}
pub async fn list_operations(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp
FROM usb_file_operations WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR operation=?) AND (? IS NULL OR usb_serial=?)
ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.operation).bind(&f.operation)
.bind(&f.usb_serial).bind(&f.usb_serial)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"operations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"usb_serial": r.get::<Option<String>,_>("usb_serial"), "drive_letter": r.get::<Option<String>,_>("drive_letter"),
"operation": r.get::<String,_>("operation"), "file_path": r.get::<String,_>("file_path"),
"file_size": r.get::<Option<i64>,_>("file_size"), "timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query USB file operations", e)),
}
}
pub async fn summary(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT device_uid, COUNT(*) as op_count, COUNT(DISTINCT usb_serial) as usb_count,
MIN(timestamp) as first_op, MAX(timestamp) as last_op
FROM usb_file_operations WHERE timestamp >= datetime('now', '-7 days')
GROUP BY device_uid ORDER BY op_count DESC LIMIT 50")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"summary": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "op_count": r.get::<i64,_>("op_count"),
"usb_count": r.get::<i64,_>("usb_count"), "first_op": r.get::<String,_>("first_op"),
"last_op": r.get::<String,_>("last_op")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query USB file audit summary", e)),
}
}

View File

@@ -0,0 +1,186 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::{MessageType, WatermarkConfigPayload};
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateConfigRequest {
pub target_type: Option<String>,
pub target_id: Option<String>,
pub content: Option<String>,
pub font_size: Option<u32>,
pub opacity: Option<f64>,
pub color: Option<String>,
pub angle: Option<i32>,
pub enabled: Option<bool>,
}
pub async fn get_config_list(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, target_type, target_id, content, font_size, opacity, color, angle, enabled, updated_at FROM watermark_config ORDER BY id")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"configs": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "target_type": r.get::<String,_>("target_type"),
"target_id": r.get::<Option<String>,_>("target_id"), "content": r.get::<String,_>("content"),
"font_size": r.get::<i32,_>("font_size"), "opacity": r.get::<f64,_>("opacity"),
"color": r.get::<String,_>("color"), "angle": r.get::<i32,_>("angle"),
"enabled": r.get::<bool,_>("enabled"), "updated_at": r.get::<String,_>("updated_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query watermark configs", e)),
}
}
pub async fn create_config(
State(state): State<AppState>,
Json(req): Json<CreateConfigRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
let content = req.content.unwrap_or_else(|| "{company} | {username} | {date}".to_string());
let font_size = req.font_size.unwrap_or(14).clamp(8, 72) as i32;
let opacity = req.opacity.unwrap_or(0.15).clamp(0.01, 1.0);
let color = req.color.unwrap_or_else(|| "#808080".to_string());
let angle = req.angle.unwrap_or(-30);
let enabled = req.enabled.unwrap_or(true);
// Validate inputs
if content.len() > 200 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("content too long (max 200 chars)")));
}
if !is_valid_hex_color(&color) {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid color format (expected #RRGGBB)")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO watermark_config (target_type, target_id, content, font_size, opacity, color, angle, enabled) VALUES (?,?,?,?,?,?,?,?)")
.bind(&target_type).bind(&req.target_id).bind(&content).bind(font_size).bind(opacity).bind(&color).bind(angle).bind(enabled)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push to online clients
let config = WatermarkConfigPayload {
content: content.clone(),
font_size: font_size as u32,
opacity,
color: color.clone(),
angle,
enabled,
};
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create watermark config", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateConfigRequest {
pub content: Option<String>, pub font_size: Option<u32>, pub opacity: Option<f64>,
pub color: Option<String>, pub angle: Option<i32>, pub enabled: Option<bool>,
}
pub async fn update_config(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateConfigRequest>,
) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM watermark_config WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query watermark config", e)),
};
let content = body.content.unwrap_or_else(|| existing.get::<String, _>("content"));
let font_size = body.font_size.map(|v| v.clamp(8, 72) as i32).unwrap_or_else(|| existing.get::<i32, _>("font_size"));
let opacity = body.opacity.map(|v| v.clamp(0.01, 1.0)).unwrap_or_else(|| existing.get::<f64, _>("opacity"));
let color = body.color.unwrap_or_else(|| existing.get::<String, _>("color"));
let angle = body.angle.unwrap_or_else(|| existing.get::<i32, _>("angle"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate inputs
if content.len() > 200 {
return Json(ApiResponse::error("content too long (max 200 chars)"));
}
if !is_valid_hex_color(&color) {
return Json(ApiResponse::error("invalid color format (expected #RRGGBB)"));
}
let result = sqlx::query(
"UPDATE watermark_config SET content = ?, font_size = ?, opacity = ?, color = ?, angle = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&content)
.bind(font_size)
.bind(opacity)
.bind(&color)
.bind(angle)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
// Push updated config to online clients
let config = WatermarkConfigPayload {
content: content.clone(),
font_size: font_size as u32,
opacity,
color: color.clone(),
angle,
enabled,
};
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update watermark config", e)),
}
}
pub async fn delete_config(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
// Fetch existing config to get target info for push
let existing = sqlx::query("SELECT target_type, target_id FROM watermark_config WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM watermark_config WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
// Push disabled watermark to clients
let disabled = WatermarkConfigPayload {
content: String::new(),
font_size: 0,
opacity: 0.0,
color: String::new(),
angle: 0,
enabled: false,
};
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &disabled, &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
/// Validate a hex color string (#RRGGBB format)
fn is_valid_hex_color(color: &str) -> bool {
if color.len() != 7 || !color.starts_with('#') {
return false;
}
color[1..].chars().all(|c| c.is_ascii_hexdigit())
}

View File

@@ -0,0 +1,156 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String,
pub pattern: String,
pub target_type: Option<String>,
pub target_id: Option<String>,
pub enabled: Option<bool>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"pattern": r.get::<String,_>("pattern"), "target_type": r.get::<String,_>("target_type"),
"target_id": r.get::<Option<String>,_>("target_id"), "enabled": r.get::<bool,_>("enabled"),
"created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>() }))),
Err(e) => Json(ApiResponse::internal_error("query web filter rules", e)),
}
}
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = req.enabled.unwrap_or(true);
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if !matches!(req.rule_type.as_str(), "blacklist" | "whitelist" | "category") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid rule_type (expected blacklist|whitelist|category)")));
}
if req.pattern.trim().is_empty() || req.pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("pattern must be 1-255 chars")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO web_filter_rules (rule_type, pattern, target_type, target_id, enabled) VALUES (?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.pattern).bind(&target_type).bind(&req.target_id).bind(enabled)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create web filter rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest { pub rule_type: Option<String>, pub pattern: Option<String>, pub enabled: Option<bool> }
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM web_filter_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query web filter rule", e)),
};
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
.bind(&rule_type)
.bind(&pattern)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update web filter rule", e)),
}
}
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM web_filter_rules WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM web_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
#[derive(Debug, Deserialize)]
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>() }))),
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
}
}
/// Fetch enabled web filter rules applicable to a given target scope.
/// For "device" targets, includes both global rules and device-specific rules
/// (matching the logic used during registration push in tcp.rs).
async fn fetch_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"), "pattern": r.get::<String,_>("pattern")
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,246 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
use crate::tcp::push_to_targets;
use csm_protocol::{MessageType, UsbPolicyPayload, UsbDeviceRule};
#[derive(Debug, Deserialize)]
pub struct UsbEventListParams {
pub device_uid: Option<String>,
pub event_type: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_events(
State(state): State<AppState>,
Query(params): Query<UsbEventListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let event_type = params.event_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, vendor_id, product_id, serial_number, device_name, event_type, event_time
FROM usb_events WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR event_type = ?)
ORDER BY event_time DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&event_type).bind(&event_type)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"vendor_id": r.get::<Option<String>, _>("vendor_id"),
"product_id": r.get::<Option<String>, _>("product_id"),
"serial_number": r.get::<Option<String>, _>("serial_number"),
"device_name": r.get::<Option<String>, _>("device_name"),
"event_type": r.get::<String, _>("event_type"),
"event_time": r.get::<String, _>("event_time"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"events": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query usb events", e)),
}
}
pub async fn list_policies(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
FROM usb_policies ORDER BY created_at DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name": r.get::<String, _>("name"),
"policy_type": r.get::<String, _>("policy_type"),
"target_group": r.get::<Option<String>, _>("target_group"),
"rules": r.get::<String, _>("rules"),
"enabled": r.get::<i32, _>("enabled"),
"created_at": r.get::<String, _>("created_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"policies": items,
})))
}
Err(e) => Json(ApiResponse::internal_error("query usb policies", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreatePolicyRequest {
pub name: String,
pub policy_type: String,
pub target_group: Option<String>,
pub rules: String,
pub enabled: Option<i32>,
}
pub async fn create_policy(
State(state): State<AppState>,
Json(body): Json<CreatePolicyRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = body.enabled.unwrap_or(1);
let result = sqlx::query(
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
)
.bind(&body.name)
.bind(&body.policy_type)
.bind(&body.target_group)
.bind(&body.rules)
.bind(enabled)
.execute(&state.db)
.await;
match result {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push USB policy to matching online clients
if enabled == 1 {
let payload = build_usb_policy_payload(&body.policy_type, true, &body.rules);
let target_group = body.target_group.as_deref();
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
}
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
"id": new_id,
}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create usb policy", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdatePolicyRequest {
pub name: Option<String>,
pub policy_type: Option<String>,
pub target_group: Option<String>,
pub rules: Option<String>,
pub enabled: Option<i32>,
}
pub async fn update_policy(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdatePolicyRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
// Fetch existing policy
let existing = sqlx::query("SELECT * FROM usb_policies WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Policy not found")),
Err(e) => return Json(ApiResponse::internal_error("query usb policy", e)),
};
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
let policy_type = body.policy_type.unwrap_or_else(|| existing.get::<String, _>("policy_type"));
let target_group = body.target_group.or_else(|| existing.get::<Option<String>, _>("target_group"));
let rules = body.rules.unwrap_or_else(|| existing.get::<String, _>("rules"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
let result = sqlx::query(
"UPDATE usb_policies SET name = ?, policy_type = ?, target_group = ?, rules = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&name)
.bind(&policy_type)
.bind(&target_group)
.bind(&rules)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(_) => {
// Push updated USB policy to matching online clients
let payload = build_usb_policy_payload(&policy_type, enabled == 1, &rules);
let target_group = target_group.as_deref();
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
Json(ApiResponse::ok(serde_json::json!({"updated": true})))
}
Err(e) => Json(ApiResponse::internal_error("update usb policy", e)),
}
}
pub async fn delete_policy(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
// Fetch existing policy to get target info for push
let existing = sqlx::query("SELECT target_group FROM usb_policies WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let target_group = match existing {
Ok(Some(row)) => row.get::<Option<String>, _>("target_group"),
_ => return Json(ApiResponse::error("Policy not found")),
};
let result = sqlx::query("DELETE FROM usb_policies WHERE id = ?")
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
// Push disabled policy to clients
let disabled = UsbPolicyPayload {
policy_type: String::new(),
enabled: false,
rules: vec![],
};
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &disabled, "group", target_group.as_deref()).await;
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
} else {
Json(ApiResponse::error("Policy not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("delete usb policy", e)),
}
}
/// Build a UsbPolicyPayload from raw policy fields
fn build_usb_policy_payload(policy_type: &str, enabled: bool, rules_json: &str) -> UsbPolicyPayload {
let raw_rules: Vec<serde_json::Value> = serde_json::from_str(rules_json).unwrap_or_default();
let rules: Vec<UsbDeviceRule> = raw_rules.iter().map(|r| UsbDeviceRule {
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
}).collect();
UsbPolicyPayload {
policy_type: policy_type.to_string(),
enabled,
rules,
}
}

View File

@@ -0,0 +1,28 @@
use sqlx::SqlitePool;
use tracing::debug;
/// Record an admin audit log entry.
pub async fn audit_log(
db: &SqlitePool,
user_id: i64,
action: &str,
target_type: Option<&str>,
target_id: Option<&str>,
detail: Option<&str>,
) {
let result = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, target_type, target_id, detail) VALUES (?, ?, ?, ?, ?)"
)
.bind(user_id)
.bind(action)
.bind(target_type)
.bind(target_id)
.bind(detail)
.execute(db)
.await;
match result {
Ok(_) => debug!("Audit: user={} action={} target={}/{}", user_id, action, target_type.unwrap_or("-"), target_id.unwrap_or("-")),
Err(e) => tracing::warn!("Failed to write audit log: {}", e),
}
}

134
crates/server/src/config.rs Normal file
View File

@@ -0,0 +1,134 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub retention: RetentionConfig,
#[serde(default)]
pub notify: NotifyConfig,
/// Token required for device registration. Empty = any token accepted.
#[serde(default)]
pub registration_token: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ServerConfig {
pub http_addr: String,
pub tcp_addr: String,
/// Allowed CORS origins. Empty = same-origin only (no CORS headers).
#[serde(default)]
pub cors_origins: Vec<String>,
/// Optional TLS configuration for the TCP listener.
#[serde(default)]
pub tls: Option<TlsConfig>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TlsConfig {
/// Path to the server certificate (PEM format)
pub cert_path: String,
/// Path to the server private key (PEM format)
pub key_path: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DatabaseConfig {
pub path: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AuthConfig {
pub jwt_secret: String,
#[serde(default = "default_access_ttl")]
pub access_token_ttl_secs: u64,
#[serde(default = "default_refresh_ttl")]
pub refresh_token_ttl_secs: u64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RetentionConfig {
#[serde(default = "default_status_history_days")]
pub status_history_days: u32,
#[serde(default = "default_usb_events_days")]
pub usb_events_days: u32,
#[serde(default = "default_asset_changes_days")]
pub asset_changes_days: u32,
#[serde(default = "default_alert_records_days")]
pub alert_records_days: u32,
#[serde(default = "default_audit_log_days")]
pub audit_log_days: u32,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct NotifyConfig {
#[serde(default)]
pub smtp: Option<SmtpConfig>,
#[serde(default)]
pub webhook_urls: Vec<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SmtpConfig {
pub host: String,
pub port: u16,
pub username: String,
pub password: String,
pub from: String,
}
impl AppConfig {
pub async fn load(path: &str) -> Result<Self> {
if Path::new(path).exists() {
let content = tokio::fs::read_to_string(path).await?;
let config: AppConfig = toml::from_str(&content)?;
Ok(config)
} else {
let config = default_config();
// Write default config for reference
let toml_str = toml::to_string_pretty(&config)?;
tokio::fs::write(path, &toml_str).await?;
tracing::warn!("Created default config at {}", path);
Ok(config)
}
}
}
fn default_access_ttl() -> u64 { 1800 } // 30 minutes
fn default_refresh_ttl() -> u64 { 604800 } // 7 days
fn default_status_history_days() -> u32 { 7 }
fn default_usb_events_days() -> u32 { 90 }
fn default_asset_changes_days() -> u32 { 365 }
fn default_alert_records_days() -> u32 { 90 }
fn default_audit_log_days() -> u32 { 365 }
pub fn default_config() -> AppConfig {
AppConfig {
server: ServerConfig {
http_addr: "0.0.0.0:8080".into(),
tcp_addr: "0.0.0.0:9999".into(),
cors_origins: vec![],
tls: None,
},
database: DatabaseConfig {
path: "./csm.db".into(),
},
auth: AuthConfig {
jwt_secret: uuid::Uuid::new_v4().to_string(),
access_token_ttl_secs: default_access_ttl(),
refresh_token_ttl_secs: default_refresh_ttl(),
},
retention: RetentionConfig {
status_history_days: default_status_history_days(),
usb_events_days: default_usb_events_days(),
asset_changes_days: default_asset_changes_days(),
alert_records_days: default_alert_records_days(),
audit_log_days: default_audit_log_days(),
},
notify: NotifyConfig::default(),
registration_token: uuid::Uuid::new_v4().to_string(),
}
}

118
crates/server/src/db.rs Normal file
View File

@@ -0,0 +1,118 @@
use sqlx::SqlitePool;
use anyhow::Result;
/// Database repository for device operations
pub struct DeviceRepo;
impl DeviceRepo {
pub async fn upsert_status(pool: &SqlitePool, device_uid: &str, status: &csm_protocol::DeviceStatus) -> Result<()> {
let top_procs_json = serde_json::to_string(&status.top_processes)?;
// Update latest snapshot
sqlx::query(
"INSERT INTO device_status (device_uid, cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb, network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid) DO UPDATE SET
cpu_usage = excluded.cpu_usage,
memory_usage = excluded.memory_usage,
memory_total_mb = excluded.memory_total_mb,
disk_usage = excluded.disk_usage,
disk_total_mb = excluded.disk_total_mb,
network_rx_rate = excluded.network_rx_rate,
network_tx_rate = excluded.network_tx_rate,
running_procs = excluded.running_procs,
top_processes = excluded.top_processes,
reported_at = datetime('now'),
updated_at = datetime('now')"
)
.bind(device_uid)
.bind(status.cpu_usage)
.bind(status.memory_usage)
.bind(status.memory_total_mb as i64)
.bind(status.disk_usage)
.bind(status.disk_total_mb as i64)
.bind(status.network_rx_rate as i64)
.bind(status.network_tx_rate as i64)
.bind(status.running_procs as i32)
.bind(&top_procs_json)
.execute(pool)
.await?;
// Insert into history
sqlx::query(
"INSERT INTO device_status_history (device_uid, cpu_usage, memory_usage, disk_usage, network_rx_rate, network_tx_rate, running_procs, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(device_uid)
.bind(status.cpu_usage)
.bind(status.memory_usage)
.bind(status.disk_usage)
.bind(status.network_rx_rate as i64)
.bind(status.network_tx_rate as i64)
.bind(status.running_procs as i32)
.execute(pool)
.await?;
// Update device heartbeat
sqlx::query(
"UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?"
)
.bind(device_uid)
.execute(pool)
.await?;
Ok(())
}
pub async fn insert_usb_event(pool: &SqlitePool, event: &csm_protocol::UsbEvent) -> Result<i64> {
let result = sqlx::query(
"INSERT INTO usb_events (device_uid, vendor_id, product_id, serial_number, device_name, event_type)
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&event.device_uid)
.bind(&event.vendor_id)
.bind(&event.product_id)
.bind(&event.serial)
.bind(&event.device_name)
.bind(match event.event_type {
csm_protocol::UsbEventType::Inserted => "inserted",
csm_protocol::UsbEventType::Removed => "removed",
csm_protocol::UsbEventType::Blocked => "blocked",
})
.execute(pool)
.await?;
Ok(result.last_insert_rowid())
}
pub async fn upsert_hardware(pool: &SqlitePool, asset: &csm_protocol::HardwareAsset) -> Result<()> {
sqlx::query(
"INSERT INTO hardware_assets (device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb, gpu_model, motherboard, serial_number, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid) DO UPDATE SET
cpu_model = excluded.cpu_model,
cpu_cores = excluded.cpu_cores,
memory_total_mb = excluded.memory_total_mb,
disk_model = excluded.disk_model,
disk_total_mb = excluded.disk_total_mb,
gpu_model = excluded.gpu_model,
motherboard = excluded.motherboard,
serial_number = excluded.serial_number,
reported_at = datetime('now'),
updated_at = datetime('now')"
)
.bind(&asset.device_uid)
.bind(&asset.cpu_model)
.bind(asset.cpu_cores as i32)
.bind(asset.memory_total_mb as i64)
.bind(&asset.disk_model)
.bind(asset.disk_total_mb as i64)
.bind(&asset.gpu_model)
.bind(&asset.motherboard)
.bind(&asset.serial_number)
.execute(pool)
.await?;
Ok(())
}
}

264
crates/server/src/main.rs Normal file
View File

@@ -0,0 +1,264 @@
use anyhow::Result;
use axum::Router;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{CorsLayer, Any};
use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tracing::{info, warn, error};
mod api;
mod audit;
mod config;
mod db;
mod tcp;
mod ws;
mod alert;
use config::AppConfig;
/// Application shared state
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::SqlitePool,
pub config: Arc<AppConfig>,
pub clients: Arc<tcp::ClientRegistry>,
pub ws_hub: Arc<ws::WsHub>,
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "csm_server=info,tower_http=info".into()),
)
.json()
.init();
info!("CSM Server starting...");
// Load configuration
let config = AppConfig::load("config.toml").await?;
let config = Arc::new(config);
// Initialize database
let db = init_database(&config.database.path).await?;
run_migrations(&db).await?;
info!("Database initialized at {}", config.database.path);
// Ensure default admin exists
ensure_default_admin(&db).await?;
// Initialize shared state
let clients = Arc::new(tcp::ClientRegistry::new());
let ws_hub = Arc::new(ws::WsHub::new());
let state = AppState {
db: db.clone(),
config: config.clone(),
clients: clients.clone(),
ws_hub: ws_hub.clone(),
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
};
// Start background tasks
let cleanup_state = state.clone();
tokio::spawn(async move {
alert::cleanup_task(cleanup_state).await;
});
// Start TCP listener for client connections
let tcp_state = state.clone();
let tcp_addr = config.server.tcp_addr.clone();
tokio::spawn(async move {
if let Err(e) = tcp::start_tcp_server(tcp_addr, tcp_state).await {
error!("TCP server error: {}", e);
}
});
// Build HTTP router
let app = Router::new()
.merge(api::routes(state.clone()))
.layer(
build_cors_layer(&config.server.cors_origins),
)
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
// Security headers
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::X_CONTENT_TYPE_OPTIONS,
axum::http::HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::X_FRAME_OPTIONS,
axum::http::HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("x-xss-protection"),
axum::http::HeaderValue::from_static("1; mode=block"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("referrer-policy"),
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("content-security-policy"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
))
.with_state(state);
// Start HTTP server
let http_addr = &config.server.http_addr;
info!("HTTP server listening on {}", http_addr);
let listener = TcpListener::bind(http_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn init_database(db_path: &str) -> Result<sqlx::SqlitePool> {
// Ensure parent directory exists for file-based databases
// Strip sqlite: prefix if present for directory creation
let file_path = db_path.strip_prefix("sqlite:").unwrap_or(db_path);
// Strip query parameters
let file_path = file_path.split('?').next().unwrap_or(file_path);
if let Some(parent) = Path::new(file_path).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
let options = SqliteConnectOptions::from_str(db_path)?
.journal_mode(SqliteJournalMode::Wal)
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
.busy_timeout(std::time::Duration::from_secs(5))
.foreign_keys(true);
let pool = SqlitePoolOptions::new()
.max_connections(8)
.connect_with(options)
.await?;
// Set pragmas on each connection
sqlx::query("PRAGMA cache_size = -64000")
.execute(&pool)
.await?;
sqlx::query("PRAGMA wal_autocheckpoint = 1000")
.execute(&pool)
.await?;
Ok(pool)
}
async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
// Embedded migrations - run in order
let migrations = [
include_str!("../../../migrations/001_init.sql"),
include_str!("../../../migrations/002_assets.sql"),
include_str!("../../../migrations/003_usb.sql"),
include_str!("../../../migrations/004_alerts.sql"),
include_str!("../../../migrations/005_plugins_web_filter.sql"),
include_str!("../../../migrations/006_plugins_usage_timer.sql"),
include_str!("../../../migrations/007_plugins_software_blocker.sql"),
include_str!("../../../migrations/008_plugins_popup_blocker.sql"),
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
include_str!("../../../migrations/010_plugins_watermark.sql"),
include_str!("../../../migrations/011_token_security.sql"),
];
// Create migrations tracking table
sqlx::query(
"CREATE TABLE IF NOT EXISTS _migrations (id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, applied_at TEXT NOT NULL DEFAULT (datetime('now')))"
)
.execute(pool)
.await?;
for (i, migration_sql) in migrations.iter().enumerate() {
let name = format!("{:03}", i + 1);
let exists: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM _migrations WHERE name = ?"
)
.bind(&name)
.fetch_one(pool)
.await? > 0;
if !exists {
info!("Running migration: {}", name);
sqlx::query(migration_sql)
.execute(pool)
.await?;
sqlx::query("INSERT INTO _migrations (name) VALUES (?)")
.bind(&name)
.execute(pool)
.await?;
}
}
Ok(())
}
async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users")
.fetch_one(pool)
.await?;
if count == 0 {
// Generate a random 16-character alphanumeric password
let random_password: String = {
use std::fmt::Write;
let bytes = uuid::Uuid::new_v4();
let mut s = String::with_capacity(16);
for byte in bytes.as_bytes().iter().take(16) {
write!(s, "{:02x}", byte).unwrap();
}
s
};
let password_hash = bcrypt::hash(&random_password, 12)?;
sqlx::query(
"INSERT INTO users (username, password, role) VALUES (?, ?, 'admin')"
)
.bind("admin")
.bind(&password_hash)
.execute(pool)
.await?;
warn!("Created default admin user (username: admin)");
// Print password directly to stderr — bypasses tracing JSON formatter
eprintln!("============================================================");
eprintln!(" Generated admin password: {}", random_password);
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
eprintln!("============================================================");
}
Ok(())
}
/// Build CORS layer from configured origins.
/// If cors_origins is empty, no CORS headers are sent (same-origin only).
/// If origins are specified, only those are allowed.
fn build_cors_layer(origins: &[String]) -> CorsLayer {
use axum::http::HeaderValue;
let allowed_origins: Vec<HeaderValue> = origins.iter()
.filter_map(|o| o.parse::<HeaderValue>().ok())
.collect();
if allowed_origins.is_empty() {
// No CORS — production safe by default
CorsLayer::new()
} else {
CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
.allow_methods(Any)
.allow_headers(Any)
.max_age(std::time::Duration::from_secs(3600))
}
}

844
crates/server/src/tcp.rs Normal file
View File

@@ -0,0 +1,844 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tokio::net::{TcpListener, TcpStream};
use tracing::{info, warn, debug};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use csm_protocol::{Frame, MessageType, PROTOCOL_VERSION};
use crate::AppState;
/// Maximum frames per second per connection before rate-limiting kicks in
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
const RATE_LIMIT_MAX_FRAMES: usize = 100;
/// Per-connection rate limiter using a sliding window of frame timestamps
struct RateLimiter {
timestamps: Vec<Instant>,
}
impl RateLimiter {
fn new() -> Self {
Self { timestamps: Vec::with_capacity(RATE_LIMIT_MAX_FRAMES) }
}
/// Returns false if the connection is rate-limited
fn check(&mut self) -> bool {
let now = Instant::now();
let cutoff = now - std::time::Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
// Evict timestamps outside the window
self.timestamps.retain(|t| *t > cutoff);
if self.timestamps.len() >= RATE_LIMIT_MAX_FRAMES {
return false;
}
self.timestamps.push(now);
true
}
}
/// Push a plugin config frame to all online clients matching the target scope.
/// target_type: "global" | "device" | "group"
/// target_id: device_uid or group_name (None for global)
pub async fn push_to_targets(
db: &sqlx::SqlitePool,
clients: &crate::tcp::ClientRegistry,
msg_type: MessageType,
payload: &impl serde::Serialize,
target_type: &str,
target_id: Option<&str>,
) {
let frame = match Frame::new_json(msg_type, payload) {
Ok(f) => f.encode(),
Err(e) => {
warn!("Failed to encode plugin push frame: {}", e);
return;
}
};
let online = clients.list_online().await;
let mut pushed_count = 0usize;
// For group targeting, resolve group members from DB once
let group_members: Option<Vec<String>> = if target_type == "group" {
if let Some(group_name) = target_id {
sqlx::query_scalar::<_, String>(
"SELECT device_uid FROM devices WHERE group_name = ?"
)
.bind(group_name)
.fetch_all(db)
.await
.ok()
.into()
} else {
None
}
} else {
None
};
for uid in &online {
let should_push = match target_type {
"global" => true,
"device" => target_id.map_or(false, |id| id == uid),
"group" => {
if let Some(members) = &group_members {
members.contains(uid)
} else {
false
}
}
other => {
warn!("Unknown target_type '{}', skipping push", other);
false
}
};
if should_push {
if clients.send_to(uid, frame.clone()).await {
pushed_count += 1;
}
}
}
debug!("Pushed {:?} to {}/{} online clients (target={})", msg_type, pushed_count, online.len(), target_type);
}
/// Push all active plugin configs to a newly registered client.
pub async fn push_all_plugin_configs(
db: &sqlx::SqlitePool,
clients: &crate::tcp::ClientRegistry,
device_uid: &str,
) {
use sqlx::Row;
// Watermark configs — only push the highest-priority enabled config (device > group > global)
if let Ok(rows) = sqlx::query(
"SELECT content, font_size, opacity, color, angle, enabled FROM watermark_config WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?))) ORDER BY CASE WHEN target_type = 'device' THEN 0 WHEN target_type = 'group' THEN 1 ELSE 2 END LIMIT 1"
)
.bind(device_uid)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
if let Some(row) = rows.first() {
let config = csm_protocol::WatermarkConfigPayload {
content: row.get("content"),
font_size: row.get::<i32, _>("font_size") as u32,
opacity: row.get("opacity"),
color: row.get("color"),
angle: row.get::<i32, _>("angle"),
enabled: row.get("enabled"),
};
if let Ok(frame) = Frame::new_json(MessageType::WatermarkConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Web filter rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"pattern": r.get::<String, _>("pattern"),
})).collect();
if !rules.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Software blacklist
if let Ok(rows) = sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let entries: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name_pattern": r.get::<String, _>("name_pattern"),
"action": r.get::<String, _>("action"),
})).collect();
if !entries.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Popup blocker rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"window_title": r.get::<Option<String>, _>("window_title"),
"window_class": r.get::<Option<String>, _>("window_class"),
"process_name": r.get::<Option<String>, _>("process_name"),
})).collect();
if !rules.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::PopupRules, &serde_json::json!({"rules": rules})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// USB policies — push highest-priority enabled policy for the device's group
if let Ok(rows) = sqlx::query(
"SELECT policy_type, rules, enabled FROM usb_policies WHERE enabled = 1 AND target_group = (SELECT group_name FROM devices WHERE device_uid = ?) ORDER BY CASE WHEN policy_type = 'all_block' THEN 0 WHEN policy_type = 'blacklist' THEN 1 ELSE 2 END LIMIT 1"
)
.bind(device_uid)
.fetch_all(db).await
{
if let Some(row) = rows.first() {
let policy_type: String = row.get("policy_type");
let rules_json: String = row.get("rules");
let rules: Vec<serde_json::Value> = serde_json::from_str(&rules_json).unwrap_or_default();
let payload = csm_protocol::UsbPolicyPayload {
policy_type,
enabled: true,
rules: rules.iter().map(|r| csm_protocol::UsbDeviceRule {
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
}).collect(),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbPolicyUpdate, &payload) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
info!("Pushed all plugin configs to newly registered device {}", device_uid);
}
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Registry of all connected client sessions
#[derive(Clone, Default)]
pub struct ClientRegistry {
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
}
impl ClientRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, tx);
}
pub async fn unregister(&self, device_uid: &str) {
self.sessions.write().await.remove(device_uid);
}
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
if let Some(tx) = self.sessions.read().await.get(device_uid) {
tx.send(data).await.is_ok()
} else {
false
}
}
pub async fn count(&self) -> usize {
self.sessions.read().await.len()
}
pub async fn list_online(&self) -> Vec<String> {
self.sessions.read().await.keys().cloned().collect()
}
}
/// Start the TCP server for client connections (optionally with TLS)
pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<()> {
let listener = TcpListener::bind(&addr).await?;
// Build TLS acceptor if configured
let tls_acceptor = build_tls_acceptor(&state.config.server.tls)?;
if tls_acceptor.is_some() {
info!("TCP server listening on {} (TLS enabled)", addr);
} else {
info!("TCP server listening on {} (plaintext)", addr);
}
loop {
let (stream, peer_addr) = listener.accept().await?;
let state = state.clone();
let acceptor = tls_acceptor.clone();
tokio::spawn(async move {
debug!("New TCP connection from {}", peer_addr);
match acceptor {
Some(acceptor) => {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_client_tls(tls_stream, state).await {
warn!("Client {} TLS error: {}", peer_addr, e);
}
}
Err(e) => warn!("TLS handshake failed for {}: {}", peer_addr, e),
}
}
None => {
if let Err(e) = handle_client(stream, state).await {
warn!("Client {} error: {}", peer_addr, e);
}
}
}
});
}
}
fn build_tls_acceptor(
tls_config: &Option<crate::config::TlsConfig>,
) -> anyhow::Result<Option<tokio_rustls::TlsAcceptor>> {
let config = match tls_config {
Some(c) => c,
None => return Ok(None),
};
let cert_pem = std::fs::read(&config.cert_path)
.map_err(|e| anyhow::anyhow!("Failed to read TLS cert {}: {}", config.cert_path, e))?;
let key_pem = std::fs::read(&config.key_path)
.map_err(|e| anyhow::anyhow!("Failed to read TLS key {}: {}", config.key_path, e))?;
let certs: Vec<rustls_pki_types::CertificateDer> = rustls_pemfile::certs(&mut &cert_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse TLS cert: {:?}", e))?
.into_iter()
.map(|c| c.into())
.collect();
let key = rustls_pemfile::private_key(&mut &key_pem[..])
.map_err(|e| anyhow::anyhow!("Failed to parse TLS key: {:?}", e))?
.ok_or_else(|| anyhow::anyhow!("No private key found in {}", config.key_path))?;
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| anyhow::anyhow!("Failed to build TLS config: {}", e))?;
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(server_config))))
}
/// Cleanup on client disconnect: unregister from client map, mark offline, notify WS.
async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
if let Some(uid) = device_uid {
state.clients.unregister(uid).await;
sqlx::query("UPDATE devices SET status = 'offline' WHERE device_uid = ?")
.bind(uid)
.execute(&state.db)
.await
.ok();
state.ws_hub.broadcast(serde_json::json!({
"type": "device_state",
"device_uid": uid,
"status": "offline"
}).to_string()).await;
info!("Device disconnected: {}", uid);
}
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
/// Verify that a frame sender is a registered device and that the claimed device_uid
/// matches the one registered on this connection. Returns true if valid.
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
match device_uid {
Some(uid) if *uid == claimed_uid => true,
Some(uid) => {
warn!("{} device_uid mismatch: expected {:?}, got {}", msg_type, uid, claimed_uid);
false
}
None => {
warn!("{} from unregistered connection", msg_type);
false
}
}
}
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
async fn process_frame(
frame: Frame,
state: &AppState,
device_uid: &mut Option<String>,
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
) -> anyhow::Result<()> {
match frame.msg_type {
MessageType::Register => {
let req: csm_protocol::RegisterRequest = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration payload: {}", e))?;
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
// Validate registration token against configured token
let expected_token = &state.config.registration_token;
if !expected_token.is_empty() {
if req.registration_token.is_empty() || req.registration_token != *expected_token {
warn!("Registration rejected for {}: invalid token", req.device_uid);
let err_frame = Frame::new_json(MessageType::RegisterResponse,
&serde_json::json!({"error": "invalid_registration_token"}))?;
tx.send(err_frame.encode()).await.ok();
return Ok(());
}
}
// Check if device already exists with a secret (reconnection scenario)
let existing_secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&req.device_uid)
.fetch_optional(&state.db)
.await
.ok()
.flatten();
let device_secret = match existing_secret {
// Existing device — keep the same secret, don't rotate
Some(secret) if !secret.is_empty() => secret,
// New device — generate a fresh secret
_ => uuid::Uuid::new_v4().to_string(),
};
sqlx::query(
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
mac_address=excluded.mac_address, status='online'"
)
.bind(&req.device_uid)
.bind(&req.hostname)
.bind(&req.mac_address)
.bind(&req.os_version)
.bind(&device_secret)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error during registration: {}", e))?;
*device_uid = Some(req.device_uid.clone());
// If this device was already connected on a different session, evict the old one
// The new register() call will replace it in the hashmap
state.clients.register(req.device_uid.clone(), tx.clone()).await;
// Send registration response
let config = csm_protocol::ClientConfig::default();
let response = csm_protocol::RegisterResponse {
device_secret,
config,
};
let resp_frame = Frame::new_json(MessageType::RegisterResponse, &response)?;
tx.send(resp_frame.encode()).await?;
info!("Device registered successfully: {} ({})", req.hostname, req.device_uid);
// Push all active plugin configs to newly registered client
push_all_plugin_configs(&state.db, &state.clients, &req.device_uid).await;
}
MessageType::Heartbeat => {
let heartbeat: csm_protocol::HeartbeatPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid heartbeat: {}", e))?;
if !verify_device_uid(device_uid, "Heartbeat", &heartbeat.device_uid) {
return Ok(());
}
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
let secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
if let Some(ref secret) = secret {
if !secret.is_empty() {
if heartbeat.hmac.is_empty() {
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
return Ok(());
}
// Constant-time HMAC verification using hmac::Mac::verify_slice
let message = format!("{}\n{}", heartbeat.device_uid, heartbeat.timestamp);
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
.map_err(|_| anyhow::anyhow!("HMAC key error"))?;
mac.update(message.as_bytes());
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
if mac.verify_slice(&provided_bytes).is_err() {
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
return Ok(());
}
}
}
debug!("Heartbeat from {} (hmac verified)", heartbeat.device_uid);
// Update device status in DB
sqlx::query("UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?")
.bind(&heartbeat.device_uid)
.execute(&state.db)
.await
.ok();
// Push to WebSocket subscribers
state.ws_hub.broadcast(serde_json::json!({
"type": "device_state",
"device_uid": heartbeat.device_uid,
"status": "online"
}).to_string()).await;
}
MessageType::StatusReport => {
let status: csm_protocol::DeviceStatus = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid status report: {}", e))?;
if !verify_device_uid(device_uid, "StatusReport", &status.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::upsert_status(&state.db, &status.device_uid, &status).await?;
// Push to WebSocket subscribers
state.ws_hub.broadcast(serde_json::json!({
"type": "device_status",
"device_uid": status.device_uid,
"cpu": status.cpu_usage,
"memory": status.memory_usage,
"disk": status.disk_usage
}).to_string()).await;
}
MessageType::UsbEvent => {
let event: csm_protocol::UsbEvent = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB event: {}", e))?;
if !verify_device_uid(device_uid, "UsbEvent", &event.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::insert_usb_event(&state.db, &event).await?;
state.ws_hub.broadcast(serde_json::json!({
"type": "usb_event",
"device_uid": event.device_uid,
"event": event.event_type,
"usb_name": event.device_name
}).to_string()).await;
}
MessageType::AssetReport => {
let asset: csm_protocol::HardwareAsset = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid asset report: {}", e))?;
if !verify_device_uid(device_uid, "AssetReport", &asset.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::upsert_hardware(&state.db, &asset).await?;
}
MessageType::UsageReport => {
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
if !verify_device_uid(device_uid, "UsageReport", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO usage_daily (device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at) \
VALUES (?, ?, ?, ?, ?, ?) \
ON CONFLICT(device_uid, date) DO UPDATE SET \
total_active_minutes = excluded.total_active_minutes, \
total_idle_minutes = excluded.total_idle_minutes, \
first_active_at = excluded.first_active_at, \
last_active_at = excluded.last_active_at"
)
.bind(&report.device_uid)
.bind(&report.date)
.bind(report.total_active_minutes as i32)
.bind(report.total_idle_minutes as i32)
.bind(&report.first_active_at)
.bind(&report.last_active_at)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting usage report: {}", e))?;
debug!("Usage report saved for device {}", report.device_uid);
}
MessageType::AppUsageReport => {
let report: csm_protocol::AppUsageEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid app usage report: {}", e))?;
if !verify_device_uid(device_uid, "AppUsageReport", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
VALUES (?, ?, ?, ?) \
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
)
.bind(&report.device_uid)
.bind(&report.date)
.bind(&report.app_name)
.bind(report.usage_minutes as i32)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting app usage: {}", e))?;
debug!("App usage saved: {} -> {} ({} min)", report.device_uid, report.app_name, report.usage_minutes);
}
MessageType::SoftwareViolation => {
let report: csm_protocol::SoftwareViolationReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software violation: {}", e))?;
if !verify_device_uid(device_uid, "SoftwareViolation", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO software_violations (device_uid, software_name, action_taken, timestamp) VALUES (?, ?, ?, ?)"
)
.bind(&report.device_uid)
.bind(&report.software_name)
.bind(&report.action_taken)
.bind(&report.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting software violation: {}", e))?;
info!("Software violation: {} tried to run {} -> {}", report.device_uid, report.software_name, report.action_taken);
state.ws_hub.broadcast(serde_json::json!({
"type": "software_violation",
"device_uid": report.device_uid,
"software_name": report.software_name,
"action_taken": report.action_taken
}).to_string()).await;
}
MessageType::UsbFileOp => {
let entry: csm_protocol::UsbFileOpEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB file op: {}", e))?;
if !verify_device_uid(device_uid, "UsbFileOp", &entry.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO usb_file_operations (device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)"
)
.bind(&entry.device_uid)
.bind(&entry.usb_serial)
.bind(&entry.drive_letter)
.bind(&entry.operation)
.bind(&entry.file_path)
.bind(entry.file_size.map(|s| s as i64))
.bind(&entry.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting USB file op: {}", e))?;
debug!("USB file op: {} {} on {}", entry.operation, entry.file_path, entry.device_uid);
}
MessageType::WebAccessLog => {
let entry: csm_protocol::WebAccessLogEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web access log: {}", e))?;
if !verify_device_uid(device_uid, "WebAccessLog", &entry.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO web_access_log (device_uid, url, action, timestamp) VALUES (?, ?, ?, ?)"
)
.bind(&entry.device_uid)
.bind(&entry.url)
.bind(&entry.action)
.bind(&entry.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting web access log: {}", e))?;
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
/// Handle a single client TCP connection
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Set read timeout to detect stale connections
let _ = stream.set_nodelay(true);
let (mut reader, mut writer) = stream.into_split();
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
// Writer task: forwards messages from channel to TCP stream
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
}
});
// Reader loop
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break; // Connection closed
}
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("Connection exceeded max buffer size, dropping");
break;
}
// Process complete frames
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
// Remove consumed bytes without reallocating
read_buf.drain(..frame_size);
// Rate limit check
if !rate_limiter.check() {
warn!("Rate limit exceeded for device {:?}, dropping connection", device_uid);
break 'reader;
}
// Verify protocol version
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
warn!("Frame processing error: {}", e);
}
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
}
/// Handle a TLS-wrapped client connection
async fn handle_client_tls(
stream: tokio_rustls::server::TlsStream<TcpStream>,
state: AppState,
) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
}
});
// Reader loop — same logic as plaintext handler
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break;
}
read_buf.extend_from_slice(&buffer[..n]);
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("TLS connection exceeded max buffer size, dropping");
break;
}
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
read_buf.drain(..frame_size);
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if !rate_limiter.check() {
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
break 'reader;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
warn!("Frame processing error: {}", e);
}
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
}

125
crates/server/src/ws.rs Normal file
View File

@@ -0,0 +1,125 @@
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
use axum::response::IntoResponse;
use axum::extract::Query;
use jsonwebtoken::{decode, Validation, DecodingKey};
use serde::Deserialize;
use tokio::sync::broadcast;
use std::sync::Arc;
use tracing::{debug, warn};
use crate::api::auth::Claims;
use crate::AppState;
/// WebSocket hub for broadcasting real-time events to admin browsers
#[derive(Clone)]
pub struct WsHub {
tx: broadcast::Sender<String>,
}
impl WsHub {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(1024);
Self { tx }
}
pub async fn broadcast(&self, message: String) {
if self.tx.send(message).is_err() {
debug!("No WebSocket subscribers to receive broadcast");
}
}
pub fn subscribe(&self) -> broadcast::Receiver<String> {
self.tx.subscribe()
}
}
#[derive(Debug, Deserialize)]
pub struct WsAuthParams {
pub token: Option<String>,
}
/// HTTP upgrade handler for WebSocket connections
/// Validates JWT token from query parameter before upgrading
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(params): Query<WsAuthParams>,
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl IntoResponse {
let token = match params.token {
Some(t) => t,
None => {
warn!("WebSocket connection rejected: no token provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
}
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(e) => {
warn!("WebSocket connection rejected: invalid token - {}", e);
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
}
};
if claims.token_type != "access" {
warn!("WebSocket connection rejected: not an access token");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
}
let hub = state.ws_hub.clone();
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
}
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claims.username);
let welcome = serde_json::json!({
"type": "connected",
"message": "CSM real-time feed active",
"user": claims.username
});
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
return;
}
// Subscribe to broadcast hub for real-time events
let mut rx = hub.subscribe();
loop {
tokio::select! {
// Forward broadcast messages to WebSocket client
msg = rx.recv() => {
match msg {
Ok(text) => {
if socket.send(Message::Text(text)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
debug!("WebSocket client lagged {} messages, continuing", n);
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
// Handle incoming WebSocket messages (ping/close)
msg = socket.recv() => {
match msg {
Some(Ok(Message::Ping(data))) => {
if socket.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) => break,
Some(Err(_)) => break,
None => break,
_ => {}
}
}
}
}
debug!("WebSocket client disconnected: user={}", claims.username);
}

1
device_secret.txt Normal file
View File

@@ -0,0 +1 @@
6958a73b-ccd7-4790-82b4-1e88863d84ef

1
device_uid.txt Normal file
View File

@@ -0,0 +1 @@
a9e9f62a-c682-48fb-b1f4-b1429236ea92

File diff suppressed because it is too large Load Diff

1
login_resp.json Normal file
View File

@@ -0,0 +1 @@
{"success":true,"data":{"access_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjEsInVzZXJuYW1lIjoiYWRtaW4iLCJyb2xlIjoiYWRtaW4iLCJleHAiOjE3NzUyNjIxMzYsImlhdCI6MTc3NTI2MDMzNiwidG9rZW5fdHlwZSI6ImFjY2VzcyIsImZhbWlseSI6ImUxMTdjNDU0LTgxMGUtNDYxOC1hNjg5LWFkZGUyODI3MTI0MiJ9.DiZPv622vMCgkVszVjq41EIz19Yi0LMhAEiPDs7J5MY","refresh_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOjEsInVzZXJuYW1lIjoiYWRtaW4iLCJyb2xlIjoiYWRtaW4iLCJleHAiOjE3NzU4NjUxMzYsImlhdCI6MTc3NTI2MDMzNiwidG9rZW5fdHlwZSI6InJlZnJlc2giLCJmYW1pbHkiOiJlMTE3YzQ1NC04MTBlLTQ2MTgtYTY4OS1hZGRlMjgyNzEyNDIifQ.LvLQK2qmdxXrTxPUgoRbKFsvTCbeJwNjisdMenPrSuM","user":{"id":1,"username":"admin","role":"admin"}},"error":null}

70
migrations/001_init.sql Normal file
View File

@@ -0,0 +1,70 @@
-- 001_init.sql: Core tables (users, devices, device_status)
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'admin' CHECK(role IN ('admin', 'viewer')),
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS devices (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL UNIQUE,
hostname TEXT NOT NULL,
ip_address TEXT NOT NULL,
mac_address TEXT,
os_version TEXT,
client_version TEXT,
device_secret TEXT, -- HMAC key for message authentication
status TEXT NOT NULL DEFAULT 'offline' CHECK(status IN ('online', 'offline')),
last_heartbeat TEXT,
registered_at TEXT NOT NULL DEFAULT (datetime('now')),
group_name TEXT DEFAULT 'default'
);
CREATE TABLE IF NOT EXISTS device_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
cpu_usage REAL,
memory_usage REAL,
memory_total_mb INTEGER,
disk_usage REAL,
disk_total_mb INTEGER,
network_rx_rate INTEGER,
network_tx_rate INTEGER,
running_procs INTEGER,
top_processes TEXT,
reported_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(device_uid)
);
CREATE TABLE IF NOT EXISTS device_status_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
cpu_usage REAL,
memory_usage REAL,
disk_usage REAL,
network_rx_rate INTEGER,
network_tx_rate INTEGER,
running_procs INTEGER,
reported_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS device_groups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
description TEXT,
parent_id INTEGER REFERENCES device_groups(id),
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Insert default group
INSERT OR IGNORE INTO device_groups (name, description) VALUES ('default', 'Default device group');
-- Indexes
CREATE INDEX IF NOT EXISTS idx_devices_status ON devices(status);
CREATE INDEX IF NOT EXISTS idx_device_status_uid ON device_status(device_uid);
CREATE INDEX IF NOT EXISTS idx_status_history_device_time ON device_status_history(device_uid, reported_at);
CREATE INDEX IF NOT EXISTS idx_status_history_time ON device_status_history(reported_at);

42
migrations/002_assets.sql Normal file
View File

@@ -0,0 +1,42 @@
-- 002_assets.sql: Asset management tables
CREATE TABLE IF NOT EXISTS hardware_assets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
cpu_model TEXT,
cpu_cores INTEGER,
memory_total_mb INTEGER,
disk_model TEXT,
disk_total_mb INTEGER,
gpu_model TEXT,
motherboard TEXT,
serial_number TEXT,
reported_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(device_uid)
);
CREATE TABLE IF NOT EXISTS software_assets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
name TEXT NOT NULL,
version TEXT,
publisher TEXT,
install_date TEXT,
install_path TEXT,
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(device_uid, name, version)
);
CREATE TABLE IF NOT EXISTS asset_changes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
change_type TEXT NOT NULL CHECK(change_type IN ('hardware', 'software_added', 'software_removed')),
change_detail TEXT NOT NULL,
detected_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_software_device ON software_assets(device_uid);
CREATE INDEX IF NOT EXISTS idx_asset_changes_time ON asset_changes(detected_at);
CREATE INDEX IF NOT EXISTS idx_asset_changes_device ON asset_changes(device_uid);

28
migrations/003_usb.sql Normal file
View File

@@ -0,0 +1,28 @@
-- 003_usb.sql: USB control tables
CREATE TABLE IF NOT EXISTS usb_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
vendor_id TEXT,
product_id TEXT,
serial_number TEXT,
device_name TEXT,
event_type TEXT NOT NULL CHECK(event_type IN ('inserted', 'removed', 'blocked')),
event_time TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS usb_policies (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
policy_type TEXT NOT NULL CHECK(policy_type IN ('all_block', 'whitelist', 'blacklist')),
target_group TEXT,
rules TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_usb_events_device_time ON usb_events(device_uid, event_time);
CREATE INDEX IF NOT EXISTS idx_usb_events_time ON usb_events(event_time);
CREATE INDEX IF NOT EXISTS idx_usb_policies_target ON usb_policies(target_group);

46
migrations/004_alerts.sql Normal file
View File

@@ -0,0 +1,46 @@
-- 004_alerts.sql: Alert system tables
CREATE TABLE IF NOT EXISTS alert_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
rule_type TEXT NOT NULL CHECK(rule_type IN ('device_offline', 'cpu_high', 'memory_high', 'disk_high', 'usb_unauthorized', 'usb_unauth', 'asset_change')),
condition TEXT NOT NULL,
severity TEXT NOT NULL DEFAULT 'medium' CHECK(severity IN ('low', 'medium', 'high', 'critical')),
enabled INTEGER NOT NULL DEFAULT 1,
notify_email TEXT,
notify_webhook TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS alert_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
rule_id INTEGER REFERENCES alert_rules(id),
device_uid TEXT REFERENCES devices(device_uid) ON DELETE SET NULL,
alert_type TEXT NOT NULL,
severity TEXT NOT NULL DEFAULT 'medium' CHECK(severity IN ('low', 'medium', 'high', 'critical')),
detail TEXT NOT NULL,
handled INTEGER NOT NULL DEFAULT 0,
handled_by TEXT,
handled_at TEXT,
triggered_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Admin audit log
CREATE TABLE IF NOT EXISTS admin_audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
action TEXT NOT NULL,
target_type TEXT,
target_id TEXT,
detail TEXT,
ip_address TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_alert_records_time ON alert_records(triggered_at);
CREATE INDEX IF NOT EXISTS idx_alert_records_device ON alert_records(device_uid, triggered_at);
CREATE INDEX IF NOT EXISTS idx_alert_records_unhandled ON alert_records(handled) WHERE handled = 0;
CREATE INDEX IF NOT EXISTS idx_audit_log_user_time ON admin_audit_log(user_id, created_at);
CREATE INDEX IF NOT EXISTS idx_audit_log_time ON admin_audit_log(created_at);

View File

@@ -0,0 +1,23 @@
-- 005_plugins_web_filter.sql: Web Filter plugin (上网拦截)
CREATE TABLE IF NOT EXISTS web_filter_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
rule_type TEXT NOT NULL CHECK(rule_type IN ('blacklist', 'whitelist', 'category')),
pattern TEXT NOT NULL,
target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'group', 'device')),
target_id TEXT,
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS web_access_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
url TEXT NOT NULL,
action TEXT NOT NULL CHECK(action IN ('allowed', 'blocked')),
timestamp TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_web_filter_rules_type ON web_filter_rules(rule_type, enabled);
CREATE INDEX IF NOT EXISTS idx_web_access_log_device_time ON web_access_log(device_uid, timestamp);
CREATE INDEX IF NOT EXISTS idx_web_access_log_time ON web_access_log(timestamp);

View File

@@ -0,0 +1,24 @@
-- 006_plugins_usage_timer.sql: Usage Timer plugin (时长记录)
CREATE TABLE IF NOT EXISTS usage_daily (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
date TEXT NOT NULL,
total_active_minutes INTEGER NOT NULL DEFAULT 0,
total_idle_minutes INTEGER NOT NULL DEFAULT 0,
first_active_at TEXT,
last_active_at TEXT,
UNIQUE(device_uid, date)
);
CREATE TABLE IF NOT EXISTS app_usage_daily (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
date TEXT NOT NULL,
app_name TEXT NOT NULL,
usage_minutes INTEGER NOT NULL DEFAULT 0,
UNIQUE(device_uid, date, app_name)
);
CREATE INDEX IF NOT EXISTS idx_usage_daily_date ON usage_daily(date);
CREATE INDEX IF NOT EXISTS idx_app_usage_daily_date ON app_usage_daily(date);

View File

@@ -0,0 +1,24 @@
-- 007_plugins_software_blocker.sql: Software Blocker plugin (软件禁止安装)
CREATE TABLE IF NOT EXISTS software_blacklist (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name_pattern TEXT NOT NULL,
category TEXT CHECK(category IN ('game', 'social', 'vpn', 'mining', 'custom')),
action TEXT NOT NULL DEFAULT 'block' CHECK(action IN ('block', 'alert')),
target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'group', 'device')),
target_id TEXT,
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS software_violations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
software_name TEXT NOT NULL,
action_taken TEXT NOT NULL CHECK(action_taken IN ('blocked_install', 'auto_uninstalled', 'alerted')),
timestamp TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_software_blacklist_enabled ON software_blacklist(enabled);
CREATE INDEX IF NOT EXISTS idx_software_violations_device ON software_violations(device_uid, timestamp);
CREATE INDEX IF NOT EXISTS idx_software_violations_time ON software_violations(timestamp);

View File

@@ -0,0 +1,23 @@
-- 008_plugins_popup_blocker.sql: Popup Blocker plugin (弹窗拦截)
CREATE TABLE IF NOT EXISTS popup_filter_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
rule_type TEXT NOT NULL CHECK(rule_type IN ('block', 'allow')),
window_title TEXT,
window_class TEXT,
process_name TEXT,
target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'group', 'device')),
target_id TEXT,
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS popup_block_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
blocked_count INTEGER NOT NULL DEFAULT 0,
date TEXT NOT NULL,
UNIQUE(device_uid, date)
);
CREATE INDEX IF NOT EXISTS idx_popup_rules_enabled ON popup_filter_rules(rule_type, enabled);

View File

@@ -0,0 +1,16 @@
-- 009_plugins_usb_file_audit.sql: USB File Audit plugin (U盘文件操作记录)
CREATE TABLE IF NOT EXISTS usb_file_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
device_uid TEXT NOT NULL REFERENCES devices(device_uid) ON DELETE CASCADE,
usb_serial TEXT,
drive_letter TEXT,
operation TEXT NOT NULL CHECK(operation IN ('create', 'delete', 'rename', 'modify')),
file_path TEXT NOT NULL,
file_size INTEGER,
timestamp TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_usb_file_ops_device ON usb_file_operations(device_uid, timestamp);
CREATE INDEX IF NOT EXISTS idx_usb_file_ops_time ON usb_file_operations(timestamp);
CREATE INDEX IF NOT EXISTS idx_usb_file_ops_usb ON usb_file_operations(usb_serial, timestamp);

View File

@@ -0,0 +1,27 @@
-- 010_plugins_watermark.sql: Screen Watermark plugin (水印管理)
CREATE TABLE IF NOT EXISTS watermark_config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
target_type TEXT NOT NULL DEFAULT 'global' CHECK(target_type IN ('global', 'group', 'device')),
target_id TEXT,
content TEXT NOT NULL DEFAULT '公司名称 | {username} | {date}',
font_size INTEGER NOT NULL DEFAULT 14,
opacity REAL NOT NULL DEFAULT 0.15,
color TEXT NOT NULL DEFAULT '#808080',
angle INTEGER NOT NULL DEFAULT -30,
enabled INTEGER NOT NULL DEFAULT 1,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Plugin enable/disable state per device group
CREATE TABLE IF NOT EXISTS plugin_state (
id INTEGER PRIMARY KEY AUTOINCREMENT,
plugin_name TEXT NOT NULL UNIQUE CHECK(plugin_name IN (
'web_filter', 'usage_timer', 'software_blocker',
'popup_blocker', 'usb_file_audit', 'watermark'
)),
enabled INTEGER NOT NULL DEFAULT 0,
target_type TEXT NOT NULL DEFAULT 'global',
target_id TEXT,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);

View File

@@ -0,0 +1,18 @@
-- 011_token_security.sql: Token rotation and revocation tracking
CREATE TABLE IF NOT EXISTS revoked_token_families (
family TEXT NOT NULL PRIMARY KEY,
user_id INTEGER NOT NULL,
revoked_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
family TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
expires_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_revoked_families_user ON revoked_token_families(user_id);
CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user ON refresh_tokens(user_id);

1646
plan.md Normal file

File diff suppressed because it is too large Load Diff

27
test_route.rs Normal file
View File

@@ -0,0 +1,27 @@
use axum::{Router, routing::get,};
async fn hello() -> &'static str' {
"Hello"
}
async fn hello_id() -> &'static str' {
"Hello ID"
}
#[tokio::main]
async fn main() {
let read = Router::new()
.route("/test", get(hello))
.route("/test/{id}", get(hello_id));
let write = Router::new()
.route("/test/{id}", get(hello_id));
let app = Router::new()
.merge(read)
.merge(write);
let listener = tokio::net::TcpListener::bind("127.0.0.1:9999").await.unwrap();
axum::serve(listener, app).await;
println!("Server running on 9999");
}

9
web/auto-imports.d.ts vendored Normal file
View File

@@ -0,0 +1,9 @@
/* eslint-disable */
/* prettier-ignore */
// @ts-nocheck
// noinspection JSUnusedGlobalSymbols
// Generated by unplugin-auto-import
export {}
declare global {
}

53
web/components.d.ts vendored Normal file
View File

@@ -0,0 +1,53 @@
/* eslint-disable */
/* prettier-ignore */
// @ts-nocheck
// Generated by unplugin-vue-components
// Read more: https://github.com/vuejs/core/pull/3399
export {}
declare module 'vue' {
export interface GlobalComponents {
ElAside: typeof import('element-plus/es')['ElAside']
ElBadge: typeof import('element-plus/es')['ElBadge']
ElButton: typeof import('element-plus/es')['ElButton']
ElCard: typeof import('element-plus/es')['ElCard']
ElCol: typeof import('element-plus/es')['ElCol']
ElColorPicker: typeof import('element-plus/es')['ElColorPicker']
ElContainer: typeof import('element-plus/es')['ElContainer']
ElDescriptions: typeof import('element-plus/es')['ElDescriptions']
ElDescriptionsItem: typeof import('element-plus/es')['ElDescriptionsItem']
ElDialog: typeof import('element-plus/es')['ElDialog']
ElDropdown: typeof import('element-plus/es')['ElDropdown']
ElDropdownItem: typeof import('element-plus/es')['ElDropdownItem']
ElDropdownMenu: typeof import('element-plus/es')['ElDropdownMenu']
ElEmpty: typeof import('element-plus/es')['ElEmpty']
ElForm: typeof import('element-plus/es')['ElForm']
ElFormItem: typeof import('element-plus/es')['ElFormItem']
ElHeader: typeof import('element-plus/es')['ElHeader']
ElIcon: typeof import('element-plus/es')['ElIcon']
ElInput: typeof import('element-plus/es')['ElInput']
ElInputNumber: typeof import('element-plus/es')['ElInputNumber']
ElMain: typeof import('element-plus/es')['ElMain']
ElMenu: typeof import('element-plus/es')['ElMenu']
ElMenuItem: typeof import('element-plus/es')['ElMenuItem']
ElOption: typeof import('element-plus/es')['ElOption']
ElPageHeader: typeof import('element-plus/es')['ElPageHeader']
ElPagination: typeof import('element-plus/es')['ElPagination']
ElProgress: typeof import('element-plus/es')['ElProgress']
ElRow: typeof import('element-plus/es')['ElRow']
ElSelect: typeof import('element-plus/es')['ElSelect']
ElSlider: typeof import('element-plus/es')['ElSlider']
ElSubMenu: typeof import('element-plus/es')['ElSubMenu']
ElSwitch: typeof import('element-plus/es')['ElSwitch']
ElTable: typeof import('element-plus/es')['ElTable']
ElTableColumn: typeof import('element-plus/es')['ElTableColumn']
ElTabPane: typeof import('element-plus/es')['ElTabPane']
ElTabs: typeof import('element-plus/es')['ElTabs']
ElTag: typeof import('element-plus/es')['ElTag']
RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView']
}
export interface ComponentCustomProperties {
vLoading: typeof import('element-plus/es')['ElLoadingDirective']
}
}

13
web/index.html Normal file
View File

@@ -0,0 +1,13 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>CSM - 企业终端管理系统</title>
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

2759
web/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

30
web/package.json Normal file
View File

@@ -0,0 +1,30 @@
{
"name": "csm-web",
"version": "0.1.0",
"private": true,
"type": "module",
"scripts": {
"dev": "vite",
"build": "vue-tsc && vite build",
"preview": "vite preview",
"type-check": "vue-tsc --noEmit"
},
"dependencies": {
"@vueuse/core": "^10.7.2",
"axios": "^1.6.7",
"dayjs": "^1.11.10",
"echarts": "^5.5.0",
"element-plus": "^2.5.6",
"pinia": "^2.1.7",
"vue": "^3.4.21",
"vue-router": "^4.2.5"
},
"devDependencies": {
"@vitejs/plugin-vue": "^5.0.3",
"typescript": "^5.3.3",
"unplugin-auto-import": "^0.17.5",
"unplugin-vue-components": "^0.26.0",
"vite": "^5.0.12",
"vue-tsc": "^3.2.6"
}
}

8
web/src/App.vue Normal file
View File

@@ -0,0 +1,8 @@
<template>
<router-view />
</template>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
html, body, #app { height: 100%; }
</style>

126
web/src/lib/api.ts Normal file
View File

@@ -0,0 +1,126 @@
/**
* Shared API client with authentication and error handling
*/
const API_BASE = import.meta.env.VITE_API_BASE || ''
export interface ApiResult<T> {
success: boolean
data?: T
error?: string
}
export class ApiError extends Error {
constructor(
public status: number,
public code: string,
message: string,
) {
super(message)
this.name = 'ApiError'
}
}
function getToken(): string | null {
const token = localStorage.getItem('token')
if (!token || token.trim() === '') return null
return token
}
function clearAuth() {
localStorage.removeItem('token')
localStorage.removeItem('refresh_token')
window.location.href = '/login'
}
async function request<T>(
path: string,
options: RequestInit = {},
): Promise<T> {
const token = getToken()
const headers = new Headers(options.headers || {})
if (token) {
headers.set('Authorization', `Bearer ${token}`)
}
if (options.body && typeof options.body === 'string') {
headers.set('Content-Type', 'application/json')
}
const response = await fetch(`${API_BASE}${path}`, {
...options,
headers,
})
// Handle 401 - token expired or invalid
if (response.status === 401) {
clearAuth()
throw new ApiError(401, 'UNAUTHORIZED', 'Session expired')
}
// Handle 403 - insufficient permissions
if (response.status === 403) {
throw new ApiError(403, 'FORBIDDEN', 'Insufficient permissions')
}
// Handle non-JSON responses (502, 503, etc.)
const contentType = response.headers.get('content-type')
if (!contentType || !contentType.includes('application/json')) {
throw new ApiError(response.status, 'NON_JSON_RESPONSE', `Server returned ${response.status}`)
}
const result: ApiResult<T> = await response.json()
if (!result.success) {
throw new ApiError(response.status, 'API_ERROR', result.error || 'Unknown error')
}
return result.data as T
}
export const api = {
get<T>(path: string): Promise<T> {
return request<T>(path)
},
post<T>(path: string, body?: unknown): Promise<T> {
return request<T>(path, {
method: 'POST',
body: body ? JSON.stringify(body) : undefined,
})
},
put<T>(path: string, body?: unknown): Promise<T> {
return request<T>(path, {
method: 'PUT',
body: body ? JSON.stringify(body) : undefined,
})
},
delete<T = void>(path: string): Promise<T> {
return request<T>(path, { method: 'DELETE' })
},
/** Login doesn't use the auth header */
async login(username: string, password: string): Promise<{ access_token: string; refresh_token: string; user: { id: number; username: string; role: string } }> {
const response = await fetch(`${API_BASE}/api/auth/login`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ username, password }),
})
const result = await response.json()
if (!result.success) {
throw new ApiError(response.status, 'LOGIN_FAILED', result.error || 'Login failed')
}
localStorage.setItem('token', result.data.access_token)
localStorage.setItem('refresh_token', result.data.refresh_token)
return result.data
},
logout() {
clearAuth()
},
}

12
web/src/main.ts Normal file
View File

@@ -0,0 +1,12 @@
import { createApp } from 'vue'
import { createPinia } from 'pinia'
import 'element-plus/dist/index.css'
import App from './App.vue'
import router from './router'
const app = createApp(App)
app.use(createPinia())
app.use(router)
app.mount('#app')

63
web/src/router/index.ts Normal file
View File

@@ -0,0 +1,63 @@
import { createRouter, createWebHistory } from 'vue-router'
import AppLayout from '../views/Layout.vue'
const router = createRouter({
history: createWebHistory(),
routes: [
{ path: '/login', name: 'Login', component: () => import('../views/Login.vue') },
{
path: '/',
component: AppLayout,
redirect: '/dashboard',
children: [
{ path: 'dashboard', name: 'Dashboard', component: () => import('../views/Dashboard.vue') },
{ path: 'devices', name: 'Devices', component: () => import('../views/Devices.vue') },
{ path: 'devices/:uid', name: 'DeviceDetail', component: () => import('../views/DeviceDetail.vue') },
{ path: 'assets', name: 'Assets', component: () => import('../views/Assets.vue') },
{ path: 'usb', name: 'UsbPolicy', component: () => import('../views/UsbPolicy.vue') },
{ path: 'alerts', name: 'Alerts', component: () => import('../views/Alerts.vue') },
{ path: 'settings', name: 'Settings', component: () => import('../views/Settings.vue') },
// Phase 2: Plugin pages
{ path: 'plugins/web-filter', name: 'WebFilter', component: () => import('../views/plugins/WebFilter.vue') },
{ path: 'plugins/usage-timer', name: 'UsageTimer', component: () => import('../views/plugins/UsageTimer.vue') },
{ path: 'plugins/software-blocker', name: 'SoftwareBlocker', component: () => import('../views/plugins/SoftwareBlocker.vue') },
{ path: 'plugins/popup-blocker', name: 'PopupBlocker', component: () => import('../views/plugins/PopupBlocker.vue') },
{ path: 'plugins/usb-file-audit', name: 'UsbFileAudit', component: () => import('../views/plugins/UsbFileAudit.vue') },
{ path: 'plugins/watermark', name: 'Watermark', component: () => import('../views/plugins/Watermark.vue') },
],
},
],
})
/** Check if a JWT token is structurally valid and not expired */
function isTokenValid(token: string): boolean {
if (!token || token.trim() === '') return false
try {
const parts = token.split('.')
if (parts.length !== 3) return false
const payload = JSON.parse(atob(parts[1]))
if (!payload.exp) return false
// Reject if token expires within 30 seconds
return payload.exp * 1000 > Date.now() + 30_000
} catch {
return false
}
}
router.beforeEach((to, _from, next) => {
if (to.path === '/login') {
next()
return
}
const token = localStorage.getItem('token')
if (!token || !isTokenValid(token)) {
localStorage.removeItem('token')
localStorage.removeItem('refresh_token')
next('/login')
} else {
next()
}
})
export default router

87
web/src/stores/devices.ts Normal file
View File

@@ -0,0 +1,87 @@
import { defineStore } from 'pinia'
import { ref } from 'vue'
import axios from 'axios'
const api = axios.create({
baseURL: '/api',
headers: {
'Content-Type': 'application/json',
},
})
// Add auth token to requests
api.interceptors.request.use((config) => {
const token = localStorage.getItem('token')
if (token) {
config.headers.Authorization = `Bearer ${token}`
}
return config
})
export interface Device {
id: number
device_uid: string
hostname: string
ip_address: string
mac_address: string | null
os_version: string | null
client_version: string | null
status: 'online' | 'offline'
last_heartbeat: string | null
registered_at: string
group_name: string
}
export interface DeviceStatusDetail {
cpu_usage: number
memory_usage: number
memory_total: number
disk_usage: number
disk_total: number
running_procs: number
top_processes: Array<{ name: string; pid: number; cpu_usage: number; memory_mb: number }>
}
export const useDeviceStore = defineStore('devices', () => {
const devices = ref<Device[]>([])
const loading = ref(false)
const total = ref(0)
async function fetchDevices(params?: Record<string, string>) {
loading.value = true
try {
const { data } = await api.get('/devices', { params })
if (data.success) {
devices.value = data.data.devices
total.value = data.data.total ?? devices.value.length
}
} finally {
loading.value = false
}
}
async function fetchDeviceStatus(uid: string): Promise<DeviceStatusDetail | null> {
const { data } = await api.get(`/devices/${uid}/status`)
return data.success ? data.data : null
}
async function fetchDeviceHistory(uid: string, params?: Record<string, string>) {
const { data } = await api.get(`/devices/${uid}/history`, { params })
return data.success ? data.data : null
}
async function removeDevice(uid: string) {
await api.delete(`/devices/${uid}`)
devices.value = devices.value.filter((d) => d.device_uid !== uid)
}
return {
devices,
loading,
total,
fetchDevices,
fetchDeviceStatus,
fetchDeviceHistory,
removeDevice,
}
})

216
web/src/views/Alerts.vue Normal file
View File

@@ -0,0 +1,216 @@
<template>
<div class="alerts-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="告警记录" name="records">
<div class="toolbar">
<el-select v-model="severityFilter" placeholder="严重程度" clearable style="width: 140px" @change="fetchRecords">
<el-option label="Critical" value="critical" />
<el-option label="High" value="high" />
<el-option label="Medium" value="medium" />
<el-option label="Low" value="low" />
</el-select>
<el-select v-model="handledFilter" placeholder="处理状态" clearable style="width: 140px" @change="fetchRecords">
<el-option label="待处理" value="false" />
<el-option label="已处理" value="true" />
</el-select>
</div>
<el-table :data="records" v-loading="recLoading" stripe size="small">
<el-table-column label="严重程度" width="100">
<template #default="{ row }">
<el-tag :type="severityTag(row.severity)" size="small">{{ row.severity }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="alert_type" label="告警类型" width="130" />
<el-table-column prop="detail" label="详情" min-width="250" show-overflow-tooltip />
<el-table-column prop="device_uid" label="终端" width="150" show-overflow-tooltip />
<el-table-column prop="triggered_at" label="触发时间" width="170" />
<el-table-column label="状态" width="80">
<template #default="{ row }">
<el-tag :type="row.handled ? 'success' : 'warning'" size="small">
{{ row.handled ? '已处理' : '待处理' }}
</el-tag>
</template>
</el-table-column>
<el-table-column label="操作" width="100" fixed="right">
<template #default="{ row }">
<el-button v-if="!row.handled" link type="primary" size="small" @click="handleRecord(row.id)">处理</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="告警规则" name="rules">
<div class="toolbar">
<el-button type="primary" @click="showRuleDialog()">新建规则</el-button>
</div>
<el-table :data="rules" v-loading="ruleLoading" stripe size="small">
<el-table-column prop="name" label="规则名称" width="180" />
<el-table-column prop="rule_type" label="规则类型" width="140" />
<el-table-column prop="severity" label="严重程度" width="100">
<template #default="{ row }">
<el-tag :type="severityTag(row.severity)" size="small">{{ row.severity }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="condition" label="条件" min-width="200" show-overflow-tooltip />
<el-table-column prop="enabled" label="启用" width="80">
<template #default="{ row }">
<el-switch :model-value="row.enabled" @change="toggleRule(row)" size="small" />
</template>
</el-table-column>
<el-table-column label="操作" width="140" fixed="right">
<template #default="{ row }">
<el-button link type="primary" size="small" @click="showRuleDialog(row)">编辑</el-button>
<el-button link type="danger" size="small" @click="deleteRule(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
</el-tabs>
<el-dialog v-model="ruleDialogVisible" :title="editingRule ? '编辑规则' : '新建规则'" width="500px">
<el-form :model="ruleForm" label-width="100px">
<el-form-item label="规则名称">
<el-input v-model="ruleForm.name" />
</el-form-item>
<el-form-item label="规则类型">
<el-select v-model="ruleForm.rule_type" style="width: 100%">
<el-option label="CPU过高" value="cpu_high" />
<el-option label="内存过高" value="memory_high" />
<el-option label="未授权USB" value="usb_unauth" />
<el-option label="终端离线" value="device_offline" />
</el-select>
</el-form-item>
<el-form-item label="条件">
<el-input v-model="ruleForm.condition" type="textarea" :rows="2" placeholder='{"threshold":90,"duration_secs":300}' />
</el-form-item>
<el-form-item label="严重程度">
<el-select v-model="ruleForm.severity" style="width: 100%">
<el-option label="Critical" value="critical" />
<el-option label="High" value="high" />
<el-option label="Medium" value="medium" />
<el-option label="Low" value="low" />
</el-select>
</el-form-item>
<el-form-item label="通知邮箱">
<el-input v-model="ruleForm.notify_email" placeholder="可选" />
</el-form-item>
<el-form-item label="Webhook">
<el-input v-model="ruleForm.notify_webhook" placeholder="可选" />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="ruleDialogVisible = false">取消</el-button>
<el-button type="primary" @click="saveRule">保存</el-button>
</template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { api } from '@/lib/api'
const activeTab = ref('records')
// Records
const records = ref<any[]>([])
const recLoading = ref(false)
const severityFilter = ref('')
const handledFilter = ref('')
async function fetchRecords() {
recLoading.value = true
try {
const params = new URLSearchParams()
if (severityFilter.value) params.set('severity', severityFilter.value)
if (handledFilter.value) params.set('handled', handledFilter.value)
const data = await api.get<any>(`/api/alerts/records?${params}`)
records.value = data.records || []
} catch { /* api.ts handles 401 */ } finally { recLoading.value = false }
}
async function handleRecord(id: number) {
try {
await api.put(`/api/alerts/records/${id}/handle`)
ElMessage.success('已标记处理')
fetchRecords()
} catch (e: any) { ElMessage.error(e.message || '操作失败') }
}
// Rules
const rules = ref<any[]>([])
const ruleLoading = ref(false)
const ruleDialogVisible = ref(false)
const editingRule = ref<any>(null)
const ruleForm = reactive({ name: '', rule_type: 'cpu_high', condition: '{"threshold":90}', severity: 'high', notify_email: '', notify_webhook: '' })
async function fetchRules() {
ruleLoading.value = true
try {
const data = await api.get<any>('/api/alerts/rules')
rules.value = data.rules || []
} catch { /* api.ts handles 401 */ } finally { ruleLoading.value = false }
}
function showRuleDialog(row?: any) {
if (row) {
editingRule.value = row
ruleForm.name = row.name
ruleForm.rule_type = row.rule_type
ruleForm.condition = row.condition
ruleForm.severity = row.severity
ruleForm.notify_email = row.notify_email || ''
ruleForm.notify_webhook = row.notify_webhook || ''
} else {
editingRule.value = null
Object.assign(ruleForm, { name: '', rule_type: 'cpu_high', condition: '{"threshold":90}', severity: 'high', notify_email: '', notify_webhook: '' })
}
ruleDialogVisible.value = true
}
async function saveRule() {
try {
if (editingRule.value) {
await api.put(`/api/alerts/rules/${editingRule.value.id}`, ruleForm)
ElMessage.success('规则已更新')
} else {
await api.post('/api/alerts/rules', ruleForm)
ElMessage.success('规则已创建')
}
ruleDialogVisible.value = false
fetchRules()
} catch (e: any) { ElMessage.error(e.message || '操作失败') }
}
async function toggleRule(row: any) {
try {
await api.put(`/api/alerts/rules/${row.id}`, { enabled: !row.enabled ? 1 : 0 })
fetchRules()
} catch { /* ignore */ }
}
async function deleteRule(id: number) {
await ElMessageBox.confirm('确定删除该规则?', '确认', { type: 'warning' })
try {
await api.delete(`/api/alerts/rules/${id}`)
ElMessage.success('规则已删除')
fetchRules()
} catch (e: any) { ElMessage.error(e.message || '删除失败') }
}
function severityTag(s: string) {
const map: Record<string, string> = { critical: 'danger', high: 'warning', medium: '', low: 'info' }
return map[s] || 'info'
}
onMounted(() => {
fetchRecords()
fetchRules()
})
</script>
<style scoped>
.alerts-page { padding: 20px; }
.toolbar { display: flex; gap: 12px; margin-bottom: 16px; }
</style>

123
web/src/views/Assets.vue Normal file
View File

@@ -0,0 +1,123 @@
<template>
<div class="assets-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="硬件资产" name="hardware">
<div class="toolbar">
<el-input v-model="hwSearch" placeholder="搜索CPU/GPU型号" style="width: 300px" clearable @input="fetchHardware" />
</div>
<el-table :data="hardware" v-loading="hwLoading" stripe size="small">
<el-table-column prop="device_uid" label="终端UID" width="160" show-overflow-tooltip />
<el-table-column prop="cpu_model" label="CPU型号" min-width="180" />
<el-table-column prop="cpu_cores" label="核心数" width="80" />
<el-table-column label="内存" width="100">
<template #default="{ row }">{{ formatMB(row.memory_total_mb) }}</template>
</el-table-column>
<el-table-column prop="gpu_model" label="GPU" min-width="150" />
<el-table-column prop="reported_at" label="上报时间" width="170" />
</el-table>
</el-tab-pane>
<el-tab-pane label="软件资产" name="software">
<div class="toolbar">
<el-input v-model="swSearch" placeholder="搜索软件名称/发行商" style="width: 300px" clearable @input="fetchSoftware" />
</div>
<el-table :data="software" v-loading="swLoading" stripe size="small">
<el-table-column prop="name" label="软件名称" min-width="200" />
<el-table-column prop="version" label="版本" width="120" />
<el-table-column prop="publisher" label="发行商" min-width="150" />
<el-table-column prop="install_date" label="安装日期" width="120" />
<el-table-column prop="device_uid" label="终端UID" width="160" show-overflow-tooltip />
</el-table>
</el-tab-pane>
<el-tab-pane label="变更记录" name="changes">
<el-table :data="changes" v-loading="chLoading" stripe size="small">
<el-table-column prop="device_uid" label="终端UID" width="160" show-overflow-tooltip />
<el-table-column prop="change_type" label="变更类型" width="120">
<template #default="{ row }">
<el-tag :type="changeTag(row.change_type)" size="small">{{ row.change_type }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="change_detail" label="详情" min-width="300" show-overflow-tooltip />
<el-table-column prop="detected_at" label="检测时间" width="170" />
</el-table>
</el-tab-pane>
</el-tabs>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted, watch } from 'vue'
import { api } from '@/lib/api'
const activeTab = ref('hardware')
// Hardware
const hardware = ref<any[]>([])
const hwLoading = ref(false)
const hwSearch = ref('')
async function fetchHardware() {
hwLoading.value = true
try {
const params = new URLSearchParams()
if (hwSearch.value) params.set('search', hwSearch.value)
const data = await api.get<any>(`/api/assets/hardware?${params}`)
hardware.value = data.hardware || []
} catch { /* api.ts handles 401 */ } finally { hwLoading.value = false }
}
// Software
const software = ref<any[]>([])
const swLoading = ref(false)
const swSearch = ref('')
async function fetchSoftware() {
swLoading.value = true
try {
const params = new URLSearchParams()
if (swSearch.value) params.set('search', swSearch.value)
const data = await api.get<any>(`/api/assets/software?${params}`)
software.value = data.software || []
} catch { /* api.ts handles 401 */ } finally { swLoading.value = false }
}
// Changes
const changes = ref<any[]>([])
const chLoading = ref(false)
async function fetchChanges() {
chLoading.value = true
try {
const data = await api.get<any>('/api/assets/changes')
changes.value = data.changes || []
} catch { /* api.ts handles 401 */ } finally { chLoading.value = false }
}
function formatMB(mb: number) {
if (mb >= 1024) return `${(mb / 1024).toFixed(1)} GB`
return `${mb} MB`
}
function changeTag(type: string) {
const map: Record<string, string> = { hardware: 'warning', software_added: 'success', software_removed: 'danger' }
return map[type] || 'info'
}
onMounted(() => {
fetchHardware()
fetchSoftware()
fetchChanges()
})
watch(activeTab, () => {
if (activeTab.value === 'hardware') fetchHardware()
else if (activeTab.value === 'software') fetchSoftware()
else fetchChanges()
})
</script>
<style scoped>
.assets-page { padding: 20px; }
.toolbar { display: flex; gap: 12px; margin-bottom: 16px; }
</style>

264
web/src/views/Dashboard.vue Normal file
View File

@@ -0,0 +1,264 @@
<template>
<div class="dashboard">
<el-row :gutter="20" class="stat-cards">
<el-col :span="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-icon online"><el-icon :size="28"><Monitor /></el-icon></div>
<div class="stat-info">
<div class="stat-value">{{ stats.online }}</div>
<div class="stat-label">在线终端</div>
</div>
</el-card>
</el-col>
<el-col :span="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-icon offline"><el-icon :size="28"><Platform /></el-icon></div>
<div class="stat-info">
<div class="stat-value">{{ stats.offline }}</div>
<div class="stat-label">离线终端</div>
</div>
</el-card>
</el-col>
<el-col :span="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-icon warning"><el-icon :size="28"><Bell /></el-icon></div>
<div class="stat-info">
<div class="stat-value">{{ stats.alerts }}</div>
<div class="stat-label">待处理告警</div>
</div>
</el-card>
</el-col>
<el-col :span="6">
<el-card shadow="hover" class="stat-card">
<div class="stat-icon usb"><el-icon :size="28"><Connection /></el-icon></div>
<div class="stat-info">
<div class="stat-value">{{ stats.usbEvents }}</div>
<div class="stat-label">USB事件(24h)</div>
</div>
</el-card>
</el-col>
</el-row>
<el-row :gutter="20" style="margin-top: 20px">
<el-col :span="16">
<el-card shadow="hover">
<template #header>
<span class="card-title">终端状态总览</span>
</template>
<div ref="cpuChartRef" style="height: 320px"></div>
</el-card>
</el-col>
<el-col :span="8">
<el-card shadow="hover">
<template #header>
<span class="card-title">最近告警</span>
</template>
<div class="alert-list">
<div v-for="alert in recentAlerts" :key="alert.id" class="alert-item">
<el-tag :type="severityTag(alert.severity)" size="small">{{ alert.severity }}</el-tag>
<span class="alert-detail">{{ alert.detail }}</span>
<span class="alert-time">{{ alert.triggered_at }}</span>
</div>
<el-empty v-if="recentAlerts.length === 0" description="暂无告警" :image-size="60" />
</div>
</el-card>
</el-col>
</el-row>
<el-row :gutter="20" style="margin-top: 20px">
<el-col :span="12">
<el-card shadow="hover">
<template #header>
<span class="card-title">最近USB事件</span>
</template>
<el-table :data="recentUsbEvents" size="small" max-height="240">
<el-table-column prop="device_name" label="设备" width="120" />
<el-table-column label="类型" width="80">
<template #default="{ row }">
<el-tag :type="row.event_type === 'Inserted' ? 'success' : row.event_type === 'Blocked' ? 'danger' : 'info'" size="small">
{{ eventTypeLabel(row.event_type) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="device_uid" label="终端" show-overflow-tooltip />
<el-table-column prop="event_time" label="时间" width="160" />
</el-table>
</el-card>
</el-col>
<el-col :span="12">
<el-card shadow="hover">
<template #header>
<span class="card-title">Top 5 高负载终端</span>
</template>
<el-table :data="topDevices" size="small" max-height="240">
<el-table-column prop="hostname" label="主机名" />
<el-table-column label="CPU" width="140">
<template #default="{ row }">
<el-progress :percentage="Math.round(row.cpu_usage)" :stroke-width="6" :color="progressColor(row.cpu_usage)" />
</template>
</el-table-column>
<el-table-column label="内存" width="140">
<template #default="{ row }">
<el-progress :percentage="Math.round(row.memory_usage)" :stroke-width="6" :color="progressColor(row.memory_usage)" />
</template>
</el-table-column>
<el-table-column prop="status" label="状态" width="80">
<template #default="{ row }">
<el-tag :type="row.status === 'online' ? 'success' : 'info'" size="small">{{ row.status }}</el-tag>
</template>
</el-table-column>
</el-table>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted, onUnmounted } from 'vue'
import { Monitor, Platform, Bell, Connection } from '@element-plus/icons-vue'
import * as echarts from 'echarts'
import { api } from '@/lib/api'
const stats = ref({ online: 0, offline: 0, alerts: 0, usbEvents: 0 })
const recentAlerts = ref<Array<{ id: number; severity: string; detail: string; triggered_at: string }>>([])
const recentUsbEvents = ref<Array<{ device_name: string; event_type: string; device_uid: string; event_time: string }>>([])
const topDevices = ref<Array<{ hostname: string; cpu_usage: number; memory_usage: number; status: string }>>([])
const cpuChartRef = ref<HTMLElement>()
let chart: echarts.ECharts | null = null
let timer: ReturnType<typeof setInterval> | null = null
let resizeHandler: (() => void) | null = null
async function fetchDashboard() {
try {
const [devicesData, alertsData, usbData] = await Promise.all([
api.get<any>('/api/devices'),
api.get<any>('/api/alerts/records?handled=0&page_size=10'),
api.get<any>('/api/usb/events?page_size=10'),
])
const devices = devicesData.devices || []
stats.value.online = devices.filter((d: any) => d.status === 'online').length
stats.value.offline = devices.filter((d: any) => d.status === 'offline').length
topDevices.value = devices
.filter((d: any) => d.cpu_usage !== undefined)
.sort((a: any, b: any) => (b.cpu_usage || 0) - (a.cpu_usage || 0))
.slice(0, 5)
updateChart(devices)
const records = alertsData.records || []
stats.value.alerts = records.length
recentAlerts.value = records.slice(0, 8)
const events = usbData.events || []
stats.value.usbEvents = events.length
recentUsbEvents.value = events.slice(0, 8)
} catch {
// Silently fail - dashboard gracefully shows zeros
}
}
function initChart() {
if (!cpuChartRef.value) return
chart = echarts.init(cpuChartRef.value)
chart.setOption({
tooltip: { trigger: 'axis' },
legend: { data: ['CPU%', '内存%'] },
grid: { left: 50, right: 20, bottom: 30, top: 40 },
xAxis: { type: 'category', data: [] },
yAxis: { type: 'value', max: 100, axisLabel: { formatter: '{value}%' } },
series: [
{ name: 'CPU%', type: 'bar', data: [], itemStyle: { color: '#409EFF' } },
{ name: '内存%', type: 'bar', data: [], itemStyle: { color: '#67C23A' } },
],
})
}
function updateChart(devices: any[]) {
if (!chart) return
const top = devices
.filter((d: any) => d.cpu_usage !== undefined)
.sort((a: any, b: any) => (b.cpu_usage || 0) - (a.cpu_usage || 0))
.slice(0, 10)
chart.setOption({
xAxis: { data: top.map((d: any) => d.hostname || d.device_uid) },
series: [
{ data: top.map((d: any) => d.cpu_usage?.toFixed(1) || 0) },
{ data: top.map((d: any) => d.memory_usage?.toFixed(1) || 0) },
],
})
}
function severityTag(severity: string) {
const map: Record<string, string> = { critical: 'danger', high: 'warning', medium: '', low: 'info' }
return map[severity] || 'info'
}
function eventTypeLabel(type: string) {
const map: Record<string, string> = { Inserted: '插入', Removed: '拔出', Blocked: '拦截' }
return map[type] || type
}
function progressColor(value: number) {
if (value > 90) return '#F56C6C'
if (value > 70) return '#E6A23C'
return '#67C23A'
}
onMounted(() => {
fetchDashboard()
initChart()
timer = setInterval(fetchDashboard, 30000)
resizeHandler = () => chart?.resize()
window.addEventListener('resize', resizeHandler)
})
onUnmounted(() => {
if (timer) clearInterval(timer)
if (resizeHandler) window.removeEventListener('resize', resizeHandler)
chart?.dispose()
})
</script>
<style scoped>
.dashboard { padding: 20px; }
.stat-cards .stat-card {
display: flex;
align-items: center;
padding: 20px;
}
.stat-icon {
width: 56px;
height: 56px;
border-radius: 12px;
display: flex;
align-items: center;
justify-content: center;
margin-right: 16px;
color: #fff;
}
.stat-icon.online { background: linear-gradient(135deg, #67C23A, #409EFF); }
.stat-icon.offline { background: linear-gradient(135deg, #909399, #606266); }
.stat-icon.warning { background: linear-gradient(135deg, #E6A23C, #F56C6C); }
.stat-icon.usb { background: linear-gradient(135deg, #409EFF, #7C3AED); }
.stat-value { font-size: 28px; font-weight: 700; color: #303133; }
.stat-label { font-size: 13px; color: #909399; margin-top: 4px; }
.card-title { font-weight: 600; font-size: 15px; }
.alert-list { max-height: 320px; overflow-y: auto; }
.alert-item {
display: flex;
align-items: center;
gap: 8px;
padding: 8px 0;
border-bottom: 1px solid #f0f0f0;
}
.alert-detail { flex: 1; font-size: 13px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; }
.alert-time { font-size: 12px; color: #C0C4CC; white-space: nowrap; }
</style>

View File

@@ -0,0 +1,131 @@
<template>
<div class="device-detail" v-loading="loading">
<el-page-header @back="$router.back()" :title="'返回'">
<template #content>
<span>{{ device?.hostname || deviceUid }}</span>
<el-tag v-if="device" :type="device.status === 'online' ? 'success' : 'info'" size="small" style="margin-left: 8px">
{{ device.status === 'online' ? '在线' : '离线' }}
</el-tag>
</template>
</el-page-header>
<el-row :gutter="20" style="margin-top: 20px">
<el-col :span="8">
<el-card shadow="hover">
<template #header><span class="card-title">基本信息</span></template>
<el-descriptions :column="1" size="small" border>
<el-descriptions-item label="设备UID">{{ device?.device_uid }}</el-descriptions-item>
<el-descriptions-item label="主机名">{{ device?.hostname }}</el-descriptions-item>
<el-descriptions-item label="IP地址">{{ device?.ip_address }}</el-descriptions-item>
<el-descriptions-item label="MAC地址">{{ device?.mac_address || '-' }}</el-descriptions-item>
<el-descriptions-item label="操作系统">{{ device?.os_version || '-' }}</el-descriptions-item>
<el-descriptions-item label="客户端版本">{{ device?.client_version || '-' }}</el-descriptions-item>
<el-descriptions-item label="分组">{{ device?.group_name || '-' }}</el-descriptions-item>
<el-descriptions-item label="注册时间">{{ device?.registered_at || '-' }}</el-descriptions-item>
<el-descriptions-item label="最后心跳">{{ device?.last_heartbeat || '-' }}</el-descriptions-item>
</el-descriptions>
</el-card>
</el-col>
<el-col :span="16">
<el-card shadow="hover" style="margin-bottom: 20px">
<template #header><span class="card-title">实时状态</span></template>
<el-row :gutter="20" v-if="status">
<el-col :span="6">
<div class="metric">
<div class="metric-label">CPU</div>
<el-progress type="dashboard" :percentage="Math.round(status.cpu_usage)" :width="100"
:color="progressColor(status.cpu_usage)" />
</div>
</el-col>
<el-col :span="6">
<div class="metric">
<div class="metric-label">内存</div>
<el-progress type="dashboard" :percentage="Math.round(status.memory_usage)" :width="100"
:color="progressColor(status.memory_usage)" />
<div class="metric-sub">{{ formatMB(status.memory_total_mb) }}</div>
</div>
</el-col>
<el-col :span="6">
<div class="metric">
<div class="metric-label">磁盘</div>
<el-progress type="dashboard" :percentage="Math.round(status.disk_usage)" :width="100"
:color="progressColor(status.disk_usage)" />
<div class="metric-sub">{{ formatMB(status.disk_total_mb) }}</div>
</div>
</el-col>
<el-col :span="6">
<div class="metric">
<div class="metric-label">进程</div>
<div class="metric-value">{{ status.running_procs }}</div>
</div>
</el-col>
</el-row>
<el-empty v-else description="暂无状态数据" :image-size="60" />
</el-card>
<el-card shadow="hover">
<template #header><span class="card-title">Top 进程</span></template>
<el-table :data="status?.top_processes || []" size="small" max-height="200">
<el-table-column prop="name" label="进程名" />
<el-table-column prop="pid" label="PID" width="80" />
<el-table-column label="CPU" width="120">
<template #default="{ row }">
<el-progress :percentage="Math.min(Math.round(row.cpu_usage), 100)" :stroke-width="6" :color="progressColor(row.cpu_usage)" />
</template>
</el-table-column>
<el-table-column label="内存" width="100">
<template #default="{ row }">{{ formatMB(row.memory_mb) }}</template>
</el-table-column>
</el-table>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { useRoute } from 'vue-router'
import { api } from '@/lib/api'
const route = useRoute()
const deviceUid = route.params.uid as string
const loading = ref(true)
const device = ref<any>(null)
const status = ref<any>(null)
function progressColor(value: number) {
if (value > 90) return '#F56C6C'
if (value > 70) return '#E6A23C'
return '#67C23A'
}
function formatMB(mb: number) {
if (mb >= 1024) return `${(mb / 1024).toFixed(1)} GB`
return `${mb} MB`
}
onMounted(async () => {
try {
const [devData, statData] = await Promise.all([
api.get<any>(`/api/devices/${deviceUid}`),
api.get<any>(`/api/devices/${deviceUid}/status`),
])
device.value = devData
status.value = statData
} catch { /* api.ts handles 401 */ } finally {
loading.value = false
}
})
</script>
<style scoped>
.device-detail { padding: 20px; }
.card-title { font-weight: 600; font-size: 15px; }
.metric { text-align: center; padding: 10px 0; }
.metric-label { font-size: 13px; color: #909399; margin-bottom: 8px; }
.metric-value { font-size: 32px; font-weight: 700; color: #303133; margin-top: 16px; }
.metric-sub { font-size: 12px; color: #909399; margin-top: 4px; }
</style>

108
web/src/views/Devices.vue Normal file
View File

@@ -0,0 +1,108 @@
<template>
<div class="devices-page">
<div class="toolbar">
<el-input v-model="search" placeholder="搜索主机名/IP" style="width: 300px" clearable @input="handleSearch" />
<el-select v-model="statusFilter" placeholder="状态" clearable style="width: 120px" @change="handleSearch">
<el-option label="在线" value="online" />
<el-option label="离线" value="offline" />
</el-select>
<el-select v-model="groupFilter" placeholder="分组" clearable style="width: 150px" @change="handleSearch">
<el-option label="默认组" value="default" />
</el-select>
</div>
<el-table :data="deviceStore.devices" v-loading="deviceStore.loading" stripe @row-click="handleRowClick">
<el-table-column prop="hostname" label="主机名" min-width="150" />
<el-table-column prop="ip_address" label="IP地址" width="150" />
<el-table-column prop="group_name" label="分组" width="120" />
<el-table-column label="状态" width="100">
<template #default="{ row }">
<el-tag :type="row.status === 'online' ? 'success' : 'info'" size="small">
{{ row.status === 'online' ? '在线' : '离线' }}
</el-tag>
</template>
</el-table-column>
<el-table-column label="CPU" width="120">
<template #default="{ row }">
<el-progress :percentage="row.cpu_usage ?? 0" :stroke-width="6" :color="getProgressColor(row.cpu_usage)" />
</template>
</el-table-column>
<el-table-column label="内存" width="120">
<template #default="{ row }">
<el-progress :percentage="row.memory_usage ?? 0" :stroke-width="6" :color="getProgressColor(row.memory_usage)" />
</template>
</el-table-column>
<el-table-column prop="last_heartbeat" label="最后心跳" width="180" />
<el-table-column label="操作" width="100" fixed="right">
<template #default="{ row }">
<el-button type="danger" link size="small" @click.stop="handleDelete(row)">移除</el-button>
</template>
</el-table-column>
</el-table>
<el-pagination
style="margin-top: 20px; justify-content: flex-end"
:total="deviceStore.total"
:page-size="20"
layout="total, prev, pager, next"
@current-change="handlePageChange"
/>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { ElMessage, ElMessageBox } from 'element-plus'
import { useDeviceStore } from '../stores/devices'
import type { Device } from '../stores/devices'
const router = useRouter()
const deviceStore = useDeviceStore()
const search = ref('')
const statusFilter = ref('')
const groupFilter = ref('')
const currentPage = ref(1)
onMounted(() => {
deviceStore.fetchDevices()
})
function handleSearch() {
currentPage.value = 1
deviceStore.fetchDevices({
search: search.value,
status: statusFilter.value,
group: groupFilter.value,
page: '1',
})
}
function handlePageChange(page: number) {
currentPage.value = page
deviceStore.fetchDevices({ page: String(page) })
}
function handleRowClick(row: Device) {
router.push(`/devices/${row.device_uid}`)
}
async function handleDelete(row: Device) {
await ElMessageBox.confirm(`确定移除设备 ${row.hostname}?`, '确认', { type: 'warning' })
await deviceStore.removeDevice(row.device_uid)
ElMessage.success('设备已移除')
}
function getProgressColor(value?: number): string {
if (!value) return '#67C23A'
if (value > 90) return '#F56C6C'
if (value > 70) return '#E6A23C'
return '#67C23A'
}
</script>
<style scoped>
.devices-page { padding: 20px; }
.toolbar { display: flex; gap: 12px; margin-bottom: 20px; }
</style>

189
web/src/views/Layout.vue Normal file
View File

@@ -0,0 +1,189 @@
<template>
<el-container class="app-container">
<el-aside width="220px" class="sidebar">
<div class="logo">
<h2>CSM</h2>
<span>终端管理系统</span>
</div>
<el-menu
:default-active="currentRoute"
router
background-color="#1d1e2c"
text-color="#a0a3bd"
active-text-color="#409eff"
>
<el-menu-item index="/dashboard">
<el-icon><Monitor /></el-icon>
<span>仪表盘</span>
</el-menu-item>
<el-menu-item index="/devices">
<el-icon><Platform /></el-icon>
<span>设备管理</span>
</el-menu-item>
<el-menu-item index="/assets">
<el-icon><Box /></el-icon>
<span>资产管理</span>
</el-menu-item>
<el-menu-item index="/usb">
<el-icon><Connection /></el-icon>
<span>U盘管控</span>
</el-menu-item>
<el-menu-item index="/alerts">
<el-icon><Bell /></el-icon>
<span>告警中心</span>
</el-menu-item>
<el-sub-menu index="plugins">
<template #title>
<el-icon><Grid /></el-icon>
<span>安全插件</span>
</template>
<el-menu-item index="/plugins/web-filter">上网拦截</el-menu-item>
<el-menu-item index="/plugins/usage-timer">时长记录</el-menu-item>
<el-menu-item index="/plugins/software-blocker">软件管控</el-menu-item>
<el-menu-item index="/plugins/popup-blocker">弹窗拦截</el-menu-item>
<el-menu-item index="/plugins/usb-file-audit">U盘审计</el-menu-item>
<el-menu-item index="/plugins/watermark">水印管理</el-menu-item>
</el-sub-menu>
<el-menu-item index="/settings">
<el-icon><Setting /></el-icon>
<span>系统设置</span>
</el-menu-item>
</el-menu>
</el-aside>
<el-container>
<el-header class="app-header">
<div class="header-left">
<span class="page-title">{{ pageTitle }}</span>
</div>
<div class="header-right">
<el-badge :value="unreadAlerts" :hidden="unreadAlerts === 0">
<el-icon :size="20"><Bell /></el-icon>
</el-badge>
<el-dropdown>
<span class="user-info">
{{ username }} <el-icon><ArrowDown /></el-icon>
</span>
<template #dropdown>
<el-dropdown-menu>
<el-dropdown-item @click="handleLogout">退出登录</el-dropdown-item>
</el-dropdown-menu>
</template>
</el-dropdown>
</div>
</el-header>
<el-main>
<router-view />
</el-main>
</el-container>
</el-container>
</template>
<script setup lang="ts">
import { computed, ref, onMounted } from 'vue'
import { useRoute, useRouter } from 'vue-router'
import {
Monitor, Platform, Box, Connection, Bell, Setting, ArrowDown, Grid
} from '@element-plus/icons-vue'
import { api } from '@/lib/api'
const route = useRoute()
const router = useRouter()
const currentRoute = computed(() => route.path)
const unreadAlerts = ref(0)
const username = ref('')
function decodeUsername(): string {
try {
const token = localStorage.getItem('token')
if (!token) return ''
const payload = JSON.parse(atob(token.split('.')[1]))
return payload.username || ''
} catch {
return ''
}
}
async function fetchUnreadAlerts() {
try {
const data = await api.get<any>('/api/alerts/records?handled=0&page_size=1')
unreadAlerts.value = data.records?.length || 0
} catch {
// Silently fail
}
}
const pageTitles: Record<string, string> = {
'/dashboard': '仪表盘',
'/devices': '设备管理',
'/assets': '资产管理',
'/usb': 'U盘管控',
'/alerts': '告警中心',
'/settings': '系统设置',
'/plugins/web-filter': '上网拦截',
'/plugins/usage-timer': '时长记录',
'/plugins/software-blocker': '软件管控',
'/plugins/popup-blocker': '弹窗拦截',
'/plugins/usb-file-audit': 'U盘审计',
'/plugins/watermark': '水印管理',
}
const pageTitle = computed(() => pageTitles[route.path] || '仪表盘')
onMounted(() => {
username.value = decodeUsername()
fetchUnreadAlerts()
})
function handleLogout() {
localStorage.removeItem('token')
localStorage.removeItem('refresh_token')
router.push('/login')
}
</script>
<style scoped>
.app-container { height: 100vh; }
.sidebar {
background-color: #1d1e2c;
overflow-y: auto;
}
.logo {
padding: 20px;
text-align: center;
color: #fff;
border-bottom: 1px solid #2d2e3e;
}
.logo h2 { font-size: 24px; margin-bottom: 4px; }
.logo span { font-size: 12px; color: #a0a3bd; }
.app-header {
display: flex;
align-items: center;
justify-content: space-between;
border-bottom: 1px solid #e4e7ed;
background: #fff;
}
.page-title { font-size: 18px; font-weight: 600; }
.header-right {
display: flex;
align-items: center;
gap: 20px;
}
.user-info {
display: flex;
align-items: center;
gap: 4px;
cursor: pointer;
color: #606266;
}
</style>

119
web/src/views/Login.vue Normal file
View File

@@ -0,0 +1,119 @@
<template>
<div class="login-container">
<div class="login-card">
<div class="login-header">
<h2>CSM</h2>
<p>终端管理系统</p>
</div>
<el-form ref="formRef" :model="form" :rules="rules" @submit.prevent="handleLogin">
<el-form-item prop="username">
<el-input
v-model="form.username"
placeholder="用户名"
:prefix-icon="User"
size="large"
@keyup.enter="handleLogin"
/>
</el-form-item>
<el-form-item prop="password">
<el-input
v-model="form.password"
type="password"
placeholder="密码"
:prefix-icon="Lock"
size="large"
show-password
@keyup.enter="handleLogin"
/>
</el-form-item>
<el-form-item>
<el-button
type="primary"
size="large"
:loading="loading"
style="width: 100%"
@click="handleLogin"
>
登录
</el-button>
</el-form-item>
</el-form>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, reactive } from 'vue'
import { useRouter } from 'vue-router'
import { ElMessage } from 'element-plus'
import { User, Lock } from '@element-plus/icons-vue'
import { api, ApiError } from '../lib/api'
const router = useRouter()
const formRef = ref()
const loading = ref(false)
const form = reactive({
username: '',
password: '',
})
const rules = {
username: [{ required: true, message: '请输入用户名', trigger: 'blur' }],
password: [{ required: true, message: '请输入密码', trigger: 'blur' }],
}
async function handleLogin() {
const valid = await formRef.value?.validate().catch(() => false)
if (!valid) return
loading.value = true
try {
const data = await api.login(form.username, form.password)
ElMessage.success(`欢迎, ${data.user.username}`)
router.push('/dashboard')
} catch (e) {
if (e instanceof ApiError) {
ElMessage.error(e.message || '登录失败')
} else {
ElMessage.error('网络错误,请检查连接')
}
} finally {
loading.value = false
}
}
</script>
<style scoped>
.login-container {
height: 100vh;
display: flex;
align-items: center;
justify-content: center;
background: linear-gradient(135deg, #1d1e2c 0%, #2d3a4a 100%);
}
.login-card {
width: 400px;
padding: 40px;
background: #fff;
border-radius: 12px;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.15);
}
.login-header {
text-align: center;
margin-bottom: 32px;
}
.login-header h2 {
font-size: 32px;
color: #303133;
margin-bottom: 8px;
}
.login-header p {
font-size: 14px;
color: #909399;
}
</style>

115
web/src/views/Settings.vue Normal file
View File

@@ -0,0 +1,115 @@
<template>
<div class="settings-page">
<el-row :gutter="20">
<el-col :span="12">
<el-card shadow="hover">
<template #header><span class="card-title">系统信息</span></template>
<el-descriptions :column="1" border size="small">
<el-descriptions-item label="系统版本">v{{ version }}</el-descriptions-item>
<el-descriptions-item label="数据库">{{ dbInfo }}</el-descriptions-item>
<el-descriptions-item label="在线终端">{{ health.connected_clients }}</el-descriptions-item>
</el-descriptions>
</el-card>
<el-card shadow="hover" style="margin-top: 20px">
<template #header><span class="card-title">修改密码</span></template>
<el-form :model="pwdForm" label-width="100px" size="small">
<el-form-item label="当前密码">
<el-input v-model="pwdForm.oldPassword" type="password" show-password />
</el-form-item>
<el-form-item label="新密码">
<el-input v-model="pwdForm.newPassword" type="password" show-password />
</el-form-item>
<el-form-item label="确认密码">
<el-input v-model="pwdForm.confirmPassword" type="password" show-password />
</el-form-item>
<el-form-item>
<el-button type="primary" @click="changePassword">修改密码</el-button>
</el-form-item>
</el-form>
</el-card>
</el-col>
<el-col :span="12">
<el-card shadow="hover">
<template #header><span class="card-title">数据维护</span></template>
<el-form label-width="100px" size="small">
<el-form-item label="历史数据">
<el-button @click="showRetentionInfo">查看保留策略</el-button>
</el-form-item>
<el-form-item label="数据库">
<el-button type="warning" @click="manualCleanup">手动清理</el-button>
</el-form-item>
</el-form>
</el-card>
<el-card shadow="hover" style="margin-top: 20px">
<template #header><span class="card-title">当前用户</span></template>
<el-descriptions :column="1" border size="small">
<el-descriptions-item label="用户名">{{ user.username }}</el-descriptions-item>
<el-descriptions-item label="角色">{{ user.role }}</el-descriptions-item>
</el-descriptions>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage } from 'element-plus'
import { api } from '@/lib/api'
const version = ref('0.1.0')
const dbInfo = ref('SQLite (WAL mode)')
const health = reactive({ connected_clients: 0, db_size_bytes: 0 })
const user = reactive({ username: 'admin', role: 'admin' })
const pwdForm = reactive({ oldPassword: '', newPassword: '', confirmPassword: '' })
onMounted(() => {
// Decode username from JWT token
try {
const token = localStorage.getItem('token')
if (token) {
const payload = JSON.parse(atob(token.split('.')[1]))
user.username = payload.username || 'admin'
user.role = payload.role || 'admin'
}
} catch { /* ignore */ }
api.get<any>('/health')
.then((data: any) => {
if (data.version) version.value = data.version
health.connected_clients = data.connected_clients || 0
const bytes = data.db_size_bytes || 0
dbInfo.value = `SQLite (WAL mode) - ${(bytes / 1024 / 1024).toFixed(2)} MB`
})
.catch(() => { /* ignore */ })
})
function changePassword() {
if (pwdForm.newPassword !== pwdForm.confirmPassword) {
ElMessage.error('两次输入的密码不一致')
return
}
if (pwdForm.newPassword.length < 6) {
ElMessage.error('密码至少6位')
return
}
ElMessage.success('密码修改功能待实现')
}
function showRetentionInfo() {
ElMessage.info('数据保留策略在 config.toml 中配置')
}
function manualCleanup() {
ElMessage.warning('手动清理功能需通过服务器配置触发')
}
</script>
<style scoped>
.settings-page { padding: 20px; }
.card-title { font-weight: 600; font-size: 15px; }
</style>

192
web/src/views/UsbPolicy.vue Normal file
View File

@@ -0,0 +1,192 @@
<template>
<div class="usb-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="策略管理" name="policies">
<div class="toolbar">
<el-button type="primary" @click="showPolicyDialog()">新建策略</el-button>
</div>
<el-table :data="policies" v-loading="loading" stripe size="small">
<el-table-column prop="name" label="策略名称" width="180" />
<el-table-column prop="policy_type" label="策略类型" width="120">
<template #default="{ row }">
<el-tag :type="policyTypeTag(row.policy_type)" size="small">{{ policyTypeLabel(row.policy_type) }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="target_group" label="目标分组" width="120" />
<el-table-column prop="enabled" label="启用" width="80">
<template #default="{ row }">
<el-switch :model-value="row.enabled" @change="togglePolicy(row)" size="small" />
</template>
</el-table-column>
<el-table-column prop="created_at" label="创建时间" width="170" />
<el-table-column label="操作" width="140" fixed="right">
<template #default="{ row }">
<el-button link type="primary" size="small" @click="showPolicyDialog(row)">编辑</el-button>
<el-button link type="danger" size="small" @click="deletePolicy(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="事件日志" name="events">
<div class="toolbar">
<el-select v-model="eventFilter" placeholder="事件类型" clearable style="width: 150px" @change="fetchEvents">
<el-option label="插入" value="Inserted" />
<el-option label="拔出" value="Removed" />
<el-option label="拦截" value="Blocked" />
</el-select>
</div>
<el-table :data="events" v-loading="evLoading" stripe size="small">
<el-table-column prop="device_name" label="USB设备" width="150" />
<el-table-column label="事件类型" width="100">
<template #default="{ row }">
<el-tag :type="row.event_type === 'Inserted' ? 'success' : row.event_type === 'Blocked' ? 'danger' : 'info'" size="small">
{{ eventTypeLabel(row.event_type) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="vendor_id" label="VID" width="100" />
<el-table-column prop="product_id" label="PID" width="100" />
<el-table-column prop="serial_number" label="序列号" width="160" />
<el-table-column prop="device_uid" label="终端UID" min-width="160" show-overflow-tooltip />
<el-table-column prop="event_time" label="时间" width="170" />
</el-table>
</el-tab-pane>
</el-tabs>
<el-dialog v-model="policyDialogVisible" :title="editingPolicy ? '编辑策略' : '新建策略'" width="500px">
<el-form :model="policyForm" label-width="100px">
<el-form-item label="策略名称">
<el-input v-model="policyForm.name" />
</el-form-item>
<el-form-item label="策略类型">
<el-select v-model="policyForm.policy_type" style="width: 100%">
<el-option label="全部拦截" value="all_block" />
<el-option label="白名单" value="whitelist" />
<el-option label="黑名单" value="blacklist" />
</el-select>
</el-form-item>
<el-form-item label="目标分组">
<el-input v-model="policyForm.target_group" placeholder="留空表示全部终端" />
</el-form-item>
<el-form-item label="设备规则">
<el-input v-model="policyForm.rules" type="textarea" :rows="3" placeholder='[{"vendor_id":"1234","product_id":"5678"}]' />
</el-form-item>
</el-form>
<template #footer>
<el-button @click="policyDialogVisible = false">取消</el-button>
<el-button type="primary" @click="savePolicy">保存</el-button>
</template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
import { api } from '@/lib/api'
const activeTab = ref('policies')
// Policies
const policies = ref<any[]>([])
const loading = ref(false)
const policyDialogVisible = ref(false)
const editingPolicy = ref<any>(null)
const policyForm = reactive({ name: '', policy_type: 'all_block', target_group: '', rules: '[]' })
async function fetchPolicies() {
loading.value = true
try {
const data = await api.get<any>('/api/usb/policies')
policies.value = data.policies || []
} catch { /* api.ts handles 401 */ } finally { loading.value = false }
}
function showPolicyDialog(row?: any) {
if (row) {
editingPolicy.value = row
policyForm.name = row.name
policyForm.policy_type = row.policy_type
policyForm.target_group = row.target_group || ''
policyForm.rules = row.rules || '[]'
} else {
editingPolicy.value = null
policyForm.name = ''
policyForm.policy_type = 'all_block'
policyForm.target_group = ''
policyForm.rules = '[]'
}
policyDialogVisible.value = true
}
async function savePolicy() {
try {
if (editingPolicy.value) {
await api.put(`/api/usb/policies/${editingPolicy.value.id}`, policyForm)
ElMessage.success('策略已更新')
} else {
await api.post('/api/usb/policies', policyForm)
ElMessage.success('策略已创建')
}
policyDialogVisible.value = false
fetchPolicies()
} catch (e: any) { ElMessage.error(e.message || '操作失败') }
}
async function togglePolicy(row: any) {
try {
await api.put(`/api/usb/policies/${row.id}`, { enabled: !row.enabled ? 1 : 0 })
fetchPolicies()
} catch { /* ignore */ }
}
async function deletePolicy(id: number) {
await ElMessageBox.confirm('确定删除该策略?', '确认', { type: 'warning' })
try {
await api.delete(`/api/usb/policies/${id}`)
ElMessage.success('策略已删除')
fetchPolicies()
} catch (e: any) { ElMessage.error(e.message || '删除失败') }
}
function policyTypeTag(type: string) {
const map: Record<string, string> = { all_block: 'danger', whitelist: 'success', blacklist: 'warning' }
return map[type] || 'info'
}
function policyTypeLabel(type: string) {
const map: Record<string, string> = { all_block: '全部拦截', whitelist: '白名单', blacklist: '黑名单' }
return map[type] || type
}
// Events
const events = ref<any[]>([])
const evLoading = ref(false)
const eventFilter = ref('')
async function fetchEvents() {
evLoading.value = true
try {
const params = new URLSearchParams()
if (eventFilter.value) params.set('event_type', eventFilter.value)
const data = await api.get<any>(`/api/usb/events?${params}`)
events.value = data.events || []
} catch { /* api.ts handles 401 */ } finally { evLoading.value = false }
}
function eventTypeLabel(type: string) {
const map: Record<string, string> = { Inserted: '插入', Removed: '拔出', Blocked: '拦截' }
return map[type] || type
}
onMounted(() => {
fetchPolicies()
fetchEvents()
})
</script>
<style scoped>
.usb-page { padding: 20px; }
.toolbar { display: flex; gap: 12px; margin-bottom: 16px; }
</style>

View File

@@ -0,0 +1,56 @@
<template>
<div class="plugin-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="拦截规则" name="rules">
<div class="toolbar"><el-button type="primary" @click="showDialog()">新建规则</el-button></div>
<el-table :data="rules" v-loading="loading" stripe size="small">
<el-table-column prop="rule_type" label="类型" width="80">
<template #default="{ row }"><el-tag :type="row.rule_type==='block'?'danger':'success'" size="small">{{ row.rule_type==='block'?'拦截':'放行' }}</el-tag></template>
</el-table-column>
<el-table-column prop="window_title" label="窗口标题" min-width="180" show-overflow-tooltip />
<el-table-column prop="window_class" label="窗口类" width="140" show-overflow-tooltip />
<el-table-column prop="process_name" label="进程" width="140" show-overflow-tooltip />
<el-table-column prop="enabled" label="启用" width="70">
<template #default="{ row }"><el-tag :type="row.enabled?'success':'info'" size="small">{{ row.enabled?'是':'否' }}</el-tag></template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button link type="primary" size="small" @click="showDialog(row)">编辑</el-button>
<el-button link type="danger" size="small" @click="remove(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="拦截统计" name="stats">
<el-table :data="stats" v-loading="sLoading" stripe size="small">
<el-table-column prop="device_uid" label="终端" min-width="160" show-overflow-tooltip />
<el-table-column prop="blocked_count" label="拦截次数" width="120" />
<el-table-column prop="date" label="日期" width="120" />
</el-table>
</el-tab-pane>
</el-tabs>
<el-dialog v-model="visible" :title="editing?'编辑规则':'新建规则'" width="480px">
<el-form :model="form" label-width="80px">
<el-form-item label="类型"><el-select v-model="form.rule_type"><el-option label="拦截" value="block" /><el-option label="放行" value="allow" /></el-select></el-form-item>
<el-form-item label="窗口标题"><el-input v-model="form.window_title" placeholder="匹配模式(支持*通配符)" /></el-form-item>
<el-form-item label="窗口类"><el-input v-model="form.window_class" /></el-form-item>
<el-form-item label="进程名"><el-input v-model="form.process_name" /></el-form-item>
</el-form>
<template #footer><el-button @click="visible=false">取消</el-button><el-button type="primary" @click="save">保存</el-button></template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
const activeTab=ref('rules'),auth=()=>({headers:{Authorization:`Bearer ${localStorage.getItem('token')}`,'Content-Type':'application/json'}})
const rules=ref<any[]>([]),loading=ref(false),stats=ref<any[]>([]),sLoading=ref(false),visible=ref(false),editing=ref<any>(null)
const form=reactive({rule_type:'block',window_title:'',window_class:'',process_name:''})
async function fetchRules(){loading.value=true;try{const r=await fetch('/api/plugins/popup-blocker/rules',auth()).then(r=>r.json());if(r.success)rules.value=r.data.rules||[]}finally{loading.value=false}}
async function fetchStats(){sLoading.value=true;try{const r=await fetch('/api/plugins/popup-blocker/stats',auth()).then(r=>r.json());if(r.success)stats.value=r.data.stats||[]}finally{sLoading.value=false}}
function showDialog(row?:any){if(row){editing.value=row;Object.assign(form,{rule_type:row.rule_type,window_title:row.window_title||'',window_class:row.window_class||'',process_name:row.process_name||''})}else{editing.value=null;Object.assign(form,{rule_type:'block',window_title:'',window_class:'',process_name:''})}visible.value=true}
async function save(){const url=editing.value?`/api/plugins/popup-blocker/rules/${editing.value.id}`:'/api/plugins/popup-blocker/rules';const m=editing.value?'PUT':'POST';const r=await fetch(url,{method:m,...auth(),body:JSON.stringify(form)}).then(r=>r.json());if(r.success){ElMessage.success('已保存');visible.value=false;fetchRules()}else{ElMessage.error(r.error)}}
async function remove(id:number){await ElMessageBox.confirm('确定删除?','确认',{type:'warning'});await fetch(`/api/plugins/popup-blocker/rules/${id}`,{method:'DELETE',...auth()});ElMessage.success('已删除');fetchRules()}
onMounted(()=>{fetchRules();fetchStats()})
</script>
<style scoped>.plugin-page{padding:20px}.toolbar{display:flex;gap:12px;margin-bottom:16px}</style>

View File

@@ -0,0 +1,71 @@
<template>
<div class="plugin-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="软件黑名单" name="blacklist">
<div class="toolbar">
<el-button type="primary" @click="showDialog()">添加规则</el-button>
</div>
<el-table :data="blacklist" v-loading="loading" stripe size="small">
<el-table-column prop="name_pattern" label="软件名称匹配" min-width="200" />
<el-table-column prop="category" label="分类" width="100">
<template #default="{ row }"><el-tag size="small">{{ catLabel(row.category) }}</el-tag></template>
</el-table-column>
<el-table-column prop="action" label="动作" width="100">
<template #default="{ row }"><el-tag :type="row.action==='block'?'danger':'warning'" size="small">{{ row.action==='block'?'阻止':'告警' }}</el-tag></template>
</el-table-column>
<el-table-column prop="target_type" label="范围" width="80" />
<el-table-column prop="enabled" label="启用" width="70">
<template #default="{ row }"><el-tag :type="row.enabled?'success':'info'" size="small">{{ row.enabled?'是':'否' }}</el-tag></template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button link type="danger" size="small" @click="remove(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="违规记录" name="violations">
<div class="toolbar"><el-input v-model="vFilter" placeholder="终端UID" style="width:200px" clearable @input="fetchViolations" /></div>
<el-table :data="violations" v-loading="vLoading" stripe size="small">
<el-table-column prop="device_uid" label="终端" width="160" show-overflow-tooltip />
<el-table-column prop="software_name" label="软件" min-width="200" />
<el-table-column prop="action_taken" label="处理动作" width="150">
<template #default="{ row }"><el-tag :type="actionTag(row.action_taken)" size="small">{{ actionLabel(row.action_taken) }}</el-tag></template>
</el-table-column>
<el-table-column prop="timestamp" label="时间" width="170" />
</el-table>
</el-tab-pane>
</el-tabs>
<el-dialog v-model="dialogVisible" title="添加黑名单规则" width="480px">
<el-form :model="form" label-width="100px">
<el-form-item label="软件名称"><el-input v-model="form.name_pattern" placeholder="*游戏* / *.exe" /></el-form-item>
<el-form-item label="分类"><el-select v-model="form.category"><el-option label="游戏" value="game" /><el-option label="社交" value="social" /><el-option label="VPN" value="vpn" /><el-option label="挖矿" value="mining" /><el-option label="自定义" value="custom" /></el-select></el-form-item>
<el-form-item label="动作"><el-select v-model="form.action"><el-option label="阻止安装" value="block" /><el-option label="仅告警" value="alert" /></el-select></el-form-item>
</el-form>
<template #footer><el-button @click="dialogVisible=false">取消</el-button><el-button type="primary" @click="save">保存</el-button></template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
const activeTab = ref('blacklist')
const auth = () => ({ headers: { Authorization: `Bearer ${localStorage.getItem('token')}`, 'Content-Type': 'application/json' } })
const blacklist = ref<any[]>([])
const loading = ref(false)
const violations = ref<any[]>([])
const vLoading = ref(false)
const vFilter = ref('')
const dialogVisible = ref(false)
const form = reactive({ name_pattern: '', category: 'custom', action: 'block' })
function catLabel(c: string) { return { game:'游戏',social:'社交',vpn:'VPN',mining:'挖矿',custom:'自定义' }[c]||c }
function actionLabel(a: string) { return { blocked_install:'已阻止',auto_uninstalled:'已卸载',alerted:'已告警' }[a]||a }
function actionTag(a: string) { return { blocked_install:'danger',auto_uninstalled:'warning',alerted:'' }[a]||'info' }
async function fetchBlacklist() { loading.value=true; try{const r=await fetch('/api/plugins/software-blocker/blacklist',auth()).then(r=>r.json());if(r.success)blacklist.value=r.data.blacklist||[]}finally{loading.value=false} }
async function fetchViolations() { vLoading.value=true; try{const p=new URLSearchParams();if(vFilter.value)p.set('device_uid',vFilter.value);const r=await fetch(`/api/plugins/software-blocker/violations?${p}`,auth()).then(r=>r.json());if(r.success)violations.value=r.data.violations||[]}finally{vLoading.value=false} }
function showDialog(){Object.assign(form,{name_pattern:'',category:'custom',action:'block'});dialogVisible.value=true}
async function save(){const r=await fetch('/api/plugins/software-blocker/blacklist',{method:'POST',...auth(),body:JSON.stringify(form)}).then(r=>r.json());if(r.success){ElMessage.success('已添加');dialogVisible.value=false;fetchBlacklist()}else{ElMessage.error(r.error)}}
async function remove(id:number){await ElMessageBox.confirm('确定删除?','确认',{type:'warning'});await fetch(`/api/plugins/software-blocker/blacklist/${id}`,{method:'DELETE',...auth()});ElMessage.success('已删除');fetchBlacklist()}
onMounted(()=>{fetchBlacklist();fetchViolations()})
</script>
<style scoped>.plugin-page{padding:20px}.toolbar{display:flex;gap:12px;margin-bottom:16px}</style>

View File

@@ -0,0 +1,74 @@
<template>
<div class="plugin-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="每日使用统计" name="daily">
<div class="toolbar">
<el-input v-model="uidFilter" placeholder="终端UID" style="width:200px" clearable @input="fetchDaily" />
</div>
<el-table :data="dailyData" v-loading="loading" stripe size="small">
<el-table-column prop="device_uid" label="终端" width="160" show-overflow-tooltip />
<el-table-column prop="date" label="日期" width="120" />
<el-table-column label="活跃时间" width="120">
<template #default="{ row }">{{ formatMinutes(row.total_active_minutes) }}</template>
</el-table-column>
<el-table-column label="空闲时间" width="120">
<template #default="{ row }">{{ formatMinutes(row.total_idle_minutes) }}</template>
</el-table-column>
<el-table-column prop="first_active_at" label="首次活跃" width="170" />
<el-table-column prop="last_active_at" label="最后活跃" width="170" />
</el-table>
</el-tab-pane>
<el-tab-pane label="应用使用详情" name="apps">
<div class="toolbar"><el-input v-model="appUid" placeholder="终端UID" style="width:200px" clearable @input="fetchApps" /></div>
<el-table :data="appData" v-loading="appLoading" stripe size="small">
<el-table-column prop="app_name" label="应用名称" min-width="200" />
<el-table-column prop="date" label="日期" width="120" />
<el-table-column label="使用时长" width="120">
<template #default="{ row }">{{ formatMinutes(row.usage_minutes) }}</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="使用排行" name="leaderboard">
<el-table :data="board" v-loading="boardLoading" stripe size="small">
<el-table-column type="index" label="#" width="60" />
<el-table-column prop="device_uid" label="终端" min-width="200" />
<el-table-column label="7天总时长" width="140">
<template #default="{ row }">{{ formatMinutes(row.total_minutes) }}</template>
</el-table-column>
</el-table>
</el-tab-pane>
</el-tabs>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
const activeTab = ref('daily')
const auth = () => ({ headers: { Authorization: `Bearer ${localStorage.getItem('token')}` } })
const uidFilter = ref('')
const dailyData = ref<any[]>([])
const loading = ref(false)
const appUid = ref('')
const appData = ref<any[]>([])
const appLoading = ref(false)
const board = ref<any[]>([])
const boardLoading = ref(false)
function formatMinutes(m: number) { if(m>=60) return `${Math.floor(m/60)}h${m%60}m`; return `${m}m` }
async function fetchDaily() {
loading.value=true
try{const params=new URLSearchParams();if(uidFilter.value)params.set('device_uid',uidFilter.value)
const r=await fetch(`/api/plugins/usage-timer/daily?${params}`,auth()).then(r=>r.json());if(r.success)dailyData.value=r.data.daily||[]}finally{loading.value=false}
}
async function fetchApps() {
appLoading.value=true
try{const params=new URLSearchParams();if(appUid.value)params.set('device_uid',appUid.value)
const r=await fetch(`/api/plugins/usage-timer/app-usage?${params}`,auth()).then(r=>r.json());if(r.success)appData.value=r.data.app_usage||[]}finally{appLoading.value=false}
}
async function fetchBoard() {
boardLoading.value=true
try{const r=await fetch('/api/plugins/usage-timer/leaderboard',auth()).then(r=>r.json());if(r.success)board.value=r.data.leaderboard||[]}finally{boardLoading.value=false}
}
onMounted(()=>{fetchDaily();fetchApps();fetchBoard()})
</script>
<style scoped>.plugin-page{padding:20px}.toolbar{display:flex;gap:12px;margin-bottom:16px}</style>

View File

@@ -0,0 +1,53 @@
<template>
<div class="plugin-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="文件操作日志" name="log">
<div class="toolbar">
<el-input v-model="filters.device_uid" placeholder="终端UID" style="width:200px" clearable @input="fetchLog" />
<el-select v-model="filters.operation" placeholder="操作类型" clearable style="width:130px" @change="fetchLog">
<el-option label="创建" value="create" /><el-option label="删除" value="delete" /><el-option label="修改" value="modify" /><el-option label="重命名" value="rename" />
</el-select>
</div>
<el-table :data="log" v-loading="loading" stripe size="small">
<el-table-column prop="device_uid" label="终端" width="140" show-overflow-tooltip />
<el-table-column prop="usb_serial" label="U盘序列号" width="140" show-overflow-tooltip />
<el-table-column prop="drive_letter" label="盘符" width="70" />
<el-table-column prop="operation" label="操作" width="80">
<template #default="{ row }"><el-tag :type="opTag(row.operation)" size="small">{{ opLabel(row.operation) }}</el-tag></template>
</el-table-column>
<el-table-column prop="file_path" label="文件路径" min-width="250" show-overflow-tooltip />
<el-table-column label="大小" width="100">
<template #default="{ row }">{{ row.file_size ? formatSize(row.file_size) : '-' }}</template>
</el-table-column>
<el-table-column prop="timestamp" label="时间" width="170" />
</el-table>
</el-tab-pane>
<el-tab-pane label="设备汇总" name="summary">
<el-table :data="summaryData" v-loading="sLoading" stripe size="small">
<el-table-column prop="device_uid" label="终端" width="160" show-overflow-tooltip />
<el-table-column prop="op_count" label="操作次数(7天)" width="130" />
<el-table-column prop="usb_count" label="U盘数量" width="110" />
<el-table-column prop="first_op" label="最早操作" width="170" />
<el-table-column prop="last_op" label="最近操作" width="170" />
</el-table>
</el-tab-pane>
</el-tabs>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
const activeTab = ref('log')
const auth = () => ({ headers: { Authorization: `Bearer ${localStorage.getItem('token')}` } })
const filters = reactive({ device_uid: '', operation: '' })
const log = ref<any[]>([])
const loading = ref(false)
const summaryData = ref<any[]>([])
const sLoading = ref(false)
function opTag(o:string){return{create:'success',delete:'danger',modify:'warning',rename:'info'}[o]||'info'}
function opLabel(o:string){return{create:'创建',delete:'删除',modify:'修改',rename:'重命名'}[o]||o}
function formatSize(b:number){if(b>=1073741824)return`${(b/1073741824).toFixed(1)}GB`;if(b>=1048576)return`${(b/1048576).toFixed(1)}MB`;if(b>=1024)return`${(b/1024).toFixed(1)}KB`;return`${b}B`}
async function fetchLog(){loading.value=true;try{const p=new URLSearchParams();if(filters.device_uid)p.set('device_uid',filters.device_uid);if(filters.operation)p.set('operation',filters.operation);const r=await fetch(`/api/plugins/usb-file-audit/log?${p}`,auth()).then(r=>r.json());if(r.success)log.value=r.data.operations||[]}finally{loading.value=false}}
async function fetchSummary(){sLoading.value=true;try{const r=await fetch('/api/plugins/usb-file-audit/summary',auth()).then(r=>r.json());if(r.success)summaryData.value=r.data.summary||[]}finally{sLoading.value=false}}
onMounted(()=>{fetchLog();fetchSummary()})
</script>
<style scoped>.plugin-page{padding:20px}.toolbar{display:flex;gap:12px;margin-bottom:16px}</style>

View File

@@ -0,0 +1,112 @@
<template>
<div class="plugin-page">
<el-card shadow="hover">
<template #header>
<div style="display:flex;justify-content:space-between;align-items:center">
<span class="card-title">水印配置</span>
<el-button type="primary" size="small" @click="showDialog()">新建配置</el-button>
</div>
</template>
<el-table :data="configs" v-loading="loading" stripe size="small">
<el-table-column prop="target_type" label="应用范围" width="100" />
<el-table-column prop="target_id" label="目标" width="140" show-overflow-tooltip />
<el-table-column prop="content" label="水印内容" min-width="250" show-overflow-tooltip />
<el-table-column prop="font_size" label="字号" width="70" />
<el-table-column label="透明度" width="100">
<template #default="{ row }">{{ (row.opacity * 100).toFixed(0) }}%</template>
</el-table-column>
<el-table-column prop="color" label="颜色" width="80" />
<el-table-column prop="angle" label="角度" width="70" />
<el-table-column prop="enabled" label="启用" width="70">
<template #default="{ row }"><el-tag :type="row.enabled?'success':'info'" size="small">{{ row.enabled?'是':'否' }}</el-tag></template>
</el-table-column>
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button link type="primary" size="small" @click="showDialog(row)">编辑</el-button>
<el-button link type="danger" size="small" @click="remove(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-card>
<el-card shadow="hover" style="margin-top:20px">
<template #header><span class="card-title">水印预览</span></template>
<div class="preview-area">
<div class="watermark-overlay" :style="watermarkStyle">
<span v-for="i in 12" :key="i" class="wm-text">{{ previewContent }}</span>
</div>
<div class="preview-content">
<p style="color:#606266;font-size:14px">此区域模拟用户桌面效果</p>
<p style="color:#909399;font-size:12px">水印内容会以设定角度和透明度覆盖整个屏幕</p>
</div>
</div>
</el-card>
<el-dialog v-model="visible" :title="editing?'编辑配置':'新建配置'" width="520px">
<el-form :model="form" label-width="80px">
<el-form-item label="应用范围">
<el-select v-model="form.target_type"><el-option label="全局" value="global" /><el-option label="分组" value="group" /><el-option label="指定设备" value="device" /></el-select>
</el-form-item>
<el-form-item label="目标ID" v-if="form.target_type!=='global'"><el-input v-model="form.target_id" /></el-form-item>
<el-form-item label="水印内容"><el-input v-model="form.content" type="textarea" :rows="2" placeholder="支持变量: {company} {username} {hostname} {date} {time}" /></el-form-item>
<el-form-item label="字号"><el-input-number v-model="form.font_size" :min="8" :max="48" /></el-form-item>
<el-form-item label="透明度"><el-slider v-model="form.opacity" :min="5" :max="50" :step="1" :format-tooltip="(v:number)=>`${v}%`" /></el-form-item>
<el-form-item label="颜色"><el-color-picker v-model="form.color" /></el-form-item>
<el-form-item label="角度"><el-input-number v-model="form.angle" :min="-90" :max="90" /></el-form-item>
<el-form-item label="启用"><el-switch v-model="form.enabled" /></el-form-item>
</el-form>
<template #footer><el-button @click="visible=false">取消</el-button><el-button type="primary" @click="save">保存</el-button></template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, computed, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
const auth = () => ({ headers: { Authorization: `Bearer ${localStorage.getItem('token')}`, 'Content-Type': 'application/json' } })
const configs = ref<any[]>([])
const loading = ref(false)
const visible = ref(false)
const editing = ref<any>(null)
const form = reactive({ target_type: 'global', target_id: '', content: '{company} | {username} | {date}', font_size: 14, opacity: 15, color: '#808080', angle: -30, enabled: true })
const previewContent = computed(() => form.content
.replace('{company}', 'CSM Corp')
.replace('{username}', 'admin')
.replace('{hostname}', 'DESKTOP-01')
.replace('{date}', new Date().toLocaleDateString())
.replace('{time}', new Date().toLocaleTimeString()))
const watermarkStyle = computed(() => ({
transform: `rotate(${form.angle}deg)`,
opacity: form.opacity / 100,
fontSize: `${form.font_size}px`,
color: form.color,
}))
async function fetchConfigs() {
loading.value = true
try { const r = await fetch('/api/plugins/watermark/config', auth()).then(r => r.json()); if (r.success) configs.value = r.data.configs || [] } finally { loading.value = false }
}
function showDialog(row?: any) {
if (row) { editing.value = row; Object.assign(form, { target_type: row.target_type, target_id: row.target_id || '', content: row.content, font_size: row.font_size, opacity: Math.round(row.opacity * 100), color: row.color, angle: row.angle, enabled: row.enabled }) }
else { editing.value = null; Object.assign(form, { target_type: 'global', target_id: '', content: '{company} | {username} | {date}', font_size: 14, opacity: 15, color: '#808080', angle: -30, enabled: true }) }
visible.value = true
}
async function save() {
const body = { ...form, opacity: form.opacity / 100 }
const url = editing.value ? `/api/plugins/watermark/config/${editing.value.id}` : '/api/plugins/watermark/config'
const method = editing.value ? 'PUT' : 'POST'
const r = await fetch(url, { method, ...auth(), body: JSON.stringify(body) }).then(r => r.json())
if (r.success) { ElMessage.success('已保存'); visible.value = false; fetchConfigs() } else { ElMessage.error(r.error) }
}
async function remove(id: number) { await ElMessageBox.confirm('确定删除?', '确认', { type: 'warning' }); await fetch(`/api/plugins/watermark/config/${id}`, { method: 'DELETE', ...auth() }); ElMessage.success('已删除'); fetchConfigs() }
onMounted(() => fetchConfigs())
</script>
<style scoped>
.plugin-page{padding:20px}
.card-title{font-weight:600;font-size:15px}
.preview-area{position:relative;height:200px;border:1px solid #e4e7ed;border-radius:8px;overflow:hidden;background:#f5f7fa}
.watermark-overlay{position:absolute;inset:0;display:flex;flex-wrap:wrap;align-items:center;justify-content:center;gap:80px;pointer-events:none}
.wm-text{white-space:nowrap;user-select:none}
.preview-content{position:absolute;inset:0;display:flex;flex-direction:column;align-items:center;justify-content:center}
</style>

View File

@@ -0,0 +1,105 @@
<template>
<div class="plugin-page">
<el-tabs v-model="activeTab">
<el-tab-pane label="过滤规则" name="rules">
<div class="toolbar">
<el-button type="primary" @click="showRuleDialog()">新建规则</el-button>
</div>
<el-table :data="rules" v-loading="loading" stripe size="small">
<el-table-column prop="rule_type" label="类型" width="100">
<template #default="{ row }">
<el-tag :type="row.rule_type==='blacklist'?'danger':row.rule_type==='whitelist'?'success':'info'" size="small">
{{ ruleTypeLabel(row.rule_type) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="pattern" label="匹配模式" min-width="200" />
<el-table-column prop="target_type" label="应用范围" width="100" />
<el-table-column prop="enabled" label="启用" width="80">
<template #default="{ row }">
<el-tag :type="row.enabled?'success':'info'" size="small">{{ row.enabled?'是':'否' }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="created_at" label="创建时间" width="170" />
<el-table-column label="操作" width="120" fixed="right">
<template #default="{ row }">
<el-button link type="primary" size="small" @click="showRuleDialog(row)">编辑</el-button>
<el-button link type="danger" size="small" @click="deleteRule(row.id)">删除</el-button>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
<el-tab-pane label="访问日志" name="log">
<el-table :data="accessLog" v-loading="logLoading" stripe size="small">
<el-table-column prop="device_uid" label="终端" width="160" show-overflow-tooltip />
<el-table-column prop="url" label="URL" min-width="300" show-overflow-tooltip />
<el-table-column label="动作" width="80">
<template #default="{ row }">
<el-tag :type="row.action==='blocked'?'danger':'success'" size="small">{{ row.action==='blocked'?'拦截':'放行' }}</el-tag>
</template>
</el-table-column>
<el-table-column prop="timestamp" label="时间" width="170" />
</el-table>
</el-tab-pane>
</el-tabs>
<el-dialog v-model="dialogVisible" :title="editing?'编辑规则':'新建规则'" width="480px">
<el-form :model="form" label-width="80px">
<el-form-item label="规则类型">
<el-select v-model="form.rule_type"><el-option label="黑名单" value="blacklist" /><el-option label="白名单" value="whitelist" /><el-option label="分类" value="category" /></el-select>
</el-form-item>
<el-form-item label="匹配模式"><el-input v-model="form.pattern" placeholder="*.example.com" /></el-form-item>
<el-form-item label="应用范围">
<el-select v-model="form.target_type"><el-option label="全局" value="global" /><el-option label="分组" value="group" /><el-option label="指定设备" value="device" /></el-select>
</el-form-item>
<el-form-item label="启用"><el-switch v-model="form.enabled" /></el-form-item>
</el-form>
<template #footer><el-button @click="dialogVisible=false">取消</el-button><el-button type="primary" @click="saveRule">保存</el-button></template>
</el-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted } from 'vue'
import { ElMessage, ElMessageBox } from 'element-plus'
const activeTab = ref('rules')
const auth = () => ({ headers: { Authorization: `Bearer ${localStorage.getItem('token')}` } })
const rules = ref<any[]>([])
const loading = ref(false)
const accessLog = ref<any[]>([])
const logLoading = ref(false)
const dialogVisible = ref(false)
const editing = ref<any>(null)
const form = reactive({ rule_type: 'blacklist', pattern: '', target_type: 'global', target_id: '', enabled: true })
function ruleTypeLabel(t: string) { return { blacklist: '黑名单', whitelist: '白名单', category: '分类' }[t] || t }
async function fetchRules() {
loading.value = true
try { const r = await fetch('/api/plugins/web-filter/rules', auth()).then(r=>r.json()); if(r.success) rules.value = r.data.rules||[] } finally { loading.value = false }
}
async function fetchLog() {
logLoading.value = true
try { const r = await fetch('/api/plugins/web-filter/log', auth()).then(r=>r.json()); if(r.success) accessLog.value = r.data.log||[] } finally { logLoading.value = false }
}
function showRuleDialog(row?: any) {
if(row){ editing.value=row; Object.assign(form,{rule_type:row.rule_type,pattern:row.pattern,target_type:row.target_type,target_id:row.target_id||'',enabled:row.enabled}) }
else{ editing.value=null; Object.assign(form,{rule_type:'blacklist',pattern:'',target_type:'global',target_id:'',enabled:true}) }
dialogVisible.value=true
}
async function saveRule() {
const url = editing.value ? `/api/plugins/web-filter/rules/${editing.value.id}` : '/api/plugins/web-filter/rules'
const method = editing.value ? 'PUT' : 'POST'
const res = await fetch(url,{method,...auth(),headers:{...auth().headers,'Content-Type':'application/json'},body:JSON.stringify(form)}).then(r=>r.json())
if(res.success){ElMessage.success('已保存');dialogVisible.value=false;fetchRules()}else{ElMessage.error(res.error)}
}
async function deleteRule(id: number) {
await ElMessageBox.confirm('确定删除?','确认',{type:'warning'})
await fetch(`/api/plugins/web-filter/rules/${id}`,{method:'DELETE',...auth()})
ElMessage.success('已删除'); fetchRules()
}
onMounted(()=>{fetchRules();fetchLog()})
</script>
<style scoped>.plugin-page{padding:20px}.toolbar{display:flex;gap:12px;margin-bottom:16px}</style>

1
web/src/vite-env.d.ts vendored Normal file
View File

@@ -0,0 +1 @@
/// <reference types="vite/client" />

25
web/tsconfig.json Normal file
View File

@@ -0,0 +1,25 @@
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"module": "ESNext",
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"skipLibCheck": true,
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "preserve",
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true,
"baseUrl": ".",
"paths": {
"@/*": ["src/*"]
}
},
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.tsx", "src/**/*.vue"],
"references": [{ "path": "./tsconfig.node.json" }]
}

10
web/tsconfig.node.json Normal file
View File

@@ -0,0 +1,10 @@
{
"compilerOptions": {
"composite": true,
"skipLibCheck": true,
"module": "ESNext",
"moduleResolution": "bundler",
"allowSyntheticDefaultImports": true
},
"include": ["vite.config.ts"]
}

62
web/vite.config.ts Normal file
View File

@@ -0,0 +1,62 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import AutoImport from 'unplugin-auto-import/vite'
import Components from 'unplugin-vue-components/vite'
import { ElementPlusResolver } from 'unplugin-vue-components/resolvers'
import { resolve } from 'path'
export default defineConfig({
plugins: [
vue(),
AutoImport({
resolvers: [ElementPlusResolver()],
}),
Components({
resolvers: [ElementPlusResolver()],
}),
],
resolve: {
alias: {
'@': resolve(__dirname, 'src'),
},
},
optimizeDeps: {
include: [
'element-plus',
'@element-plus/icons-vue',
'echarts',
],
},
server: {
port: 3000,
proxy: {
'/api': {
target: 'http://localhost:8080',
changeOrigin: true,
},
'/ws': {
target: 'ws://localhost:8080',
ws: true,
},
'/health': {
target: 'http://localhost:8080',
changeOrigin: true,
},
},
},
build: {
outDir: 'dist',
assetsDir: 'assets',
sourcemap: false,
chunkSizeWarningLimit: 500,
rollupOptions: {
output: {
manualChunks: {
'element-plus': ['element-plus'],
'echarts': ['echarts'],
'vendor': ['vue', 'vue-router', 'pinia'],
},
},
},
},
})