feat: 添加新插件支持及多项功能改进

- 新增磁盘加密、打印审计和剪贴板管控插件支持
- 优化水印插件显示效果,支持中文及更多Unicode字符
- 改进硬件资产收集逻辑,更准确获取磁盘和显卡信息
- 增强API错误处理,添加详细日志记录
- 完善前端界面,新增插件管理页面
- 修复多个UI问题,优化页面过渡效果
- 添加环境变量覆盖配置功能
- 实现插件状态管理API
- 更新文档和变更日志
- 添加安装程序脚本支持
This commit is contained in:
iven
2026-04-10 22:21:05 +08:00
parent 3d39f0e426
commit b5333d8c93
101 changed files with 4487 additions and 661 deletions

View File

@@ -41,11 +41,13 @@ tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
# Static file embedding
include_dir = "0.7"
# Utilities
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
include_dir = "0.7"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"

View File

@@ -63,7 +63,9 @@ pub async fn cleanup_task(state: AppState) {
}
}
/// Send email notification
/// Send email notification.
/// TODO: Wire up email notifications to alert rules.
#[allow(dead_code)]
pub async fn send_email(
smtp_config: &crate::config::SmtpConfig,
to: &str,
@@ -97,8 +99,10 @@ pub async fn send_email(
/// Shared HTTP client for webhook notifications.
/// Lazily initialized once and reused across calls to benefit from connection pooling.
#[allow(dead_code)]
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
#[allow(dead_code)]
fn webhook_client() -> &'static reqwest::Client {
WEBHOOK_CLIENT.get_or_init(|| {
reqwest::Client::builder()
@@ -108,7 +112,9 @@ fn webhook_client() -> &'static reqwest::Client {
})
}
/// Send webhook notification
/// Send webhook notification.
/// TODO: Wire up webhook notifications to alert rules.
#[allow(dead_code)]
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
webhook_client().post(url)
.json(payload)

View File

@@ -2,7 +2,7 @@ use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
use super::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct AssetListParams {

View File

@@ -1,4 +1,4 @@
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response, Extension};
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;

View File

@@ -1,4 +1,4 @@
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
use axum::{routing::{get, post, put, delete}, Router, Json, middleware};
use serde::{Deserialize, Serialize};
use crate::AppState;

View File

@@ -0,0 +1,247 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: Option<String>, // "block" | "allow"
pub direction: Option<String>, // "out" | "in" | "both"
pub source_process: Option<String>,
pub target_process: Option<String>,
pub content_pattern: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
pub enabled: Option<bool>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled, updated_at \
FROM clipboard_rules ORDER BY updated_at DESC"
)
.fetch_all(&state.db)
.await
{
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"target_type": r.get::<String, _>("target_type"),
"target_id": r.get::<Option<String>, _>("target_id"),
"rule_type": r.get::<String, _>("rule_type"),
"direction": r.get::<String, _>("direction"),
"source_process": r.get::<Option<String>, _>("source_process"),
"target_process": r.get::<Option<String>, _>("target_process"),
"content_pattern": r.get::<Option<String>, _>("content_pattern"),
"enabled": r.get::<bool, _>("enabled"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect::<Vec<_>>()
}))),
Err(e) => Json(ApiResponse::internal_error("query clipboard rules", e)),
}
}
pub async fn create_rule(
State(state): State<AppState>,
Json(req): Json<CreateRuleRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
let rule_type = req.rule_type.unwrap_or_else(|| "block".to_string());
let direction = req.direction.unwrap_or_else(|| "out".to_string());
// Validate inputs
if !matches!(rule_type.as_str(), "block" | "allow") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
}
if !matches!(direction.as_str(), "out" | "in" | "both") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("direction must be 'out', 'in' or 'both'")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
let enabled = req.enabled.unwrap_or(true);
match sqlx::query(
"INSERT INTO clipboard_rules (target_type, target_id, rule_type, direction, source_process, target_process, content_pattern, enabled) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
)
.bind(&target_type)
.bind(&req.target_id)
.bind(&rule_type)
.bind(&direction)
.bind(&req.source_process)
.bind(&req.target_process)
.bind(&req.content_pattern)
.bind(enabled)
.execute(&state.db)
.await
{
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_clipboard_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(
&state.db, &state.clients, MessageType::ClipboardRules,
&serde_json::json!({"rules": rules}),
&target_type, req.target_id.as_deref(),
).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create clipboard rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest {
pub rule_type: Option<String>,
pub direction: Option<String>,
pub source_process: Option<String>,
pub target_process: Option<String>,
pub content_pattern: Option<String>,
pub enabled: Option<bool>,
}
pub async fn update_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateRuleRequest>,
) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM clipboard_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query clipboard rule", e)),
};
let rule_type = body.rule_type.or_else(|| Some(existing.get::<String, _>("rule_type")));
let direction = body.direction.or_else(|| Some(existing.get::<String, _>("direction")));
let source_process = body.source_process.or_else(|| existing.get::<Option<String>, _>("source_process"));
let target_process = body.target_process.or_else(|| existing.get::<Option<String>, _>("target_process"));
let content_pattern = body.content_pattern.or_else(|| existing.get::<Option<String>, _>("content_pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query(
"UPDATE clipboard_rules SET rule_type = ?, direction = ?, source_process = ?, target_process = ?, content_pattern = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&rule_type)
.bind(&direction)
.bind(&source_process)
.bind(&target_process)
.bind(&content_pattern)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_clipboard_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(
&state.db, &state.clients, MessageType::ClipboardRules,
&serde_json::json!({"rules": rules}),
&target_type_val, target_id_val.as_deref(),
).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update clipboard rule", e)),
}
}
pub async fn delete_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM clipboard_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM clipboard_rules WHERE id = ?")
.bind(id)
.execute(&state.db)
.await
{
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_clipboard_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(
&state.db, &state.clients, MessageType::ClipboardRules,
&serde_json::json!({"rules": rules}),
&target_type, target_id.as_deref(),
).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
pub async fn list_violations(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, source_process, target_process, content_preview, action_taken, timestamp, reported_at \
FROM clipboard_violations ORDER BY reported_at DESC LIMIT 200"
)
.fetch_all(&state.db)
.await
{
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"violations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"source_process": r.get::<Option<String>, _>("source_process"),
"target_process": r.get::<Option<String>, _>("target_process"),
"content_preview": r.get::<Option<String>, _>("content_preview"),
"action_taken": r.get::<String, _>("action_taken"),
"timestamp": r.get::<String, _>("timestamp"),
"reported_at": r.get::<String, _>("reported_at"),
})).collect::<Vec<_>>()
}))),
Err(e) => Json(ApiResponse::internal_error("query clipboard violations", e)),
}
}
async fn fetch_clipboard_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
FROM clipboard_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
)
.bind(target_id),
"group" => sqlx::query(
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
FROM clipboard_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'group' AND target_id = ?))"
)
.bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
FROM clipboard_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"direction": r.get::<String, _>("direction"),
"source_process": r.get::<Option<String>, _>("source_process"),
"target_process": r.get::<Option<String>, _>("target_process"),
"content_pattern": r.get::<Option<String>, _>("content_pattern"),
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,97 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct StatusFilter {
pub device_uid: Option<String>,
}
pub async fn list_status(
State(state): State<AppState>,
Query(filter): Query<StatusFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let result = if let Some(uid) = &filter.device_uid {
sqlx::query(
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
WHERE s.device_uid = ? ORDER BY s.drive_letter"
)
.bind(uid)
.fetch_all(&state.db)
.await
} else {
sqlx::query(
"SELECT s.id, s.device_uid, s.drive_letter, s.volume_name, s.encryption_method, \
s.protection_status, s.encryption_percentage, s.lock_status, s.reported_at, s.updated_at, \
d.hostname FROM disk_encryption_status s LEFT JOIN devices d ON s.device_uid = d.device_uid \
ORDER BY s.device_uid, s.drive_letter"
)
.fetch_all(&state.db)
.await
};
match result {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"entries": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<Option<String>, _>("hostname"),
"drive_letter": r.get::<String, _>("drive_letter"),
"volume_name": r.get::<Option<String>, _>("volume_name"),
"encryption_method": r.get::<Option<String>, _>("encryption_method"),
"protection_status": r.get::<String, _>("protection_status"),
"encryption_percentage": r.get::<f64, _>("encryption_percentage"),
"lock_status": r.get::<String, _>("lock_status"),
"reported_at": r.get::<String, _>("reported_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect::<Vec<_>>()
}))),
Err(e) => Json(ApiResponse::internal_error("query disk encryption status", e)),
}
}
pub async fn list_alerts(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT a.id, a.device_uid, a.drive_letter, a.alert_type, a.status, a.created_at, a.resolved_at, \
d.hostname FROM encryption_alerts a LEFT JOIN devices d ON a.device_uid = d.device_uid \
ORDER BY a.created_at DESC"
)
.fetch_all(&state.db)
.await
{
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"alerts": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"hostname": r.get::<Option<String>, _>("hostname"),
"drive_letter": r.get::<String, _>("drive_letter"),
"alert_type": r.get::<String, _>("alert_type"),
"status": r.get::<String, _>("status"),
"created_at": r.get::<String, _>("created_at"),
"resolved_at": r.get::<Option<String>, _>("resolved_at"),
})).collect::<Vec<_>>()
}))),
Err(e) => Json(ApiResponse::internal_error("query encryption alerts", e)),
}
}
pub async fn acknowledge_alert(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<()>> {
match sqlx::query(
"UPDATE encryption_alerts SET status = 'acknowledged', resolved_at = datetime('now') WHERE id = ? AND status = 'open'"
)
.bind(id)
.execute(&state.db)
.await
{
Ok(r) if r.rows_affected() > 0 => Json(ApiResponse::ok(())),
Ok(_) => Json(ApiResponse::error("Alert not found or already acknowledged")),
Err(e) => Json(ApiResponse::internal_error("acknowledge encryption alert", e)),
}
}

View File

@@ -4,6 +4,10 @@ pub mod software_blocker;
pub mod popup_blocker;
pub mod usb_file_audit;
pub mod watermark;
pub mod disk_encryption;
pub mod print_audit;
pub mod clipboard_control;
pub mod plugin_control;
use axum::{Router, routing::{get, post, put}};
use crate::AppState;
@@ -29,6 +33,18 @@ pub fn read_routes() -> Router<AppState> {
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
// Watermark
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
// Disk Encryption
.route("/api/plugins/disk-encryption/status", get(disk_encryption::list_status))
.route("/api/plugins/disk-encryption/alerts", get(disk_encryption::list_alerts))
.route("/api/plugins/disk-encryption/alerts/:id/acknowledge", put(disk_encryption::acknowledge_alert))
// Print Audit
.route("/api/plugins/print-audit/events", get(print_audit::list_events))
.route("/api/plugins/print-audit/events/:id", get(print_audit::get_event))
// Clipboard Control
.route("/api/plugins/clipboard-control/rules", get(clipboard_control::list_rules))
.route("/api/plugins/clipboard-control/violations", get(clipboard_control::list_violations))
// Plugin Control
.route("/api/plugins/control", get(plugin_control::list_plugins))
}
/// Write plugin routes (admin only — require_admin middleware applied by caller)
@@ -46,4 +62,9 @@ pub fn write_routes() -> Router<AppState> {
// Watermark
.route("/api/plugins/watermark/config", post(watermark::create_config))
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
// Clipboard Control
.route("/api/plugins/clipboard-control/rules", post(clipboard_control::create_rule))
.route("/api/plugins/clipboard-control/rules/:id", put(clipboard_control::update_rule).delete(clipboard_control::delete_rule))
// Plugin Control (enable/disable)
.route("/api/plugins/control/:plugin_name", put(plugin_control::set_plugin_state))
}

View File

@@ -0,0 +1,95 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
use csm_protocol::{MessageType, PluginControlPayload};
/// List all plugin states
pub async fn list_plugins(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, plugin_name, enabled, target_type, target_id, updated_at FROM plugin_state ORDER BY plugin_name"
)
.fetch_all(&state.db)
.await
{
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"plugins": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"plugin_name": r.get::<String, _>("plugin_name"),
"enabled": r.get::<bool, _>("enabled"),
"target_type": r.get::<String, _>("target_type"),
"target_id": r.get::<Option<String>, _>("target_id"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect::<Vec<_>>()
}))),
Err(e) => Json(ApiResponse::internal_error("query plugin state", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct SetPluginStateRequest {
pub enabled: bool,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
/// Enable or disable a plugin. Pushes PluginEnable/PluginDisable to matching clients.
pub async fn set_plugin_state(
State(state): State<AppState>,
Path(plugin_name): Path<String>,
Json(req): Json<SetPluginStateRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let valid_plugins = [
"web_filter", "usage_timer", "software_blocker",
"popup_blocker", "usb_file_audit", "watermark",
"disk_encryption", "usb_audit", "print_audit", "clipboard_control",
];
if !valid_plugins.contains(&plugin_name.as_str()) {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("unknown plugin name")));
}
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
// Upsert plugin state
match sqlx::query(
"INSERT INTO plugin_state (plugin_name, enabled, target_type, target_id, updated_at) \
VALUES (?, ?, ?, ?, datetime('now')) \
ON CONFLICT(plugin_name) DO UPDATE SET enabled = excluded.enabled, target_type = excluded.target_type, \
target_id = excluded.target_id, updated_at = datetime('now')"
)
.bind(&plugin_name)
.bind(req.enabled)
.bind(&target_type)
.bind(&req.target_id)
.execute(&state.db)
.await
{
Ok(_) => {
// Push enable/disable to matching clients
let payload = PluginControlPayload {
plugin_name: plugin_name.clone(),
enabled: req.enabled,
};
let msg_type = if req.enabled {
MessageType::PluginEnable
} else {
MessageType::PluginDisable
};
push_to_targets(
&state.db, &state.clients, msg_type, &payload,
&target_type, req.target_id.as_deref(),
).await;
(StatusCode::OK, Json(ApiResponse::ok(serde_json::json!({
"plugin_name": plugin_name,
"enabled": req.enabled,
}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("set plugin state", e))),
}
}

View File

@@ -0,0 +1,101 @@
use axum::{extract::{State, Query, Path}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct ListEventsParams {
pub device_uid: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_events(
State(state): State<AppState>,
Query(params): Query<ListEventsParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let page = params.page.unwrap_or(1).max(1);
let page_size = params.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
let (rows, total) = if let Some(ref device_uid) = params.device_uid {
let rows = sqlx::query(
"SELECT id, device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp, reported_at \
FROM print_events WHERE device_uid = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?"
)
.bind(device_uid)
.bind(page_size)
.bind(offset)
.fetch_all(&state.db).await;
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM print_events WHERE device_uid = ?")
.bind(device_uid)
.fetch_one(&state.db).await.unwrap_or(0);
(rows, total)
} else {
let rows = sqlx::query(
"SELECT id, device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp, reported_at \
FROM print_events ORDER BY timestamp DESC LIMIT ? OFFSET ?"
)
.bind(page_size)
.bind(offset)
.fetch_all(&state.db).await;
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM print_events")
.fetch_one(&state.db).await.unwrap_or(0);
(rows, total)
};
match rows {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"events": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"document_name": r.get::<Option<String>, _>("document_name"),
"printer_name": r.get::<Option<String>, _>("printer_name"),
"pages": r.get::<Option<i32>, _>("pages"),
"copies": r.get::<Option<i32>, _>("copies"),
"user_name": r.get::<Option<String>, _>("user_name"),
"file_size_bytes": r.get::<Option<i64>, _>("file_size_bytes"),
"timestamp": r.get::<String, _>("timestamp"),
"reported_at": r.get::<String, _>("reported_at"),
})).collect::<Vec<_>>(),
"total": total,
"page": page,
"page_size": page_size,
}))),
Err(e) => Json(ApiResponse::internal_error("query print events", e)),
}
}
pub async fn get_event(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp, reported_at \
FROM print_events WHERE id = ?"
)
.bind(id)
.fetch_optional(&state.db)
.await
{
Ok(Some(row)) => Json(ApiResponse::ok(serde_json::json!({
"id": row.get::<i64, _>("id"),
"device_uid": row.get::<String, _>("device_uid"),
"document_name": row.get::<Option<String>, _>("document_name"),
"printer_name": row.get::<Option<String>, _>("printer_name"),
"pages": row.get::<Option<i32>, _>("pages"),
"copies": row.get::<Option<i32>, _>("copies"),
"user_name": row.get::<Option<String>, _>("user_name"),
"file_size_bytes": row.get::<Option<i64>, _>("file_size_bytes"),
"timestamp": row.get::<String, _>("timestamp"),
"reported_at": row.get::<String, _>("reported_at"),
}))),
Ok(None) => Json(ApiResponse::error("Print event not found")),
Err(e) => Json(ApiResponse::internal_error("query print event", e)),
}
}

View File

@@ -143,6 +143,9 @@ async fn fetch_blacklist_for_push(
"device" => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
"group" => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'group' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
),

View File

@@ -6,9 +6,6 @@ use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String,
@@ -144,6 +141,9 @@ async fn fetch_rules_for_push(
"device" => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
"group" => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'group' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),

View File

@@ -2,6 +2,8 @@ use sqlx::SqlitePool;
use tracing::debug;
/// Record an admin audit log entry.
/// TODO: Wire up audit logging to all admin API handlers.
#[allow(dead_code)]
pub async fn audit_log(
db: &SqlitePool,
user_id: i64,

View File

@@ -82,18 +82,32 @@ pub struct SmtpConfig {
impl AppConfig {
pub async fn load(path: &str) -> Result<Self> {
if Path::new(path).exists() {
let mut config = if Path::new(path).exists() {
let content = tokio::fs::read_to_string(path).await?;
let config: AppConfig = toml::from_str(&content)?;
Ok(config)
toml::from_str(&content)?
} else {
let config = default_config();
// Write default config for reference
let toml_str = toml::to_string_pretty(&config)?;
tokio::fs::write(path, &toml_str).await?;
tracing::warn!("Created default config at {}", path);
Ok(config)
config
};
// Environment variable overrides (take precedence over config file)
if let Ok(val) = std::env::var("CSM_JWT_SECRET") {
if !val.is_empty() {
tracing::info!("JWT secret loaded from CSM_JWT_SECRET env var");
config.auth.jwt_secret = val;
}
}
if let Ok(val) = std::env::var("CSM_REGISTRATION_TOKEN") {
if !val.is_empty() {
tracing::info!("Registration token loaded from CSM_REGISTRATION_TOKEN env var");
config.registration_token = val;
}
}
Ok(config)
}
}

View File

@@ -117,12 +117,15 @@ impl DeviceRepo {
}
pub async fn upsert_software(pool: &SqlitePool, asset: &csm_protocol::SoftwareAsset) -> Result<()> {
// Use INSERT OR REPLACE to handle the UNIQUE(device_uid, name, version) constraint
// where version can be NULL (treated as distinct by SQLite)
let version = asset.version.as_deref().unwrap_or("");
sqlx::query(
"INSERT OR REPLACE INTO software_assets (device_uid, name, version, publisher, install_date, install_path, updated_at)
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))"
"INSERT INTO software_assets (device_uid, name, version, publisher, install_date, install_path, updated_at)
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid, name, version) DO UPDATE SET
publisher = excluded.publisher,
install_date = excluded.install_date,
install_path = excluded.install_path,
updated_at = datetime('now')"
)
.bind(&asset.device_uid)
.bind(&asset.name)

View File

@@ -1,15 +1,20 @@
use anyhow::Result;
use axum::Router;
use axum::body::Body;
use axum::http::{Request, Response, StatusCode, header};
use axum::middleware::Next;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{CorsLayer, Any};
use axum::http::Method as HttpMethod;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tracing::{info, warn, error};
use include_dir::{include_dir, Dir};
mod api;
mod audit;
@@ -21,6 +26,10 @@ mod alert;
use config::AppConfig;
/// Embedded frontend assets from web/dist/ (compiled into the binary at build time).
/// Falls back gracefully at runtime if the directory is empty (dev mode).
static FRONTEND_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/../../web/dist");
/// Application shared state
#[derive(Clone)]
pub struct AppState {
@@ -46,6 +55,15 @@ async fn main() -> Result<()> {
// Load configuration
let config = AppConfig::load("config.toml").await?;
// Security checks
if config.registration_token.is_empty() {
warn!("SECURITY: registration_token is empty — any device can register!");
}
if config.auth.jwt_secret.len() < 32 {
warn!("SECURITY: jwt_secret is too short ({} chars) — consider using a 32+ byte key from CSM_JWT_SECRET env var", config.auth.jwt_secret.len());
}
let config = Arc::new(config);
// Initialize database
@@ -86,6 +104,9 @@ async fn main() -> Result<()> {
// Build HTTP router
let app = Router::new()
.merge(api::routes(state.clone()))
// SPA fallback: serve embedded frontend for non-API routes
.fallback(spa_fallback)
.layer(axum::middleware::from_fn(json_rejection_handler))
.layer(
build_cors_layer(&config.server.cors_origins),
)
@@ -171,6 +192,11 @@ async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
include_str!("../../../migrations/010_plugins_watermark.sql"),
include_str!("../../../migrations/011_token_security.sql"),
include_str!("../../../migrations/012_disk_encryption.sql"),
include_str!("../../../migrations/013_print_audit.sql"),
include_str!("../../../migrations/014_clipboard_control.sql"),
include_str!("../../../migrations/015_plugin_control.sql"),
include_str!("../../../migrations/016_encryption_alerts_unique.sql"),
];
// Create migrations tracking table
@@ -257,8 +283,102 @@ fn build_cors_layer(origins: &[String]) -> CorsLayer {
} else {
CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
.allow_methods(Any)
.allow_headers(Any)
.allow_methods([HttpMethod::GET, HttpMethod::POST, HttpMethod::PUT, HttpMethod::DELETE])
.allow_headers([axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE])
.max_age(std::time::Duration::from_secs(3600))
}
}
/// Middleware to convert axum's default 422 text/plain rejection responses
/// (e.g., JSON deserialization errors) into proper JSON ApiResponse format.
async fn json_rejection_handler(
req: Request<Body>,
next: Next,
) -> Response<Body> {
let response = next.run(req).await;
let status = response.status();
if status == StatusCode::UNPROCESSABLE_ENTITY {
let ct = response.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if ct.starts_with("text/plain") {
// Convert to JSON error response
let body = serde_json::json!({
"success": false,
"data": null,
"error": "Invalid request body"
});
return Response::builder()
.status(StatusCode::UNPROCESSABLE_ENTITY)
.header(axum::http::header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&body).unwrap_or_default()))
.unwrap_or(response);
}
}
response
}
/// SPA fallback handler: serves embedded frontend static files.
/// For known asset paths (JS/CSS/images), returns the file with correct MIME type.
/// For all other paths, returns index.html (SPA client-side routing).
async fn spa_fallback(req: Request<Body>) -> Response<Body> {
let path = req.uri().path().trim_start_matches('/');
// Try to serve the exact file first (e.g., assets/index-xxx.js)
if !path.is_empty() {
if let Some(file) = FRONTEND_DIR.get_file(path) {
let content_type = guess_content_type(path);
return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.header(header::CACHE_CONTROL, "public, max-age=31536000".to_string())
.body(Body::from(file.contents().to_vec()))
.unwrap_or_else(|_| Response::new(Body::from("Internal error")));
}
}
// SPA fallback: return index.html for all unmatched routes
match FRONTEND_DIR.get_file("index.html") {
Some(file) => Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.header(header::CACHE_CONTROL, "no-cache".to_string())
.body(Body::from(file.contents().to_vec()))
.unwrap_or_else(|_| Response::new(Body::from("Internal error"))),
None => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Frontend not built. Run: cd web && npm run build"))
.unwrap_or_else(|_| Response::new(Body::from("Not found"))),
}
}
/// Guess MIME type from file extension.
fn guess_content_type(path: &str) -> &'static str {
if path.ends_with(".js") {
"application/javascript; charset=utf-8"
} else if path.ends_with(".css") {
"text/css; charset=utf-8"
} else if path.ends_with(".html") {
"text/html; charset=utf-8"
} else if path.ends_with(".json") {
"application/json"
} else if path.ends_with(".png") {
"image/png"
} else if path.ends_with(".jpg") || path.ends_with(".jpeg") {
"image/jpeg"
} else if path.ends_with(".svg") {
"image/svg+xml"
} else if path.ends_with(".ico") {
"image/x-icon"
} else if path.ends_with(".woff") || path.ends_with(".woff2") {
"font/woff2"
} else if path.ends_with(".ttf") {
"font/ttf"
} else {
"application/octet-stream"
}
}

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use tokio::net::{TcpListener, TcpStream};
@@ -13,6 +14,15 @@ use crate::AppState;
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
const RATE_LIMIT_MAX_FRAMES: usize = 100;
/// Maximum concurrent TCP connections
const MAX_CONNECTIONS: usize = 500;
/// Maximum consecutive HMAC failures before disconnecting
const MAX_HMAC_FAILURES: u32 = 3;
/// Idle timeout for TCP connections (seconds) — disconnect if no data received
const IDLE_TIMEOUT_SECS: u64 = 180;
/// Per-connection rate limiter using a sliding window of frame timestamps
struct RateLimiter {
timestamps: Vec<Instant>,
@@ -226,6 +236,61 @@ pub async fn push_all_plugin_configs(
}
}
// Clipboard control rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, direction, source_process, target_process, content_pattern \
FROM clipboard_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<csm_protocol::ClipboardRule> = rows.iter().map(|r| csm_protocol::ClipboardRule {
id: r.get::<i64, _>("id"),
rule_type: r.get::<String, _>("rule_type"),
direction: r.get::<String, _>("direction"),
source_process: r.get::<Option<String>, _>("source_process"),
target_process: r.get::<Option<String>, _>("target_process"),
content_pattern: r.get::<Option<String>, _>("content_pattern"),
}).collect();
if !rules.is_empty() {
let payload = csm_protocol::ClipboardRulesPayload { rules };
if let Ok(frame) = Frame::new_json(MessageType::ClipboardRules, &payload) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Disk encryption config — push default reporting interval (no dedicated config table)
{
let config = csm_protocol::DiskEncryptionConfigPayload {
enabled: true,
report_interval_secs: 3600,
};
if let Ok(frame) = Frame::new_json(MessageType::DiskEncryptionConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
// Push plugin enable/disable state — disable any plugins that admin has turned off
if let Ok(rows) = sqlx::query(
"SELECT plugin_name FROM plugin_state WHERE enabled = 0"
)
.fetch_all(db).await
{
for row in &rows {
let plugin_name: String = row.get("plugin_name");
let payload = csm_protocol::PluginControlPayload {
plugin_name: plugin_name.clone(),
enabled: false,
};
if let Ok(frame) = Frame::new_json(MessageType::PluginDisable, &payload) {
clients.send_to(device_uid, frame.encode()).await;
}
debug!("Pushed PluginDisable for {} to device {}", plugin_name, device_uid);
}
}
info!("Pushed all plugin configs to newly registered device {}", device_uid);
}
@@ -283,6 +348,15 @@ pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<(
loop {
let (stream, peer_addr) = listener.accept().await?;
// Enforce maximum connection limit
let current_count = state.clients.count().await;
if current_count >= MAX_CONNECTIONS {
warn!("Rejecting connection from {}: limit reached ({}/{})", peer_addr, current_count, MAX_CONNECTIONS);
drop(stream);
continue;
}
let state = state.clone();
let acceptor = tls_acceptor.clone();
@@ -361,20 +435,6 @@ async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
}
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
/// Verify that a frame sender is a registered device and that the claimed device_uid
/// matches the one registered on this connection. Returns true if valid.
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
@@ -392,11 +452,13 @@ fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &
}
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
/// `hmac_fail_count` tracks consecutive HMAC failures; caller checks it for disconnect threshold.
async fn process_frame(
frame: Frame,
state: &AppState,
device_uid: &mut Option<String>,
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
hmac_fail_count: &Arc<AtomicU32>,
) -> anyhow::Result<()> {
match frame.msg_type {
MessageType::Register => {
@@ -438,7 +500,7 @@ async fn process_frame(
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
mac_address=excluded.mac_address, status='online'"
mac_address=excluded.mac_address, status='online', last_heartbeat=datetime('now')"
)
.bind(&req.device_uid)
.bind(&req.hostname)
@@ -493,6 +555,7 @@ async fn process_frame(
if !secret.is_empty() {
if heartbeat.hmac.is_empty() {
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
hmac_fail_count.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
// Constant-time HMAC verification using hmac::Mac::verify_slice
@@ -502,9 +565,11 @@ async fn process_frame(
mac.update(message.as_bytes());
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
if mac.verify_slice(&provided_bytes).is_err() {
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
warn!("Heartbeat HMAC mismatch for device {} (fail #{})", heartbeat.device_uid, hmac_fail_count.fetch_add(1, Ordering::Relaxed) + 1);
return Ok(());
}
// Successful verification — reset failure counter
hmac_fail_count.store(0, Ordering::Relaxed);
}
}
@@ -600,7 +665,8 @@ async fn process_frame(
total_active_minutes = excluded.total_active_minutes, \
total_idle_minutes = excluded.total_idle_minutes, \
first_active_at = excluded.first_active_at, \
last_active_at = excluded.last_active_at"
last_active_at = excluded.last_active_at, \
updated_at = datetime('now')"
)
.bind(&report.device_uid)
.bind(&report.date)
@@ -627,7 +693,8 @@ async fn process_frame(
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
VALUES (?, ?, ?, ?) \
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
usage_minutes = MAX(usage_minutes, excluded.usage_minutes), \
updated_at = datetime('now')"
)
.bind(&report.device_uid)
.bind(&report.date)
@@ -716,6 +783,150 @@ async fn process_frame(
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
}
MessageType::DiskEncryptionStatus => {
let payload: csm_protocol::DiskEncryptionStatusPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid disk encryption status: {}", e))?;
if !verify_device_uid(device_uid, "DiskEncryptionStatus", &payload.device_uid) {
return Ok(());
}
for drive in &payload.drives {
sqlx::query(
"INSERT INTO disk_encryption_status (device_uid, drive_letter, volume_name, encryption_method, protection_status, encryption_percentage, lock_status, reported_at, updated_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) \
ON CONFLICT(device_uid, drive_letter) DO UPDATE SET \
volume_name=excluded.volume_name, encryption_method=excluded.encryption_method, \
protection_status=excluded.protection_status, encryption_percentage=excluded.encryption_percentage, \
lock_status=excluded.lock_status, updated_at=datetime('now')"
)
.bind(&payload.device_uid)
.bind(&drive.drive_letter)
.bind(&drive.volume_name)
.bind(&drive.encryption_method)
.bind(&drive.protection_status)
.bind(drive.encryption_percentage)
.bind(&drive.lock_status)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting disk encryption status: {}", e))?;
// Generate alert for unencrypted drives
if drive.protection_status == "Off" {
sqlx::query(
"INSERT INTO encryption_alerts (device_uid, drive_letter, alert_type, status) \
VALUES (?, ?, 'not_encrypted', 'open') \
ON CONFLICT(device_uid, drive_letter, alert_type, status) DO NOTHING"
)
.bind(&payload.device_uid)
.bind(&drive.drive_letter)
.execute(&state.db)
.await
.ok();
}
}
info!("Disk encryption status reported: {} ({} drives)", payload.device_uid, payload.drives.len());
state.ws_hub.broadcast(serde_json::json!({
"type": "disk_encryption_status",
"device_uid": payload.device_uid,
"drive_count": payload.drives.len()
}).to_string()).await;
}
MessageType::PrintEvent => {
let event: csm_protocol::PrintEventPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid print event: {}", e))?;
if !verify_device_uid(device_uid, "PrintEvent", &event.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO print_events (device_uid, document_name, printer_name, pages, copies, user_name, file_size_bytes, timestamp) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
)
.bind(&event.device_uid)
.bind(&event.document_name)
.bind(&event.printer_name)
.bind(event.pages)
.bind(event.copies)
.bind(&event.user_name)
.bind(event.file_size_bytes)
.bind(&event.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting print event: {}", e))?;
debug!("Print event: {} doc={:?} printer={:?} pages={:?}",
event.device_uid, event.document_name, event.printer_name, event.pages);
state.ws_hub.broadcast(serde_json::json!({
"type": "print_event",
"device_uid": event.device_uid,
"document_name": event.document_name,
"printer_name": event.printer_name
}).to_string()).await;
}
MessageType::ClipboardViolation => {
let violation: csm_protocol::ClipboardViolationPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid clipboard violation: {}", e))?;
if !verify_device_uid(device_uid, "ClipboardViolation", &violation.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO clipboard_violations (device_uid, source_process, target_process, content_preview, action_taken, timestamp) \
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&violation.device_uid)
.bind(&violation.source_process)
.bind(&violation.target_process)
.bind(&violation.content_preview)
.bind(&violation.action_taken)
.bind(&violation.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting clipboard violation: {}", e))?;
debug!("Clipboard violation: {} action={}", violation.device_uid, violation.action_taken);
state.ws_hub.broadcast(serde_json::json!({
"type": "clipboard_violation",
"device_uid": violation.device_uid,
"source_process": violation.source_process,
"action_taken": violation.action_taken
}).to_string()).await;
}
MessageType::PopupBlockStats => {
let stats: csm_protocol::PopupBlockStatsPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup block stats: {}", e))?;
if !verify_device_uid(device_uid, "PopupBlockStats", &stats.device_uid) {
return Ok(());
}
for rule_stat in &stats.rule_stats {
sqlx::query(
"INSERT INTO popup_block_stats (device_uid, rule_id, blocked_count, period_secs, reported_at) \
VALUES (?, ?, ?, ?, datetime('now'))"
)
.bind(&stats.device_uid)
.bind(rule_stat.rule_id)
.bind(rule_stat.hits as i32)
.bind(stats.period_secs as i32)
.execute(&state.db)
.await
.ok();
}
debug!("Popup block stats: {} blocked {} windows in {}s", stats.device_uid, stats.blocked_count, stats.period_secs);
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
@@ -728,7 +939,6 @@ async fn process_frame(
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Set read timeout to detect stale connections
let _ = stream.set_nodelay(true);
let (mut reader, mut writer) = stream.into_split();
@@ -739,6 +949,7 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let hmac_fail_count = Arc::new(AtomicU32::new(0));
// Writer task: forwards messages from channel to TCP stream
let write_task = tokio::spawn(async move {
@@ -749,12 +960,22 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
}
});
// Reader loop
// Reader loop with idle timeout
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break; // Connection closed
}
let read_result = tokio::time::timeout(
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
reader.read(&mut buffer),
).await;
let n = match read_result {
Ok(Ok(0)) => break, // Connection closed
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
warn!("Idle timeout for device {:?}, disconnecting", device_uid);
break;
}
};
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth
@@ -766,7 +987,6 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
// Process complete frames
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
// Remove consumed bytes without reallocating
read_buf.drain(..frame_size);
// Rate limit check
@@ -781,9 +1001,15 @@ async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()>
continue;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
warn!("Frame processing error: {}", e);
}
// Disconnect if too many consecutive HMAC failures
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
warn!("Too many HMAC failures for device {:?}, disconnecting", device_uid);
break 'reader;
}
}
}
@@ -807,6 +1033,7 @@ async fn handle_client_tls(
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let hmac_fail_count = Arc::new(AtomicU32::new(0));
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
@@ -816,12 +1043,22 @@ async fn handle_client_tls(
}
});
// Reader loop — same logic as plaintext handler
// Reader loop with idle timeout
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break;
}
let read_result = tokio::time::timeout(
std::time::Duration::from_secs(IDLE_TIMEOUT_SECS),
reader.read(&mut buffer),
).await;
let n = match read_result {
Ok(Ok(0)) => break,
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
warn!("Idle timeout for TLS device {:?}, disconnecting", device_uid);
break;
}
};
read_buf.extend_from_slice(&buffer[..n]);
if read_buf.len() > MAX_READ_BUF_SIZE {
@@ -843,9 +1080,15 @@ async fn handle_client_tls(
break 'reader;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx, &hmac_fail_count).await {
warn!("Frame processing error: {}", e);
}
// Disconnect if too many consecutive HMAC failures
if hmac_fail_count.load(Ordering::Relaxed) >= MAX_HMAC_FAILURES {
warn!("Too many HMAC failures for TLS device {:?}, disconnecting", device_uid);
break 'reader;
}
}
}