feat: 初始化项目基础架构和核心功能
- 添加项目基础结构:Cargo.toml、.gitignore、设备UID和密钥文件 - 实现前端Vue3项目结构:路由、登录页面、设备管理页面 - 添加核心协议定义(crates/protocol):设备状态、资产、USB事件等 - 实现客户端监控模块:系统状态收集、资产收集 - 实现服务端基础API和插件系统 - 添加数据库迁移脚本:设备管理、资产跟踪、告警系统等 - 实现前端设备状态展示和基本交互 - 添加使用时长统计和水印功能插件
This commit is contained in:
51
crates/server/Cargo.toml
Normal file
51
crates/server/Cargo.toml
Normal file
@@ -0,0 +1,51 @@
|
||||
[package]
|
||||
name = "csm-server"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
csm-protocol = { path = "../protocol" }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
|
||||
# Web framework
|
||||
axum = { version = "0.7", features = ["ws"] }
|
||||
tower-http = { version = "0.5", features = ["cors", "fs", "trace", "compression-gzip", "set-header"] }
|
||||
tower = "0.4"
|
||||
|
||||
# Database
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
|
||||
|
||||
# TLS
|
||||
rustls = "0.23"
|
||||
tokio-rustls = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
rustls-pki-types = "1"
|
||||
|
||||
# Serialization
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Auth
|
||||
jsonwebtoken = "9"
|
||||
bcrypt = "0.15"
|
||||
|
||||
# Notifications
|
||||
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder", "hostname"] }
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
|
||||
|
||||
# Config & logging
|
||||
toml = "0.8"
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
# Utilities
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
include_dir = "0.7"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
hex = "0.4"
|
||||
118
crates/server/src/alert.rs
Normal file
118
crates/server/src/alert.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use crate::AppState;
|
||||
use tracing::{info, warn, error};
|
||||
|
||||
/// Background task for data cleanup and alert processing
|
||||
pub async fn cleanup_task(state: AppState) {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Cleanup old status history
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM device_status_history WHERE reported_at < datetime('now', ?)"
|
||||
)
|
||||
.bind(format!("-{} days", state.config.retention.status_history_days))
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup status history: {}", e);
|
||||
}
|
||||
|
||||
// Cleanup old USB events
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM usb_events WHERE event_time < datetime('now', ?)"
|
||||
)
|
||||
.bind(format!("-{} days", state.config.retention.usb_events_days))
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup USB events: {}", e);
|
||||
}
|
||||
|
||||
// Cleanup handled alert records
|
||||
if let Err(e) = sqlx::query(
|
||||
"DELETE FROM alert_records WHERE handled = 1 AND triggered_at < datetime('now', ?)"
|
||||
)
|
||||
.bind(format!("-{} days", state.config.retention.alert_records_days))
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to cleanup alert records: {}", e);
|
||||
}
|
||||
|
||||
// Mark devices as offline if no heartbeat for 2 minutes
|
||||
if let Err(e) = sqlx::query(
|
||||
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"
|
||||
)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
error!("Failed to mark stale devices offline: {}", e);
|
||||
}
|
||||
|
||||
// SQLite WAL checkpoint
|
||||
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
|
||||
.execute(&state.db)
|
||||
.await
|
||||
{
|
||||
warn!("WAL checkpoint failed: {}", e);
|
||||
}
|
||||
|
||||
info!("Cleanup cycle completed");
|
||||
}
|
||||
}
|
||||
|
||||
/// Send email notification
|
||||
pub async fn send_email(
|
||||
smtp_config: &crate::config::SmtpConfig,
|
||||
to: &str,
|
||||
subject: &str,
|
||||
body: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
use lettre::message::header::ContentType;
|
||||
use lettre::{Message, SmtpTransport, Transport};
|
||||
use lettre::transport::smtp::authentication::Credentials;
|
||||
|
||||
let email = Message::builder()
|
||||
.from(smtp_config.from.parse()?)
|
||||
.to(to.parse()?)
|
||||
.subject(subject)
|
||||
.header(ContentType::TEXT_HTML)
|
||||
.body(body.to_string())?;
|
||||
|
||||
let creds = Credentials::new(
|
||||
smtp_config.username.clone(),
|
||||
smtp_config.password.clone(),
|
||||
);
|
||||
|
||||
let mailer = SmtpTransport::starttls_relay(&smtp_config.host)?
|
||||
.port(smtp_config.port)
|
||||
.credentials(creds)
|
||||
.build();
|
||||
|
||||
mailer.send(&email)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shared HTTP client for webhook notifications.
|
||||
/// Lazily initialized once and reused across calls to benefit from connection pooling.
|
||||
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||||
|
||||
fn webhook_client() -> &'static reqwest::Client {
|
||||
WEBHOOK_CLIENT.get_or_init(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new())
|
||||
})
|
||||
}
|
||||
|
||||
/// Send webhook notification
|
||||
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
|
||||
webhook_client().post(url)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
243
crates/server/src/api/alerts.rs
Normal file
243
crates/server/src/api/alerts.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
use super::auth::Claims;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AlertRecordListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub alert_type: Option<String>,
|
||||
pub severity: Option<String>,
|
||||
pub handled: Option<i32>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
|
||||
FROM alert_rules ORDER BY created_at DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"name": r.get::<String, _>("name"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"condition": r.get::<String, _>("condition"),
|
||||
"severity": r.get::<String, _>("severity"),
|
||||
"enabled": r.get::<i32, _>("enabled"),
|
||||
"notify_email": r.get::<Option<String>, _>("notify_email"),
|
||||
"notify_webhook": r.get::<Option<String>, _>("notify_webhook"),
|
||||
"created_at": r.get::<String, _>("created_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"rules": items,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query alert rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_records(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AlertRecordListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None (Axum deserializes `key=` as Some(""))
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let alert_type = params.alert_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let severity = params.severity.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let handled = params.handled;
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, rule_id, device_uid, alert_type, severity, detail, handled, handled_by, handled_at, triggered_at
|
||||
FROM alert_records WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR alert_type = ?)
|
||||
AND (? IS NULL OR severity = ?)
|
||||
AND (? IS NULL OR handled = ?)
|
||||
ORDER BY triggered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&alert_type).bind(&alert_type)
|
||||
.bind(&severity).bind(&severity)
|
||||
.bind(&handled).bind(&handled)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_id": r.get::<Option<i64>, _>("rule_id"),
|
||||
"device_uid": r.get::<Option<String>, _>("device_uid"),
|
||||
"alert_type": r.get::<String, _>("alert_type"),
|
||||
"severity": r.get::<String, _>("severity"),
|
||||
"detail": r.get::<String, _>("detail"),
|
||||
"handled": r.get::<i32, _>("handled"),
|
||||
"handled_by": r.get::<Option<String>, _>("handled_by"),
|
||||
"handled_at": r.get::<Option<String>, _>("handled_at"),
|
||||
"triggered_at": r.get::<String, _>("triggered_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"records": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query alert records", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub name: String,
|
||||
pub rule_type: String,
|
||||
pub condition: String,
|
||||
pub severity: Option<String>,
|
||||
pub notify_email: Option<String>,
|
||||
pub notify_webhook: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn create_rule(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<CreateRuleRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
|
||||
VALUES (?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&body.name)
|
||||
.bind(&body.rule_type)
|
||||
.bind(&body.condition)
|
||||
.bind(&severity)
|
||||
.bind(&body.notify_email)
|
||||
.bind(&body.notify_webhook)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
|
||||
"id": r.last_insert_rowid(),
|
||||
})))),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create alert rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest {
|
||||
pub name: Option<String>,
|
||||
pub rule_type: Option<String>,
|
||||
pub condition: Option<String>,
|
||||
pub severity: Option<String>,
|
||||
pub enabled: Option<i32>,
|
||||
pub notify_email: Option<String>,
|
||||
pub notify_webhook: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn update_rule(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdateRuleRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let existing = sqlx::query("SELECT * FROM alert_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Rule not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query alert rule", e)),
|
||||
};
|
||||
|
||||
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
|
||||
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
|
||||
let condition = body.condition.unwrap_or_else(|| existing.get::<String, _>("condition"));
|
||||
let severity = body.severity.unwrap_or_else(|| existing.get::<String, _>("severity"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
|
||||
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
|
||||
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
|
||||
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(&rule_type)
|
||||
.bind(&condition)
|
||||
.bind(&severity)
|
||||
.bind(enabled)
|
||||
.bind(¬ify_email)
|
||||
.bind(¬ify_webhook)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::ok(serde_json::json!({"updated": true}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("update alert rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let result = sqlx::query("DELETE FROM alert_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Rule not found"))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("delete alert rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_record(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
claims: axum::Extension<Claims>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let handled_by = &claims.username;
|
||||
let result = sqlx::query(
|
||||
"UPDATE alert_records SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(handled_by)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
Json(ApiResponse::ok(serde_json::json!({"handled": true})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Alert record not found"))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("handle alert record", e)),
|
||||
}
|
||||
}
|
||||
143
crates/server/src/api/assets.rs
Normal file
143
crates/server/src/api/assets.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::{ApiResponse, Pagination};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AssetListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub search: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_hardware(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AssetListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb,
|
||||
gpu_model, motherboard, serial_number, reported_at
|
||||
FROM hardware_assets WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR cpu_model LIKE '%' || ? || '%' OR gpu_model LIKE '%' || ? || '%')
|
||||
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"cpu_model": r.get::<String, _>("cpu_model"),
|
||||
"cpu_cores": r.get::<i32, _>("cpu_cores"),
|
||||
"memory_total_mb": r.get::<i64, _>("memory_total_mb"),
|
||||
"disk_model": r.get::<String, _>("disk_model"),
|
||||
"disk_total_mb": r.get::<i64, _>("disk_total_mb"),
|
||||
"gpu_model": r.get::<Option<String>, _>("gpu_model"),
|
||||
"motherboard": r.get::<Option<String>, _>("motherboard"),
|
||||
"serial_number": r.get::<Option<String>, _>("serial_number"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"hardware": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query hardware assets", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_software(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<AssetListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, name, version, publisher, install_date, install_path
|
||||
FROM software_assets WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR name LIKE '%' || ? || '%' OR publisher LIKE '%' || ? || '%')
|
||||
ORDER BY name ASC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"name": r.get::<String, _>("name"),
|
||||
"version": r.get::<Option<String>, _>("version"),
|
||||
"publisher": r.get::<Option<String>, _>("publisher"),
|
||||
"install_date": r.get::<Option<String>, _>("install_date"),
|
||||
"install_path": r.get::<Option<String>, _>("install_path"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"software": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query software assets", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_changes(
|
||||
State(state): State<AppState>,
|
||||
Query(page): Query<Pagination>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let offset = page.offset();
|
||||
let limit = page.limit();
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, change_type, change_detail, detected_at
|
||||
FROM asset_changes ORDER BY detected_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"change_type": r.get::<String, _>("change_type"),
|
||||
"change_detail": r.get::<String, _>("change_detail"),
|
||||
"detected_at": r.get::<String, _>("detected_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"changes": items,
|
||||
"page": page.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query asset changes", e)),
|
||||
}
|
||||
}
|
||||
295
crates/server/src/api/auth.rs
Normal file
295
crates/server/src/api/auth.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: i64,
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub exp: u64,
|
||||
pub iat: u64,
|
||||
pub token_type: String,
|
||||
/// Random family ID for refresh token rotation detection
|
||||
pub family: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LoginResponse {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub user: UserInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct UserInfo {
|
||||
pub id: i64,
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RefreshRequest {
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
/// In-memory rate limiter for login attempts
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LoginRateLimiter {
|
||||
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
|
||||
}
|
||||
|
||||
impl LoginRateLimiter {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Returns true if the request should be rate-limited
|
||||
pub async fn is_limited(&self, key: &str) -> bool {
|
||||
let mut attempts = self.attempts.lock().await;
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(300); // 5-minute window
|
||||
let max_attempts = 10u32;
|
||||
|
||||
if let Some((first_attempt, count)) = attempts.get_mut(key) {
|
||||
if now.duration_since(*first_attempt) > window {
|
||||
// Window expired, reset
|
||||
*first_attempt = now;
|
||||
*count = 1;
|
||||
false
|
||||
} else if *count >= max_attempts {
|
||||
true // Rate limited
|
||||
} else {
|
||||
*count += 1;
|
||||
false
|
||||
}
|
||||
} else {
|
||||
attempts.insert(key.to_string(), (now, 1));
|
||||
// Cleanup old entries periodically
|
||||
if attempts.len() > 1000 {
|
||||
let cutoff = now - window;
|
||||
attempts.retain(|_, (t, _)| *t > cutoff);
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
|
||||
// Rate limit check
|
||||
if state.login_limiter.is_limited(&req.username).await {
|
||||
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
|
||||
}
|
||||
|
||||
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
|
||||
"SELECT id, username, role FROM users WHERE username = ?"
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let user = match user {
|
||||
Some(u) => u,
|
||||
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
|
||||
};
|
||||
|
||||
let hash: String = sqlx::query_scalar::<_, String>(
|
||||
"SELECT password FROM users WHERE id = ?"
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let family = uuid::Uuid::new_v4().to_string();
|
||||
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
|
||||
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
|
||||
|
||||
// Audit log
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(format!("User {} logged in", user.username))
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token,
|
||||
user,
|
||||
}))))
|
||||
}
|
||||
|
||||
pub async fn refresh(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<RefreshRequest>,
|
||||
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
|
||||
let claims = decode::<Claims>(
|
||||
&req.refresh_token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if claims.claims.token_type != "refresh" {
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
|
||||
}
|
||||
|
||||
// Check if this refresh token family has been revoked (reuse detection)
|
||||
let revoked: bool = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
|
||||
)
|
||||
.bind(&claims.claims.family)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0) > 0;
|
||||
|
||||
if revoked {
|
||||
// Token reuse detected — revoke entire family and force re-login
|
||||
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
|
||||
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
|
||||
.bind(claims.claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
|
||||
}
|
||||
|
||||
let user = UserInfo {
|
||||
id: claims.claims.sub,
|
||||
username: claims.claims.username,
|
||||
role: claims.claims.role,
|
||||
};
|
||||
|
||||
// Rotate: new family for each refresh
|
||||
let new_family = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
|
||||
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
|
||||
|
||||
// Revoke old family
|
||||
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
|
||||
.bind(&claims.claims.family)
|
||||
.bind(claims.claims.sub)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
|
||||
access_token,
|
||||
refresh_token,
|
||||
user,
|
||||
}))))
|
||||
}
|
||||
|
||||
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
|
||||
let claims = Claims {
|
||||
sub: user.id,
|
||||
username: user.username.clone(),
|
||||
role: user.role.clone(),
|
||||
exp: now + ttl,
|
||||
iat: now,
|
||||
token_type: token_type.to_string(),
|
||||
family: family.to_string(),
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_bytes()),
|
||||
)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
/// Axum middleware: require valid JWT access token
|
||||
pub async fn require_auth(
|
||||
State(state): State<AppState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let auth_header = request.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let token = match auth_header {
|
||||
Some(t) => t,
|
||||
None => return Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
|
||||
let claims = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if claims.claims.token_type != "access" {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// Inject claims into request extensions for handlers to use
|
||||
request.extensions_mut().insert(claims.claims);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
/// Axum middleware: require admin role for write operations + audit log
|
||||
pub async fn require_admin(
|
||||
State(state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let claims = request.extensions()
|
||||
.get::<Claims>()
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if claims.role != "admin" {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
// Capture audit info before running handler
|
||||
let method = request.method().clone();
|
||||
let path = request.uri().path().to_string();
|
||||
let user_id = claims.sub;
|
||||
let username = claims.username.clone();
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Record admin action to audit log (fire and forget — don't block response)
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
let action = format!("{} {}", method, path);
|
||||
let detail = format!("by {}", username);
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&action)
|
||||
.bind(&detail)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
263
crates/server/src/api/devices.rs
Normal file
263
crates/server/src/api/devices.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use axum::{extract::{State, Path, Query}, Json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use super::{ApiResponse, Pagination};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DeviceListParams {
|
||||
pub status: Option<String>,
|
||||
pub group: Option<String>,
|
||||
pub search: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
pub struct DeviceRow {
|
||||
pub id: i64,
|
||||
pub device_uid: String,
|
||||
pub hostname: String,
|
||||
pub ip_address: String,
|
||||
pub mac_address: Option<String>,
|
||||
pub os_version: Option<String>,
|
||||
pub client_version: Option<String>,
|
||||
pub status: String,
|
||||
pub last_heartbeat: Option<String>,
|
||||
pub registered_at: Option<String>,
|
||||
pub group_name: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<DeviceListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None (Axum deserializes `status=` as Some(""))
|
||||
let status = params.status.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let group = params.group.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let devices = sqlx::query_as::<_, DeviceRow>(
|
||||
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
|
||||
status, last_heartbeat, registered_at, group_name
|
||||
FROM devices WHERE 1=1
|
||||
AND (? IS NULL OR status = ?)
|
||||
AND (? IS NULL OR group_name = ?)
|
||||
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
|
||||
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&status).bind(&status)
|
||||
.bind(&group).bind(&group)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM devices WHERE 1=1
|
||||
AND (? IS NULL OR status = ?)
|
||||
AND (? IS NULL OR group_name = ?)
|
||||
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')"
|
||||
)
|
||||
.bind(&status).bind(&status)
|
||||
.bind(&group).bind(&group)
|
||||
.bind(&search).bind(&search).bind(&search)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
match devices {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
|
||||
"devices": rows,
|
||||
"total": total,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_detail(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let device = sqlx::query_as::<_, DeviceRow>(
|
||||
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
|
||||
status, last_heartbeat, registered_at, group_name
|
||||
FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
match device {
|
||||
Ok(Some(d)) => Json(ApiResponse::ok(serde_json::to_value(d).unwrap_or_default())),
|
||||
Ok(None) => Json(ApiResponse::error("Device not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, sqlx::FromRow)]
|
||||
struct StatusRow {
|
||||
pub cpu_usage: f64,
|
||||
pub memory_usage: f64,
|
||||
pub memory_total_mb: i64,
|
||||
pub disk_usage: f64,
|
||||
pub disk_total_mb: i64,
|
||||
pub network_rx_rate: i64,
|
||||
pub network_tx_rate: i64,
|
||||
pub running_procs: i32,
|
||||
pub top_processes: Option<String>,
|
||||
pub reported_at: String,
|
||||
}
|
||||
|
||||
pub async fn get_status(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let status = sqlx::query_as::<_, StatusRow>(
|
||||
"SELECT cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb,
|
||||
network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at
|
||||
FROM device_status WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
match status {
|
||||
Ok(Some(s)) => {
|
||||
let mut val = serde_json::to_value(&s).unwrap_or_default();
|
||||
// Parse top_processes JSON string back to array
|
||||
if let Some(tp_str) = &s.top_processes {
|
||||
if let Ok(tp) = serde_json::from_str::<serde_json::Value>(tp_str) {
|
||||
val["top_processes"] = tp;
|
||||
}
|
||||
}
|
||||
Json(ApiResponse::ok(val))
|
||||
}
|
||||
Ok(None) => Json(ApiResponse::error("No status data found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_history(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
Query(page): Query<Pagination>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let offset = page.offset();
|
||||
let limit = page.limit();
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT cpu_usage, memory_usage, disk_usage, running_procs, reported_at
|
||||
FROM device_status_history WHERE device_uid = ?
|
||||
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&uid)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"cpu_usage": r.get::<f64, _>("cpu_usage"),
|
||||
"memory_usage": r.get::<f64, _>("memory_usage"),
|
||||
"disk_usage": r.get::<f64, _>("disk_usage"),
|
||||
"running_procs": r.get::<i32, _>("running_procs"),
|
||||
"reported_at": r.get::<String, _>("reported_at"),
|
||||
})
|
||||
}).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"history": items,
|
||||
"page": page.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove(
|
||||
State(state): State<AppState>,
|
||||
Path(uid): Path<String>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
// If client is connected, send self-destruct command
|
||||
let frame = csm_protocol::Frame::new_json(
|
||||
csm_protocol::MessageType::ConfigUpdate,
|
||||
&serde_json::json!({"type": "SelfDestruct"}),
|
||||
).ok();
|
||||
|
||||
if let Some(frame) = frame {
|
||||
state.clients.send_to(&uid, frame.encode()).await;
|
||||
}
|
||||
|
||||
// Delete device and all associated data in a transaction
|
||||
let mut tx = match state.db.begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => return Json(ApiResponse::internal_error("begin transaction", e)),
|
||||
};
|
||||
|
||||
// Delete status history
|
||||
if let Err(e) = sqlx::query("DELETE FROM device_status_history WHERE device_uid = ?")
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
return Json(ApiResponse::internal_error("remove device history", e));
|
||||
}
|
||||
|
||||
// Delete current status
|
||||
if let Err(e) = sqlx::query("DELETE FROM device_status WHERE device_uid = ?")
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
return Json(ApiResponse::internal_error("remove device status", e));
|
||||
}
|
||||
|
||||
// Delete plugin-related data
|
||||
let cleanup_tables = [
|
||||
"hardware_assets",
|
||||
"usb_events",
|
||||
"usb_file_operations",
|
||||
"usage_daily",
|
||||
"app_usage_daily",
|
||||
"software_violations",
|
||||
"web_access_log",
|
||||
"popup_block_stats",
|
||||
];
|
||||
for table in &cleanup_tables {
|
||||
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to clean {} for device {}: {}", table, uid, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Finally delete the device itself
|
||||
let delete_result = sqlx::query("DELETE FROM devices WHERE device_uid = ?")
|
||||
.bind(&uid)
|
||||
.execute(&mut *tx)
|
||||
.await;
|
||||
|
||||
match delete_result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
if let Err(e) = tx.commit().await {
|
||||
return Json(ApiResponse::internal_error("commit device deletion", e));
|
||||
}
|
||||
state.clients.unregister(&uid).await;
|
||||
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Device not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("remove device", e)),
|
||||
}
|
||||
}
|
||||
120
crates/server/src/api/mod.rs
Normal file
120
crates/server/src/api/mod.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::AppState;
|
||||
|
||||
pub mod auth;
|
||||
pub mod devices;
|
||||
pub mod assets;
|
||||
pub mod usb;
|
||||
pub mod alerts;
|
||||
pub mod plugins;
|
||||
|
||||
pub fn routes(state: AppState) -> Router<AppState> {
|
||||
let public = Router::new()
|
||||
.route("/api/auth/login", post(auth::login))
|
||||
.route("/api/auth/refresh", post(auth::refresh))
|
||||
.route("/health", get(health_check))
|
||||
.with_state(state.clone());
|
||||
|
||||
// Read-only routes (any authenticated user)
|
||||
let read_routes = Router::new()
|
||||
// Devices
|
||||
.route("/api/devices", get(devices::list))
|
||||
.route("/api/devices/:uid", get(devices::get_detail))
|
||||
.route("/api/devices/:uid/status", get(devices::get_status))
|
||||
.route("/api/devices/:uid/history", get(devices::get_history))
|
||||
// Assets
|
||||
.route("/api/assets/hardware", get(assets::list_hardware))
|
||||
.route("/api/assets/software", get(assets::list_software))
|
||||
.route("/api/assets/changes", get(assets::list_changes))
|
||||
// USB (read)
|
||||
.route("/api/usb/events", get(usb::list_events))
|
||||
.route("/api/usb/policies", get(usb::list_policies))
|
||||
// Alerts (read)
|
||||
.route("/api/alerts/rules", get(alerts::list_rules))
|
||||
.route("/api/alerts/records", get(alerts::list_records))
|
||||
// Plugin read routes
|
||||
.merge(plugins::read_routes())
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
|
||||
|
||||
// Write routes (admin only)
|
||||
let write_routes = Router::new()
|
||||
// Devices
|
||||
.route("/api/devices/:uid", delete(devices::remove))
|
||||
// USB (write)
|
||||
.route("/api/usb/policies", post(usb::create_policy))
|
||||
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
|
||||
// Alerts (write)
|
||||
.route("/api/alerts/rules", post(alerts::create_rule))
|
||||
.route("/api/alerts/rules/:id", put(alerts::update_rule).delete(alerts::delete_rule))
|
||||
.route("/api/alerts/records/:id/handle", put(alerts::handle_record))
|
||||
// Plugin write routes (already has require_admin layer internally)
|
||||
.merge(plugins::write_routes())
|
||||
// Layer order: outer (require_admin) runs AFTER inner (require_auth)
|
||||
// so require_auth sets Claims extension first, then require_admin checks it
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_admin))
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
|
||||
|
||||
// WebSocket has its own JWT auth via query parameter
|
||||
let ws_router = Router::new()
|
||||
.route("/ws", get(crate::ws::ws_handler))
|
||||
.with_state(state.clone());
|
||||
|
||||
Router::new()
|
||||
.merge(public)
|
||||
.merge(read_routes)
|
||||
.merge(write_routes)
|
||||
.merge(ws_router)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct HealthResponse {
|
||||
status: &'static str,
|
||||
}
|
||||
|
||||
async fn health_check() -> Json<HealthResponse> {
|
||||
Json(HealthResponse {
|
||||
status: "ok",
|
||||
})
|
||||
}
|
||||
|
||||
/// Standard API response envelope
|
||||
#[derive(Serialize)]
|
||||
pub struct ApiResponse<T: Serialize> {
|
||||
pub success: bool,
|
||||
pub data: Option<T>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl<T: Serialize> ApiResponse<T> {
|
||||
pub fn ok(data: T) -> Self {
|
||||
Self { success: true, data: Some(data), error: None }
|
||||
}
|
||||
|
||||
pub fn error(msg: impl Into<String>) -> Self {
|
||||
Self { success: false, data: None, error: Some(msg.into()) }
|
||||
}
|
||||
|
||||
/// Log internal error and return sanitized message to client
|
||||
pub fn internal_error(context: &str, e: impl std::fmt::Display) -> Self {
|
||||
tracing::error!("{}: {}", context, e);
|
||||
Self { success: false, data: None, error: Some("Internal server error".to_string()) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Pagination parameters
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Pagination {
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl Pagination {
|
||||
pub fn offset(&self) -> u32 {
|
||||
self.page.unwrap_or(1).saturating_sub(1) * self.limit()
|
||||
}
|
||||
|
||||
pub fn limit(&self) -> u32 {
|
||||
self.page_size.unwrap_or(20).min(100)
|
||||
}
|
||||
}
|
||||
49
crates/server/src/api/plugins/mod.rs
Normal file
49
crates/server/src/api/plugins/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
pub mod web_filter;
|
||||
pub mod usage_timer;
|
||||
pub mod software_blocker;
|
||||
pub mod popup_blocker;
|
||||
pub mod usb_file_audit;
|
||||
pub mod watermark;
|
||||
|
||||
use axum::{Router, routing::{get, post, put}};
|
||||
use crate::AppState;
|
||||
|
||||
/// Read-only plugin routes (accessible by admin + viewer)
|
||||
pub fn read_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
// Web Filter
|
||||
.route("/api/plugins/web-filter/rules", get(web_filter::list_rules))
|
||||
.route("/api/plugins/web-filter/log", get(web_filter::list_access_log))
|
||||
// Usage Timer
|
||||
.route("/api/plugins/usage-timer/daily", get(usage_timer::list_daily))
|
||||
.route("/api/plugins/usage-timer/app-usage", get(usage_timer::list_app_usage))
|
||||
.route("/api/plugins/usage-timer/leaderboard", get(usage_timer::leaderboard))
|
||||
// Software Blocker
|
||||
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
|
||||
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
|
||||
// Popup Blocker
|
||||
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
|
||||
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
|
||||
// USB File Audit
|
||||
.route("/api/plugins/usb-file-audit/log", get(usb_file_audit::list_operations))
|
||||
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
|
||||
// Watermark
|
||||
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
|
||||
}
|
||||
|
||||
/// Write plugin routes (admin only — require_admin middleware applied by caller)
|
||||
pub fn write_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
// Web Filter
|
||||
.route("/api/plugins/web-filter/rules", post(web_filter::create_rule))
|
||||
.route("/api/plugins/web-filter/rules/:id", put(web_filter::update_rule).delete(web_filter::delete_rule))
|
||||
// Software Blocker
|
||||
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
|
||||
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
|
||||
// Popup Blocker
|
||||
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
|
||||
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
|
||||
// Watermark
|
||||
.route("/api/plugins/watermark/config", post(watermark::create_config))
|
||||
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
|
||||
}
|
||||
155
crates/server/src/api/plugins/popup_blocker.rs
Normal file
155
crates/server/src/api/plugins/popup_blocker.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use axum::{extract::{State, Path, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub rule_type: String, // "block" | "allow"
|
||||
pub window_title: Option<String>,
|
||||
pub window_class: Option<String>,
|
||||
pub process_name: Option<String>,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
"window_title": r.get::<Option<String>,_>("window_title"),
|
||||
"window_class": r.get::<Option<String>,_>("window_class"),
|
||||
"process_name": r.get::<Option<String>,_>("process_name"),
|
||||
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
|
||||
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query popup filter rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if !matches!(req.rule_type.as_str(), "block" | "allow") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
let has_filter = req.window_title.as_ref().map_or(false, |s| !s.is_empty())
|
||||
|| req.window_class.as_ref().map_or(false, |s| !s.is_empty())
|
||||
|| req.process_name.as_ref().map_or(false, |s| !s.is_empty());
|
||||
if !has_filter {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
|
||||
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let rules = fetch_popup_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create popup filter rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest { pub window_title: Option<String>, pub window_class: Option<String>, pub process_name: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM popup_filter_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query popup filter rule", e)),
|
||||
};
|
||||
|
||||
let window_title = body.window_title.or_else(|| existing.get::<Option<String>, _>("window_title"));
|
||||
let window_class = body.window_class.or_else(|| existing.get::<Option<String>, _>("window_class"));
|
||||
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&window_title)
|
||||
.bind(&window_class)
|
||||
.bind(&process_name)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let rules = fetch_popup_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update popup filter rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM popup_filter_rules WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
match sqlx::query("DELETE FROM popup_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let rules = fetch_popup_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_stats(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT device_uid, blocked_count, date FROM popup_block_stats ORDER BY date DESC LIMIT 30")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"stats": rows.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String,_>("device_uid"), "blocked_count": r.get::<i32,_>("blocked_count"),
|
||||
"date": r.get::<String,_>("date")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query popup block stats", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_popup_rules_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
"window_title": r.get::<Option<String>,_>("window_title"),
|
||||
"window_class": r.get::<Option<String>,_>("window_class"),
|
||||
"process_name": r.get::<Option<String>,_>("process_name"),
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
155
crates/server/src/api/plugins/software_blocker.rs
Normal file
155
crates/server/src/api/plugins/software_blocker.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateBlacklistRequest {
|
||||
pub name_pattern: String,
|
||||
pub category: Option<String>,
|
||||
pub action: Option<String>,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
|
||||
"category": r.get::<Option<String>,_>("category"), "action": r.get::<String,_>("action"),
|
||||
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
|
||||
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query software blacklist", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<CreateBlacklistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let action = req.action.unwrap_or_else(|| "block".to_string());
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
|
||||
}
|
||||
if !matches!(action.as_str(), "block" | "alert") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("action must be 'block' or 'alert'")));
|
||||
}
|
||||
if let Some(ref cat) = req.category {
|
||||
if !matches!(cat.as_str(), "game" | "social" | "vpn" | "mining" | "custom") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid category")));
|
||||
}
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO software_blacklist (name_pattern, category, action, target_type, target_id) VALUES (?,?,?,?,?)")
|
||||
.bind(&req.name_pattern).bind(&req.category).bind(&action).bind(&target_type).bind(&req.target_id)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateBlacklistRequest { pub name_pattern: Option<String>, pub action: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateBlacklistRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM software_blacklist WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query software blacklist", e)),
|
||||
};
|
||||
|
||||
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
|
||||
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&name_pattern)
|
||||
.bind(&action)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update software blacklist", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM software_blacklist WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ViolationFilters { pub device_uid: Option<String> }
|
||||
|
||||
pub async fn list_violations(State(state): State<AppState>, Query(f): Query<ViolationFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, device_uid, software_name, action_taken, timestamp FROM software_violations WHERE (? IS NULL OR device_uid=?) ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"violations": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"software_name": r.get::<String,_>("software_name"), "action_taken": r.get::<String,_>("action_taken"),
|
||||
"timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query software violations", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_blacklist_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"), "action": r.get::<String,_>("action")
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
60
crates/server/src/api/plugins/usage_timer.rs
Normal file
60
crates/server/src/api/plugins/usage_timer.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DailyFilters { pub device_uid: Option<String>, pub start_date: Option<String>, pub end_date: Option<String> }
|
||||
|
||||
pub async fn list_daily(State(state): State<AppState>, Query(f): Query<DailyFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at
|
||||
FROM usage_daily WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date>=?) AND (? IS NULL OR date<=?)
|
||||
ORDER BY date DESC LIMIT 90")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.bind(&f.start_date).bind(&f.start_date)
|
||||
.bind(&f.end_date).bind(&f.end_date)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"daily": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"date": r.get::<String,_>("date"), "total_active_minutes": r.get::<i32,_>("total_active_minutes"),
|
||||
"total_idle_minutes": r.get::<i32,_>("total_idle_minutes"),
|
||||
"first_active_at": r.get::<Option<String>,_>("first_active_at"),
|
||||
"last_active_at": r.get::<Option<String>,_>("last_active_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query daily usage", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AppUsageFilters { pub device_uid: Option<String>, pub date: Option<String> }
|
||||
|
||||
pub async fn list_app_usage(State(state): State<AppState>, Query(f): Query<AppUsageFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, date, app_name, usage_minutes FROM app_usage_daily
|
||||
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date=?)
|
||||
ORDER BY usage_minutes DESC LIMIT 100")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.bind(&f.date).bind(&f.date)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"app_usage": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"date": r.get::<String,_>("date"), "app_name": r.get::<String,_>("app_name"),
|
||||
"usage_minutes": r.get::<i32,_>("usage_minutes")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query app usage", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn leaderboard(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT device_uid, SUM(total_active_minutes) as total_minutes FROM usage_daily
|
||||
WHERE date >= date('now', '-7 days') GROUP BY device_uid ORDER BY total_minutes DESC LIMIT 20")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"leaderboard": rows.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String,_>("device_uid"), "total_minutes": r.get::<i64,_>("total_minutes")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query usage leaderboard", e)),
|
||||
}
|
||||
}
|
||||
47
crates/server/src/api/plugins/usb_file_audit.rs
Normal file
47
crates/server/src/api/plugins/usb_file_audit.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use axum::{extract::{State, Query}, Json};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LogFilters {
|
||||
pub device_uid: Option<String>,
|
||||
pub operation: Option<String>,
|
||||
pub usb_serial: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn list_operations(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT id, device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp
|
||||
FROM usb_file_operations WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR operation=?) AND (? IS NULL OR usb_serial=?)
|
||||
ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid)
|
||||
.bind(&f.operation).bind(&f.operation)
|
||||
.bind(&f.usb_serial).bind(&f.usb_serial)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"operations": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"usb_serial": r.get::<Option<String>,_>("usb_serial"), "drive_letter": r.get::<Option<String>,_>("drive_letter"),
|
||||
"operation": r.get::<String,_>("operation"), "file_path": r.get::<String,_>("file_path"),
|
||||
"file_size": r.get::<Option<i64>,_>("file_size"), "timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query USB file operations", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn summary(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query(
|
||||
"SELECT device_uid, COUNT(*) as op_count, COUNT(DISTINCT usb_serial) as usb_count,
|
||||
MIN(timestamp) as first_op, MAX(timestamp) as last_op
|
||||
FROM usb_file_operations WHERE timestamp >= datetime('now', '-7 days')
|
||||
GROUP BY device_uid ORDER BY op_count DESC LIMIT 50")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"summary": rows.iter().map(|r| serde_json::json!({
|
||||
"device_uid": r.get::<String,_>("device_uid"), "op_count": r.get::<i64,_>("op_count"),
|
||||
"usb_count": r.get::<i64,_>("usb_count"), "first_op": r.get::<String,_>("first_op"),
|
||||
"last_op": r.get::<String,_>("last_op")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query USB file audit summary", e)),
|
||||
}
|
||||
}
|
||||
186
crates/server/src/api/plugins/watermark.rs
Normal file
186
crates/server/src/api/plugins/watermark.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use axum::{extract::{State, Path, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::{MessageType, WatermarkConfigPayload};
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateConfigRequest {
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub font_size: Option<u32>,
|
||||
pub opacity: Option<f64>,
|
||||
pub color: Option<String>,
|
||||
pub angle: Option<i32>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn get_config_list(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, target_type, target_id, content, font_size, opacity, color, angle, enabled, updated_at FROM watermark_config ORDER BY id")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"configs": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "target_type": r.get::<String,_>("target_type"),
|
||||
"target_id": r.get::<Option<String>,_>("target_id"), "content": r.get::<String,_>("content"),
|
||||
"font_size": r.get::<i32,_>("font_size"), "opacity": r.get::<f64,_>("opacity"),
|
||||
"color": r.get::<String,_>("color"), "angle": r.get::<i32,_>("angle"),
|
||||
"enabled": r.get::<bool,_>("enabled"), "updated_at": r.get::<String,_>("updated_at")
|
||||
})).collect::<Vec<_>>()}))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query watermark configs", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_config(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<CreateConfigRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
let content = req.content.unwrap_or_else(|| "{company} | {username} | {date}".to_string());
|
||||
let font_size = req.font_size.unwrap_or(14).clamp(8, 72) as i32;
|
||||
let opacity = req.opacity.unwrap_or(0.15).clamp(0.01, 1.0);
|
||||
let color = req.color.unwrap_or_else(|| "#808080".to_string());
|
||||
let angle = req.angle.unwrap_or(-30);
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
|
||||
// Validate inputs
|
||||
if content.len() > 200 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("content too long (max 200 chars)")));
|
||||
}
|
||||
if !is_valid_hex_color(&color) {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid color format (expected #RRGGBB)")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO watermark_config (target_type, target_id, content, font_size, opacity, color, angle, enabled) VALUES (?,?,?,?,?,?,?,?)")
|
||||
.bind(&target_type).bind(&req.target_id).bind(&content).bind(font_size).bind(opacity).bind(&color).bind(angle).bind(enabled)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
// Push to online clients
|
||||
let config = WatermarkConfigPayload {
|
||||
content: content.clone(),
|
||||
font_size: font_size as u32,
|
||||
opacity,
|
||||
color: color.clone(),
|
||||
angle,
|
||||
enabled,
|
||||
};
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create watermark config", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateConfigRequest {
|
||||
pub content: Option<String>, pub font_size: Option<u32>, pub opacity: Option<f64>,
|
||||
pub color: Option<String>, pub angle: Option<i32>, pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn update_config(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdateConfigRequest>,
|
||||
) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM watermark_config WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query watermark config", e)),
|
||||
};
|
||||
|
||||
let content = body.content.unwrap_or_else(|| existing.get::<String, _>("content"));
|
||||
let font_size = body.font_size.map(|v| v.clamp(8, 72) as i32).unwrap_or_else(|| existing.get::<i32, _>("font_size"));
|
||||
let opacity = body.opacity.map(|v| v.clamp(0.01, 1.0)).unwrap_or_else(|| existing.get::<f64, _>("opacity"));
|
||||
let color = body.color.unwrap_or_else(|| existing.get::<String, _>("color"));
|
||||
let angle = body.angle.unwrap_or_else(|| existing.get::<i32, _>("angle"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
// Validate inputs
|
||||
if content.len() > 200 {
|
||||
return Json(ApiResponse::error("content too long (max 200 chars)"));
|
||||
}
|
||||
if !is_valid_hex_color(&color) {
|
||||
return Json(ApiResponse::error("invalid color format (expected #RRGGBB)"));
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE watermark_config SET content = ?, font_size = ?, opacity = ?, color = ?, angle = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&content)
|
||||
.bind(font_size)
|
||||
.bind(opacity)
|
||||
.bind(&color)
|
||||
.bind(angle)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
// Push updated config to online clients
|
||||
let config = WatermarkConfigPayload {
|
||||
content: content.clone(),
|
||||
font_size: font_size as u32,
|
||||
opacity,
|
||||
color: color.clone(),
|
||||
angle,
|
||||
enabled,
|
||||
};
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update watermark config", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_config(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
// Fetch existing config to get target info for push
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM watermark_config WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
|
||||
match sqlx::query("DELETE FROM watermark_config WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
// Push disabled watermark to clients
|
||||
let disabled = WatermarkConfigPayload {
|
||||
content: String::new(),
|
||||
font_size: 0,
|
||||
opacity: 0.0,
|
||||
color: String::new(),
|
||||
angle: 0,
|
||||
enabled: false,
|
||||
};
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &disabled, &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a hex color string (#RRGGBB format)
|
||||
fn is_valid_hex_color(color: &str) -> bool {
|
||||
if color.len() != 7 || !color.starts_with('#') {
|
||||
return false;
|
||||
}
|
||||
color[1..].chars().all(|c| c.is_ascii_hexdigit())
|
||||
}
|
||||
156
crates/server/src/api/plugins/web_filter.rs
Normal file
156
crates/server/src/api/plugins/web_filter.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use csm_protocol::MessageType;
|
||||
use crate::AppState;
|
||||
use crate::api::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRuleRequest {
|
||||
pub rule_type: String,
|
||||
pub pattern: String,
|
||||
pub target_type: Option<String>,
|
||||
pub target_id: Option<String>,
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
|
||||
"pattern": r.get::<String,_>("pattern"), "target_type": r.get::<String,_>("target_type"),
|
||||
"target_id": r.get::<Option<String>,_>("target_id"), "enabled": r.get::<bool,_>("enabled"),
|
||||
"created_at": r.get::<String,_>("created_at")
|
||||
})).collect::<Vec<_>>() }))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query web filter rules", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let enabled = req.enabled.unwrap_or(true);
|
||||
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
|
||||
|
||||
// Validate inputs
|
||||
if !matches!(req.rule_type.as_str(), "blacklist" | "whitelist" | "category") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid rule_type (expected blacklist|whitelist|category)")));
|
||||
}
|
||||
if req.pattern.trim().is_empty() || req.pattern.len() > 255 {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("pattern must be 1-255 chars")));
|
||||
}
|
||||
if !matches!(target_type.as_str(), "global" | "device" | "group") {
|
||||
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
|
||||
}
|
||||
|
||||
match sqlx::query("INSERT INTO web_filter_rules (rule_type, pattern, target_type, target_id, enabled) VALUES (?,?,?,?,?)")
|
||||
.bind(&req.rule_type).bind(&req.pattern).bind(&target_type).bind(&req.target_id).bind(enabled)
|
||||
.execute(&state.db).await {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
let rules = fetch_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create web filter rule", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateRuleRequest { pub rule_type: Option<String>, pub pattern: Option<String>, pub enabled: Option<bool> }
|
||||
|
||||
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT * FROM web_filter_rules WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query web filter rule", e)),
|
||||
};
|
||||
|
||||
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
|
||||
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
|
||||
|
||||
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
|
||||
.bind(&rule_type)
|
||||
.bind(&pattern)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let target_type_val: String = existing.get("target_type");
|
||||
let target_id_val: Option<String> = existing.get("target_id");
|
||||
let rules = fetch_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
Ok(_) => Json(ApiResponse::error("Not found")),
|
||||
Err(e) => Json(ApiResponse::internal_error("update web filter rule", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
|
||||
let existing = sqlx::query("SELECT target_type, target_id FROM web_filter_rules WHERE id = ?")
|
||||
.bind(id).fetch_optional(&state.db).await;
|
||||
let (target_type, target_id) = match existing {
|
||||
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
|
||||
_ => return Json(ApiResponse::error("Not found")),
|
||||
};
|
||||
match sqlx::query("DELETE FROM web_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
|
||||
Ok(r) if r.rows_affected() > 0 => {
|
||||
let rules = fetch_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
|
||||
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
|
||||
Json(ApiResponse::ok(()))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Not found")),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
|
||||
|
||||
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
|
||||
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
|
||||
.fetch_all(&state.db).await {
|
||||
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
|
||||
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
|
||||
"timestamp": r.get::<String,_>("timestamp")
|
||||
})).collect::<Vec<_>>() }))),
|
||||
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch enabled web filter rules applicable to a given target scope.
|
||||
/// For "device" targets, includes both global rules and device-specific rules
|
||||
/// (matching the logic used during registration push in tcp.rs).
|
||||
async fn fetch_rules_for_push(
|
||||
db: &sqlx::SqlitePool,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let query = match target_type {
|
||||
"device" => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
|
||||
).bind(target_id),
|
||||
_ => sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
|
||||
),
|
||||
};
|
||||
query.fetch_all(db).await
|
||||
.map(|rows| rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"), "pattern": r.get::<String,_>("pattern")
|
||||
})).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
246
crates/server/src/api/usb.rs
Normal file
246
crates/server/src/api/usb.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
|
||||
use crate::AppState;
|
||||
use super::ApiResponse;
|
||||
use crate::tcp::push_to_targets;
|
||||
use csm_protocol::{MessageType, UsbPolicyPayload, UsbDeviceRule};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UsbEventListParams {
|
||||
pub device_uid: Option<String>,
|
||||
pub event_type: Option<String>,
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn list_events(
|
||||
State(state): State<AppState>,
|
||||
Query(params): Query<UsbEventListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let limit = params.page_size.unwrap_or(20).min(100);
|
||||
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
|
||||
|
||||
// Normalize empty strings to None
|
||||
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
let event_type = params.event_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, device_uid, vendor_id, product_id, serial_number, device_name, event_type, event_time
|
||||
FROM usb_events WHERE 1=1
|
||||
AND (? IS NULL OR device_uid = ?)
|
||||
AND (? IS NULL OR event_type = ?)
|
||||
ORDER BY event_time DESC LIMIT ? OFFSET ?"
|
||||
)
|
||||
.bind(&device_uid).bind(&device_uid)
|
||||
.bind(&event_type).bind(&event_type)
|
||||
.bind(limit).bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"device_uid": r.get::<String, _>("device_uid"),
|
||||
"vendor_id": r.get::<Option<String>, _>("vendor_id"),
|
||||
"product_id": r.get::<Option<String>, _>("product_id"),
|
||||
"serial_number": r.get::<Option<String>, _>("serial_number"),
|
||||
"device_name": r.get::<Option<String>, _>("device_name"),
|
||||
"event_type": r.get::<String, _>("event_type"),
|
||||
"event_time": r.get::<String, _>("event_time"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"events": items,
|
||||
"page": params.page.unwrap_or(1),
|
||||
"page_size": limit,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query usb events", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_policies(
|
||||
State(state): State<AppState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
|
||||
FROM usb_policies ORDER BY created_at DESC"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(records) => {
|
||||
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"name": r.get::<String, _>("name"),
|
||||
"policy_type": r.get::<String, _>("policy_type"),
|
||||
"target_group": r.get::<Option<String>, _>("target_group"),
|
||||
"rules": r.get::<String, _>("rules"),
|
||||
"enabled": r.get::<i32, _>("enabled"),
|
||||
"created_at": r.get::<String, _>("created_at"),
|
||||
"updated_at": r.get::<String, _>("updated_at"),
|
||||
})).collect();
|
||||
Json(ApiResponse::ok(serde_json::json!({
|
||||
"policies": items,
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("query usb policies", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreatePolicyRequest {
|
||||
pub name: String,
|
||||
pub policy_type: String,
|
||||
pub target_group: Option<String>,
|
||||
pub rules: String,
|
||||
pub enabled: Option<i32>,
|
||||
}
|
||||
|
||||
pub async fn create_policy(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<CreatePolicyRequest>,
|
||||
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
|
||||
let enabled = body.enabled.unwrap_or(1);
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&body.name)
|
||||
.bind(&body.policy_type)
|
||||
.bind(&body.target_group)
|
||||
.bind(&body.rules)
|
||||
.bind(enabled)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
let new_id = r.last_insert_rowid();
|
||||
// Push USB policy to matching online clients
|
||||
if enabled == 1 {
|
||||
let payload = build_usb_policy_payload(&body.policy_type, true, &body.rules);
|
||||
let target_group = body.target_group.as_deref();
|
||||
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
|
||||
}
|
||||
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
|
||||
"id": new_id,
|
||||
}))))
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create usb policy", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdatePolicyRequest {
|
||||
pub name: Option<String>,
|
||||
pub policy_type: Option<String>,
|
||||
pub target_group: Option<String>,
|
||||
pub rules: Option<String>,
|
||||
pub enabled: Option<i32>,
|
||||
}
|
||||
|
||||
pub async fn update_policy(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
Json(body): Json<UpdatePolicyRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// Fetch existing policy
|
||||
let existing = sqlx::query("SELECT * FROM usb_policies WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let existing = match existing {
|
||||
Ok(Some(row)) => row,
|
||||
Ok(None) => return Json(ApiResponse::error("Policy not found")),
|
||||
Err(e) => return Json(ApiResponse::internal_error("query usb policy", e)),
|
||||
};
|
||||
|
||||
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
|
||||
let policy_type = body.policy_type.unwrap_or_else(|| existing.get::<String, _>("policy_type"));
|
||||
let target_group = body.target_group.or_else(|| existing.get::<Option<String>, _>("target_group"));
|
||||
let rules = body.rules.unwrap_or_else(|| existing.get::<String, _>("rules"));
|
||||
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE usb_policies SET name = ?, policy_type = ?, target_group = ?, rules = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(&policy_type)
|
||||
.bind(&target_group)
|
||||
.bind(&rules)
|
||||
.bind(enabled)
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Push updated USB policy to matching online clients
|
||||
let payload = build_usb_policy_payload(&policy_type, enabled == 1, &rules);
|
||||
let target_group = target_group.as_deref();
|
||||
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
|
||||
Json(ApiResponse::ok(serde_json::json!({"updated": true})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("update usb policy", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_policy(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// Fetch existing policy to get target info for push
|
||||
let existing = sqlx::query("SELECT target_group FROM usb_policies WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&state.db)
|
||||
.await;
|
||||
|
||||
let target_group = match existing {
|
||||
Ok(Some(row)) => row.get::<Option<String>, _>("target_group"),
|
||||
_ => return Json(ApiResponse::error("Policy not found")),
|
||||
};
|
||||
|
||||
let result = sqlx::query("DELETE FROM usb_policies WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() > 0 {
|
||||
// Push disabled policy to clients
|
||||
let disabled = UsbPolicyPayload {
|
||||
policy_type: String::new(),
|
||||
enabled: false,
|
||||
rules: vec![],
|
||||
};
|
||||
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &disabled, "group", target_group.as_deref()).await;
|
||||
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Policy not found"))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::internal_error("delete usb policy", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a UsbPolicyPayload from raw policy fields
|
||||
fn build_usb_policy_payload(policy_type: &str, enabled: bool, rules_json: &str) -> UsbPolicyPayload {
|
||||
let raw_rules: Vec<serde_json::Value> = serde_json::from_str(rules_json).unwrap_or_default();
|
||||
let rules: Vec<UsbDeviceRule> = raw_rules.iter().map(|r| UsbDeviceRule {
|
||||
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
|
||||
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
|
||||
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
|
||||
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
|
||||
}).collect();
|
||||
UsbPolicyPayload {
|
||||
policy_type: policy_type.to_string(),
|
||||
enabled,
|
||||
rules,
|
||||
}
|
||||
}
|
||||
28
crates/server/src/audit.rs
Normal file
28
crates/server/src/audit.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use sqlx::SqlitePool;
|
||||
use tracing::debug;
|
||||
|
||||
/// Record an admin audit log entry.
|
||||
pub async fn audit_log(
|
||||
db: &SqlitePool,
|
||||
user_id: i64,
|
||||
action: &str,
|
||||
target_type: Option<&str>,
|
||||
target_id: Option<&str>,
|
||||
detail: Option<&str>,
|
||||
) {
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO admin_audit_log (user_id, action, target_type, target_id, detail) VALUES (?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(action)
|
||||
.bind(target_type)
|
||||
.bind(target_id)
|
||||
.bind(detail)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => debug!("Audit: user={} action={} target={}/{}", user_id, action, target_type.unwrap_or("-"), target_id.unwrap_or("-")),
|
||||
Err(e) => tracing::warn!("Failed to write audit log: {}", e),
|
||||
}
|
||||
}
|
||||
134
crates/server/src/config.rs
Normal file
134
crates/server/src/config.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
pub retention: RetentionConfig,
|
||||
#[serde(default)]
|
||||
pub notify: NotifyConfig,
|
||||
/// Token required for device registration. Empty = any token accepted.
|
||||
#[serde(default)]
|
||||
pub registration_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
pub http_addr: String,
|
||||
pub tcp_addr: String,
|
||||
/// Allowed CORS origins. Empty = same-origin only (no CORS headers).
|
||||
#[serde(default)]
|
||||
pub cors_origins: Vec<String>,
|
||||
/// Optional TLS configuration for the TCP listener.
|
||||
#[serde(default)]
|
||||
pub tls: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct TlsConfig {
|
||||
/// Path to the server certificate (PEM format)
|
||||
pub cert_path: String,
|
||||
/// Path to the server private key (PEM format)
|
||||
pub key_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct DatabaseConfig {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct AuthConfig {
|
||||
pub jwt_secret: String,
|
||||
#[serde(default = "default_access_ttl")]
|
||||
pub access_token_ttl_secs: u64,
|
||||
#[serde(default = "default_refresh_ttl")]
|
||||
pub refresh_token_ttl_secs: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct RetentionConfig {
|
||||
#[serde(default = "default_status_history_days")]
|
||||
pub status_history_days: u32,
|
||||
#[serde(default = "default_usb_events_days")]
|
||||
pub usb_events_days: u32,
|
||||
#[serde(default = "default_asset_changes_days")]
|
||||
pub asset_changes_days: u32,
|
||||
#[serde(default = "default_alert_records_days")]
|
||||
pub alert_records_days: u32,
|
||||
#[serde(default = "default_audit_log_days")]
|
||||
pub audit_log_days: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct NotifyConfig {
|
||||
#[serde(default)]
|
||||
pub smtp: Option<SmtpConfig>,
|
||||
#[serde(default)]
|
||||
pub webhook_urls: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct SmtpConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub from: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub async fn load(path: &str) -> Result<Self> {
|
||||
if Path::new(path).exists() {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
let config: AppConfig = toml::from_str(&content)?;
|
||||
Ok(config)
|
||||
} else {
|
||||
let config = default_config();
|
||||
// Write default config for reference
|
||||
let toml_str = toml::to_string_pretty(&config)?;
|
||||
tokio::fs::write(path, &toml_str).await?;
|
||||
tracing::warn!("Created default config at {}", path);
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_access_ttl() -> u64 { 1800 } // 30 minutes
|
||||
fn default_refresh_ttl() -> u64 { 604800 } // 7 days
|
||||
fn default_status_history_days() -> u32 { 7 }
|
||||
fn default_usb_events_days() -> u32 { 90 }
|
||||
fn default_asset_changes_days() -> u32 { 365 }
|
||||
fn default_alert_records_days() -> u32 { 90 }
|
||||
fn default_audit_log_days() -> u32 { 365 }
|
||||
|
||||
pub fn default_config() -> AppConfig {
|
||||
AppConfig {
|
||||
server: ServerConfig {
|
||||
http_addr: "0.0.0.0:8080".into(),
|
||||
tcp_addr: "0.0.0.0:9999".into(),
|
||||
cors_origins: vec![],
|
||||
tls: None,
|
||||
},
|
||||
database: DatabaseConfig {
|
||||
path: "./csm.db".into(),
|
||||
},
|
||||
auth: AuthConfig {
|
||||
jwt_secret: uuid::Uuid::new_v4().to_string(),
|
||||
access_token_ttl_secs: default_access_ttl(),
|
||||
refresh_token_ttl_secs: default_refresh_ttl(),
|
||||
},
|
||||
retention: RetentionConfig {
|
||||
status_history_days: default_status_history_days(),
|
||||
usb_events_days: default_usb_events_days(),
|
||||
asset_changes_days: default_asset_changes_days(),
|
||||
alert_records_days: default_alert_records_days(),
|
||||
audit_log_days: default_audit_log_days(),
|
||||
},
|
||||
notify: NotifyConfig::default(),
|
||||
registration_token: uuid::Uuid::new_v4().to_string(),
|
||||
}
|
||||
}
|
||||
118
crates/server/src/db.rs
Normal file
118
crates/server/src/db.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use sqlx::SqlitePool;
|
||||
use anyhow::Result;
|
||||
|
||||
/// Database repository for device operations
|
||||
pub struct DeviceRepo;
|
||||
|
||||
impl DeviceRepo {
|
||||
pub async fn upsert_status(pool: &SqlitePool, device_uid: &str, status: &csm_protocol::DeviceStatus) -> Result<()> {
|
||||
let top_procs_json = serde_json::to_string(&status.top_processes)?;
|
||||
|
||||
// Update latest snapshot
|
||||
sqlx::query(
|
||||
"INSERT INTO device_status (device_uid, cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb, network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(device_uid) DO UPDATE SET
|
||||
cpu_usage = excluded.cpu_usage,
|
||||
memory_usage = excluded.memory_usage,
|
||||
memory_total_mb = excluded.memory_total_mb,
|
||||
disk_usage = excluded.disk_usage,
|
||||
disk_total_mb = excluded.disk_total_mb,
|
||||
network_rx_rate = excluded.network_rx_rate,
|
||||
network_tx_rate = excluded.network_tx_rate,
|
||||
running_procs = excluded.running_procs,
|
||||
top_processes = excluded.top_processes,
|
||||
reported_at = datetime('now'),
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(status.cpu_usage)
|
||||
.bind(status.memory_usage)
|
||||
.bind(status.memory_total_mb as i64)
|
||||
.bind(status.disk_usage)
|
||||
.bind(status.disk_total_mb as i64)
|
||||
.bind(status.network_rx_rate as i64)
|
||||
.bind(status.network_tx_rate as i64)
|
||||
.bind(status.running_procs as i32)
|
||||
.bind(&top_procs_json)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Insert into history
|
||||
sqlx::query(
|
||||
"INSERT INTO device_status_history (device_uid, cpu_usage, memory_usage, disk_usage, network_rx_rate, network_tx_rate, running_procs, reported_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(status.cpu_usage)
|
||||
.bind(status.memory_usage)
|
||||
.bind(status.disk_usage)
|
||||
.bind(status.network_rx_rate as i64)
|
||||
.bind(status.network_tx_rate as i64)
|
||||
.bind(status.running_procs as i32)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Update device heartbeat
|
||||
sqlx::query(
|
||||
"UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_usb_event(pool: &SqlitePool, event: &csm_protocol::UsbEvent) -> Result<i64> {
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO usb_events (device_uid, vendor_id, product_id, serial_number, device_name, event_type)
|
||||
VALUES (?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&event.device_uid)
|
||||
.bind(&event.vendor_id)
|
||||
.bind(&event.product_id)
|
||||
.bind(&event.serial)
|
||||
.bind(&event.device_name)
|
||||
.bind(match event.event_type {
|
||||
csm_protocol::UsbEventType::Inserted => "inserted",
|
||||
csm_protocol::UsbEventType::Removed => "removed",
|
||||
csm_protocol::UsbEventType::Blocked => "blocked",
|
||||
})
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(result.last_insert_rowid())
|
||||
}
|
||||
|
||||
pub async fn upsert_hardware(pool: &SqlitePool, asset: &csm_protocol::HardwareAsset) -> Result<()> {
|
||||
sqlx::query(
|
||||
"INSERT INTO hardware_assets (device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb, gpu_model, motherboard, serial_number, reported_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(device_uid) DO UPDATE SET
|
||||
cpu_model = excluded.cpu_model,
|
||||
cpu_cores = excluded.cpu_cores,
|
||||
memory_total_mb = excluded.memory_total_mb,
|
||||
disk_model = excluded.disk_model,
|
||||
disk_total_mb = excluded.disk_total_mb,
|
||||
gpu_model = excluded.gpu_model,
|
||||
motherboard = excluded.motherboard,
|
||||
serial_number = excluded.serial_number,
|
||||
reported_at = datetime('now'),
|
||||
updated_at = datetime('now')"
|
||||
)
|
||||
.bind(&asset.device_uid)
|
||||
.bind(&asset.cpu_model)
|
||||
.bind(asset.cpu_cores as i32)
|
||||
.bind(asset.memory_total_mb as i64)
|
||||
.bind(&asset.disk_model)
|
||||
.bind(asset.disk_total_mb as i64)
|
||||
.bind(&asset.gpu_model)
|
||||
.bind(&asset.motherboard)
|
||||
.bind(&asset.serial_number)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
264
crates/server/src/main.rs
Normal file
264
crates/server/src/main.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use anyhow::Result;
|
||||
use axum::Router;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::cors::{CorsLayer, Any};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::compression::CompressionLayer;
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
use tracing::{info, warn, error};
|
||||
|
||||
mod api;
|
||||
mod audit;
|
||||
mod config;
|
||||
mod db;
|
||||
mod tcp;
|
||||
mod ws;
|
||||
mod alert;
|
||||
|
||||
use config::AppConfig;
|
||||
|
||||
/// Application shared state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub db: sqlx::SqlitePool,
|
||||
pub config: Arc<AppConfig>,
|
||||
pub clients: Arc<tcp::ClientRegistry>,
|
||||
pub ws_hub: Arc<ws::WsHub>,
|
||||
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "csm_server=info,tower_http=info".into()),
|
||||
)
|
||||
.json()
|
||||
.init();
|
||||
|
||||
info!("CSM Server starting...");
|
||||
|
||||
// Load configuration
|
||||
let config = AppConfig::load("config.toml").await?;
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Initialize database
|
||||
let db = init_database(&config.database.path).await?;
|
||||
run_migrations(&db).await?;
|
||||
info!("Database initialized at {}", config.database.path);
|
||||
|
||||
// Ensure default admin exists
|
||||
ensure_default_admin(&db).await?;
|
||||
|
||||
// Initialize shared state
|
||||
let clients = Arc::new(tcp::ClientRegistry::new());
|
||||
let ws_hub = Arc::new(ws::WsHub::new());
|
||||
|
||||
let state = AppState {
|
||||
db: db.clone(),
|
||||
config: config.clone(),
|
||||
clients: clients.clone(),
|
||||
ws_hub: ws_hub.clone(),
|
||||
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
|
||||
};
|
||||
|
||||
// Start background tasks
|
||||
let cleanup_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
alert::cleanup_task(cleanup_state).await;
|
||||
});
|
||||
|
||||
// Start TCP listener for client connections
|
||||
let tcp_state = state.clone();
|
||||
let tcp_addr = config.server.tcp_addr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = tcp::start_tcp_server(tcp_addr, tcp_state).await {
|
||||
error!("TCP server error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Build HTTP router
|
||||
let app = Router::new()
|
||||
.merge(api::routes(state.clone()))
|
||||
.layer(
|
||||
build_cors_layer(&config.server.cors_origins),
|
||||
)
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CompressionLayer::new())
|
||||
// Security headers
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::X_CONTENT_TYPE_OPTIONS,
|
||||
axum::http::HeaderValue::from_static("nosniff"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::X_FRAME_OPTIONS,
|
||||
axum::http::HeaderValue::from_static("DENY"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("x-xss-protection"),
|
||||
axum::http::HeaderValue::from_static("1; mode=block"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("referrer-policy"),
|
||||
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
|
||||
))
|
||||
.layer(SetResponseHeaderLayer::if_not_present(
|
||||
axum::http::header::HeaderName::from_static("content-security-policy"),
|
||||
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
|
||||
))
|
||||
.with_state(state);
|
||||
|
||||
// Start HTTP server
|
||||
let http_addr = &config.server.http_addr;
|
||||
info!("HTTP server listening on {}", http_addr);
|
||||
let listener = TcpListener::bind(http_addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn init_database(db_path: &str) -> Result<sqlx::SqlitePool> {
|
||||
// Ensure parent directory exists for file-based databases
|
||||
// Strip sqlite: prefix if present for directory creation
|
||||
let file_path = db_path.strip_prefix("sqlite:").unwrap_or(db_path);
|
||||
// Strip query parameters
|
||||
let file_path = file_path.split('?').next().unwrap_or(file_path);
|
||||
if let Some(parent) = Path::new(file_path).parent() {
|
||||
if !parent.as_os_str().is_empty() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let options = SqliteConnectOptions::from_str(db_path)?
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
|
||||
.busy_timeout(std::time::Duration::from_secs(5))
|
||||
.foreign_keys(true);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(8)
|
||||
.connect_with(options)
|
||||
.await?;
|
||||
|
||||
// Set pragmas on each connection
|
||||
sqlx::query("PRAGMA cache_size = -64000")
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
sqlx::query("PRAGMA wal_autocheckpoint = 1000")
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
// Embedded migrations - run in order
|
||||
let migrations = [
|
||||
include_str!("../../../migrations/001_init.sql"),
|
||||
include_str!("../../../migrations/002_assets.sql"),
|
||||
include_str!("../../../migrations/003_usb.sql"),
|
||||
include_str!("../../../migrations/004_alerts.sql"),
|
||||
include_str!("../../../migrations/005_plugins_web_filter.sql"),
|
||||
include_str!("../../../migrations/006_plugins_usage_timer.sql"),
|
||||
include_str!("../../../migrations/007_plugins_software_blocker.sql"),
|
||||
include_str!("../../../migrations/008_plugins_popup_blocker.sql"),
|
||||
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
|
||||
include_str!("../../../migrations/010_plugins_watermark.sql"),
|
||||
include_str!("../../../migrations/011_token_security.sql"),
|
||||
];
|
||||
|
||||
// Create migrations tracking table
|
||||
sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS _migrations (id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, applied_at TEXT NOT NULL DEFAULT (datetime('now')))"
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
for (i, migration_sql) in migrations.iter().enumerate() {
|
||||
let name = format!("{:03}", i + 1);
|
||||
let exists: bool = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM _migrations WHERE name = ?"
|
||||
)
|
||||
.bind(&name)
|
||||
.fetch_one(pool)
|
||||
.await? > 0;
|
||||
|
||||
if !exists {
|
||||
info!("Running migration: {}", name);
|
||||
sqlx::query(migration_sql)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
sqlx::query("INSERT INTO _migrations (name) VALUES (?)")
|
||||
.bind(&name)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
|
||||
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users")
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if count == 0 {
|
||||
// Generate a random 16-character alphanumeric password
|
||||
let random_password: String = {
|
||||
use std::fmt::Write;
|
||||
let bytes = uuid::Uuid::new_v4();
|
||||
let mut s = String::with_capacity(16);
|
||||
for byte in bytes.as_bytes().iter().take(16) {
|
||||
write!(s, "{:02x}", byte).unwrap();
|
||||
}
|
||||
s
|
||||
};
|
||||
|
||||
let password_hash = bcrypt::hash(&random_password, 12)?;
|
||||
sqlx::query(
|
||||
"INSERT INTO users (username, password, role) VALUES (?, ?, 'admin')"
|
||||
)
|
||||
.bind("admin")
|
||||
.bind(&password_hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
warn!("Created default admin user (username: admin)");
|
||||
// Print password directly to stderr — bypasses tracing JSON formatter
|
||||
eprintln!("============================================================");
|
||||
eprintln!(" Generated admin password: {}", random_password);
|
||||
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
|
||||
eprintln!("============================================================");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build CORS layer from configured origins.
|
||||
/// If cors_origins is empty, no CORS headers are sent (same-origin only).
|
||||
/// If origins are specified, only those are allowed.
|
||||
fn build_cors_layer(origins: &[String]) -> CorsLayer {
|
||||
use axum::http::HeaderValue;
|
||||
|
||||
let allowed_origins: Vec<HeaderValue> = origins.iter()
|
||||
.filter_map(|o| o.parse::<HeaderValue>().ok())
|
||||
.collect();
|
||||
|
||||
if allowed_origins.is_empty() {
|
||||
// No CORS — production safe by default
|
||||
CorsLayer::new()
|
||||
} else {
|
||||
CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.max_age(std::time::Duration::from_secs(3600))
|
||||
}
|
||||
}
|
||||
844
crates/server/src/tcp.rs
Normal file
844
crates/server/src/tcp.rs
Normal file
@@ -0,0 +1,844 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tracing::{info, warn, debug};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use csm_protocol::{Frame, MessageType, PROTOCOL_VERSION};
|
||||
use crate::AppState;
|
||||
|
||||
/// Maximum frames per second per connection before rate-limiting kicks in
|
||||
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
|
||||
const RATE_LIMIT_MAX_FRAMES: usize = 100;
|
||||
|
||||
/// Per-connection rate limiter using a sliding window of frame timestamps
|
||||
struct RateLimiter {
|
||||
timestamps: Vec<Instant>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
fn new() -> Self {
|
||||
Self { timestamps: Vec::with_capacity(RATE_LIMIT_MAX_FRAMES) }
|
||||
}
|
||||
|
||||
/// Returns false if the connection is rate-limited
|
||||
fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
let cutoff = now - std::time::Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
|
||||
|
||||
// Evict timestamps outside the window
|
||||
self.timestamps.retain(|t| *t > cutoff);
|
||||
|
||||
if self.timestamps.len() >= RATE_LIMIT_MAX_FRAMES {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.timestamps.push(now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a plugin config frame to all online clients matching the target scope.
|
||||
/// target_type: "global" | "device" | "group"
|
||||
/// target_id: device_uid or group_name (None for global)
|
||||
pub async fn push_to_targets(
|
||||
db: &sqlx::SqlitePool,
|
||||
clients: &crate::tcp::ClientRegistry,
|
||||
msg_type: MessageType,
|
||||
payload: &impl serde::Serialize,
|
||||
target_type: &str,
|
||||
target_id: Option<&str>,
|
||||
) {
|
||||
let frame = match Frame::new_json(msg_type, payload) {
|
||||
Ok(f) => f.encode(),
|
||||
Err(e) => {
|
||||
warn!("Failed to encode plugin push frame: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let online = clients.list_online().await;
|
||||
let mut pushed_count = 0usize;
|
||||
|
||||
// For group targeting, resolve group members from DB once
|
||||
let group_members: Option<Vec<String>> = if target_type == "group" {
|
||||
if let Some(group_name) = target_id {
|
||||
sqlx::query_scalar::<_, String>(
|
||||
"SELECT device_uid FROM devices WHERE group_name = ?"
|
||||
)
|
||||
.bind(group_name)
|
||||
.fetch_all(db)
|
||||
.await
|
||||
.ok()
|
||||
.into()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
for uid in &online {
|
||||
let should_push = match target_type {
|
||||
"global" => true,
|
||||
"device" => target_id.map_or(false, |id| id == uid),
|
||||
"group" => {
|
||||
if let Some(members) = &group_members {
|
||||
members.contains(uid)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
other => {
|
||||
warn!("Unknown target_type '{}', skipping push", other);
|
||||
false
|
||||
}
|
||||
};
|
||||
if should_push {
|
||||
if clients.send_to(uid, frame.clone()).await {
|
||||
pushed_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Pushed {:?} to {}/{} online clients (target={})", msg_type, pushed_count, online.len(), target_type);
|
||||
}
|
||||
|
||||
/// Push all active plugin configs to a newly registered client.
|
||||
pub async fn push_all_plugin_configs(
|
||||
db: &sqlx::SqlitePool,
|
||||
clients: &crate::tcp::ClientRegistry,
|
||||
device_uid: &str,
|
||||
) {
|
||||
use sqlx::Row;
|
||||
|
||||
// Watermark configs — only push the highest-priority enabled config (device > group > global)
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT content, font_size, opacity, color, angle, enabled FROM watermark_config WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?))) ORDER BY CASE WHEN target_type = 'device' THEN 0 WHEN target_type = 'group' THEN 1 ELSE 2 END LIMIT 1"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
if let Some(row) = rows.first() {
|
||||
let config = csm_protocol::WatermarkConfigPayload {
|
||||
content: row.get("content"),
|
||||
font_size: row.get::<i32, _>("font_size") as u32,
|
||||
opacity: row.get("opacity"),
|
||||
color: row.get("color"),
|
||||
angle: row.get::<i32, _>("angle"),
|
||||
enabled: row.get("enabled"),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::WatermarkConfig, &config) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Web filter rules
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"pattern": r.get::<String, _>("pattern"),
|
||||
})).collect();
|
||||
if !rules.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Software blacklist
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let entries: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"name_pattern": r.get::<String, _>("name_pattern"),
|
||||
"action": r.get::<String, _>("action"),
|
||||
})).collect();
|
||||
if !entries.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Popup blocker rules
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
|
||||
"id": r.get::<i64, _>("id"),
|
||||
"rule_type": r.get::<String, _>("rule_type"),
|
||||
"window_title": r.get::<Option<String>, _>("window_title"),
|
||||
"window_class": r.get::<Option<String>, _>("window_class"),
|
||||
"process_name": r.get::<Option<String>, _>("process_name"),
|
||||
})).collect();
|
||||
if !rules.is_empty() {
|
||||
if let Ok(frame) = Frame::new_json(MessageType::PopupRules, &serde_json::json!({"rules": rules})) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// USB policies — push highest-priority enabled policy for the device's group
|
||||
if let Ok(rows) = sqlx::query(
|
||||
"SELECT policy_type, rules, enabled FROM usb_policies WHERE enabled = 1 AND target_group = (SELECT group_name FROM devices WHERE device_uid = ?) ORDER BY CASE WHEN policy_type = 'all_block' THEN 0 WHEN policy_type = 'blacklist' THEN 1 ELSE 2 END LIMIT 1"
|
||||
)
|
||||
.bind(device_uid)
|
||||
.fetch_all(db).await
|
||||
{
|
||||
if let Some(row) = rows.first() {
|
||||
let policy_type: String = row.get("policy_type");
|
||||
let rules_json: String = row.get("rules");
|
||||
let rules: Vec<serde_json::Value> = serde_json::from_str(&rules_json).unwrap_or_default();
|
||||
let payload = csm_protocol::UsbPolicyPayload {
|
||||
policy_type,
|
||||
enabled: true,
|
||||
rules: rules.iter().map(|r| csm_protocol::UsbDeviceRule {
|
||||
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
|
||||
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
|
||||
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
|
||||
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
|
||||
}).collect(),
|
||||
};
|
||||
if let Ok(frame) = Frame::new_json(MessageType::UsbPolicyUpdate, &payload) {
|
||||
clients.send_to(device_uid, frame.encode()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Pushed all plugin configs to newly registered device {}", device_uid);
|
||||
}
|
||||
|
||||
/// Maximum accumulated read buffer size per connection (8 MB)
|
||||
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
|
||||
|
||||
/// Registry of all connected client sessions
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ClientRegistry {
|
||||
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
|
||||
}
|
||||
|
||||
impl ClientRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
|
||||
self.sessions.write().await.insert(device_uid, tx);
|
||||
}
|
||||
|
||||
pub async fn unregister(&self, device_uid: &str) {
|
||||
self.sessions.write().await.remove(device_uid);
|
||||
}
|
||||
|
||||
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
|
||||
if let Some(tx) = self.sessions.read().await.get(device_uid) {
|
||||
tx.send(data).await.is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn count(&self) -> usize {
|
||||
self.sessions.read().await.len()
|
||||
}
|
||||
|
||||
pub async fn list_online(&self) -> Vec<String> {
|
||||
self.sessions.read().await.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the TCP server for client connections (optionally with TLS)
|
||||
pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<()> {
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
|
||||
// Build TLS acceptor if configured
|
||||
let tls_acceptor = build_tls_acceptor(&state.config.server.tls)?;
|
||||
|
||||
if tls_acceptor.is_some() {
|
||||
info!("TCP server listening on {} (TLS enabled)", addr);
|
||||
} else {
|
||||
info!("TCP server listening on {} (plaintext)", addr);
|
||||
}
|
||||
|
||||
loop {
|
||||
let (stream, peer_addr) = listener.accept().await?;
|
||||
let state = state.clone();
|
||||
let acceptor = tls_acceptor.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
debug!("New TCP connection from {}", peer_addr);
|
||||
match acceptor {
|
||||
Some(acceptor) => {
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
if let Err(e) = handle_client_tls(tls_stream, state).await {
|
||||
warn!("Client {} TLS error: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("TLS handshake failed for {}: {}", peer_addr, e),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if let Err(e) = handle_client(stream, state).await {
|
||||
warn!("Client {} error: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn build_tls_acceptor(
|
||||
tls_config: &Option<crate::config::TlsConfig>,
|
||||
) -> anyhow::Result<Option<tokio_rustls::TlsAcceptor>> {
|
||||
let config = match tls_config {
|
||||
Some(c) => c,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let cert_pem = std::fs::read(&config.cert_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read TLS cert {}: {}", config.cert_path, e))?;
|
||||
let key_pem = std::fs::read(&config.key_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read TLS key {}: {}", config.key_path, e))?;
|
||||
|
||||
let certs: Vec<rustls_pki_types::CertificateDer> = rustls_pemfile::certs(&mut &cert_pem[..])
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse TLS cert: {:?}", e))?
|
||||
.into_iter()
|
||||
.map(|c| c.into())
|
||||
.collect();
|
||||
|
||||
let key = rustls_pemfile::private_key(&mut &key_pem[..])
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse TLS key: {:?}", e))?
|
||||
.ok_or_else(|| anyhow::anyhow!("No private key found in {}", config.key_path))?;
|
||||
|
||||
let server_config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build TLS config: {}", e))?;
|
||||
|
||||
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(server_config))))
|
||||
}
|
||||
|
||||
/// Cleanup on client disconnect: unregister from client map, mark offline, notify WS.
|
||||
async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
|
||||
if let Some(uid) = device_uid {
|
||||
state.clients.unregister(uid).await;
|
||||
sqlx::query("UPDATE devices SET status = 'offline' WHERE device_uid = ?")
|
||||
.bind(uid)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "device_state",
|
||||
"device_uid": uid,
|
||||
"status": "offline"
|
||||
}).to_string()).await;
|
||||
|
||||
info!("Device disconnected: {}", uid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 for heartbeat verification.
|
||||
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
|
||||
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let message = format!("{}\n{}", device_uid, timestamp);
|
||||
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return String::new(),
|
||||
};
|
||||
mac.update(message.as_bytes());
|
||||
hex::encode(mac.finalize().into_bytes())
|
||||
}
|
||||
|
||||
/// Verify that a frame sender is a registered device and that the claimed device_uid
|
||||
/// matches the one registered on this connection. Returns true if valid.
|
||||
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
|
||||
match device_uid {
|
||||
Some(uid) if *uid == claimed_uid => true,
|
||||
Some(uid) => {
|
||||
warn!("{} device_uid mismatch: expected {:?}, got {}", msg_type, uid, claimed_uid);
|
||||
false
|
||||
}
|
||||
None => {
|
||||
warn!("{} from unregistered connection", msg_type);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
|
||||
async fn process_frame(
|
||||
frame: Frame,
|
||||
state: &AppState,
|
||||
device_uid: &mut Option<String>,
|
||||
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
|
||||
) -> anyhow::Result<()> {
|
||||
match frame.msg_type {
|
||||
MessageType::Register => {
|
||||
let req: csm_protocol::RegisterRequest = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid registration payload: {}", e))?;
|
||||
|
||||
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
|
||||
|
||||
// Validate registration token against configured token
|
||||
let expected_token = &state.config.registration_token;
|
||||
if !expected_token.is_empty() {
|
||||
if req.registration_token.is_empty() || req.registration_token != *expected_token {
|
||||
warn!("Registration rejected for {}: invalid token", req.device_uid);
|
||||
let err_frame = Frame::new_json(MessageType::RegisterResponse,
|
||||
&serde_json::json!({"error": "invalid_registration_token"}))?;
|
||||
tx.send(err_frame.encode()).await.ok();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Check if device already exists with a secret (reconnection scenario)
|
||||
let existing_secret: Option<String> = sqlx::query_scalar(
|
||||
"SELECT device_secret FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&req.device_uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let device_secret = match existing_secret {
|
||||
// Existing device — keep the same secret, don't rotate
|
||||
Some(secret) if !secret.is_empty() => secret,
|
||||
// New device — generate a fresh secret
|
||||
_ => uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
|
||||
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
|
||||
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
|
||||
mac_address=excluded.mac_address, status='online'"
|
||||
)
|
||||
.bind(&req.device_uid)
|
||||
.bind(&req.hostname)
|
||||
.bind(&req.mac_address)
|
||||
.bind(&req.os_version)
|
||||
.bind(&device_secret)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error during registration: {}", e))?;
|
||||
|
||||
*device_uid = Some(req.device_uid.clone());
|
||||
// If this device was already connected on a different session, evict the old one
|
||||
// The new register() call will replace it in the hashmap
|
||||
state.clients.register(req.device_uid.clone(), tx.clone()).await;
|
||||
|
||||
// Send registration response
|
||||
let config = csm_protocol::ClientConfig::default();
|
||||
let response = csm_protocol::RegisterResponse {
|
||||
device_secret,
|
||||
config,
|
||||
};
|
||||
let resp_frame = Frame::new_json(MessageType::RegisterResponse, &response)?;
|
||||
tx.send(resp_frame.encode()).await?;
|
||||
|
||||
info!("Device registered successfully: {} ({})", req.hostname, req.device_uid);
|
||||
|
||||
// Push all active plugin configs to newly registered client
|
||||
push_all_plugin_configs(&state.db, &state.clients, &req.device_uid).await;
|
||||
}
|
||||
|
||||
MessageType::Heartbeat => {
|
||||
let heartbeat: csm_protocol::HeartbeatPayload = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid heartbeat: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "Heartbeat", &heartbeat.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
|
||||
let secret: Option<String> = sqlx::query_scalar(
|
||||
"SELECT device_secret FROM devices WHERE device_uid = ?"
|
||||
)
|
||||
.bind(&heartbeat.device_uid)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
|
||||
anyhow::anyhow!("DB error during HMAC verification")
|
||||
})?;
|
||||
|
||||
if let Some(ref secret) = secret {
|
||||
if !secret.is_empty() {
|
||||
if heartbeat.hmac.is_empty() {
|
||||
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
|
||||
return Ok(());
|
||||
}
|
||||
// Constant-time HMAC verification using hmac::Mac::verify_slice
|
||||
let message = format!("{}\n{}", heartbeat.device_uid, heartbeat.timestamp);
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
|
||||
.map_err(|_| anyhow::anyhow!("HMAC key error"))?;
|
||||
mac.update(message.as_bytes());
|
||||
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
|
||||
if mac.verify_slice(&provided_bytes).is_err() {
|
||||
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Heartbeat from {} (hmac verified)", heartbeat.device_uid);
|
||||
|
||||
// Update device status in DB
|
||||
sqlx::query("UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?")
|
||||
.bind(&heartbeat.device_uid)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Push to WebSocket subscribers
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "device_state",
|
||||
"device_uid": heartbeat.device_uid,
|
||||
"status": "online"
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::StatusReport => {
|
||||
let status: csm_protocol::DeviceStatus = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid status report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "StatusReport", &status.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
crate::db::DeviceRepo::upsert_status(&state.db, &status.device_uid, &status).await?;
|
||||
|
||||
// Push to WebSocket subscribers
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "device_status",
|
||||
"device_uid": status.device_uid,
|
||||
"cpu": status.cpu_usage,
|
||||
"memory": status.memory_usage,
|
||||
"disk": status.disk_usage
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::UsbEvent => {
|
||||
let event: csm_protocol::UsbEvent = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid USB event: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "UsbEvent", &event.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
crate::db::DeviceRepo::insert_usb_event(&state.db, &event).await?;
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "usb_event",
|
||||
"device_uid": event.device_uid,
|
||||
"event": event.event_type,
|
||||
"usb_name": event.device_name
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::AssetReport => {
|
||||
let asset: csm_protocol::HardwareAsset = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid asset report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "AssetReport", &asset.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
crate::db::DeviceRepo::upsert_hardware(&state.db, &asset).await?;
|
||||
}
|
||||
|
||||
MessageType::UsageReport => {
|
||||
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "UsageReport", &report.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO usage_daily (device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at) \
|
||||
VALUES (?, ?, ?, ?, ?, ?) \
|
||||
ON CONFLICT(device_uid, date) DO UPDATE SET \
|
||||
total_active_minutes = excluded.total_active_minutes, \
|
||||
total_idle_minutes = excluded.total_idle_minutes, \
|
||||
first_active_at = excluded.first_active_at, \
|
||||
last_active_at = excluded.last_active_at"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.date)
|
||||
.bind(report.total_active_minutes as i32)
|
||||
.bind(report.total_idle_minutes as i32)
|
||||
.bind(&report.first_active_at)
|
||||
.bind(&report.last_active_at)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting usage report: {}", e))?;
|
||||
|
||||
debug!("Usage report saved for device {}", report.device_uid);
|
||||
}
|
||||
|
||||
MessageType::AppUsageReport => {
|
||||
let report: csm_protocol::AppUsageEntry = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid app usage report: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "AppUsageReport", &report.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
|
||||
VALUES (?, ?, ?, ?) \
|
||||
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
|
||||
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.date)
|
||||
.bind(&report.app_name)
|
||||
.bind(report.usage_minutes as i32)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting app usage: {}", e))?;
|
||||
|
||||
debug!("App usage saved: {} -> {} ({} min)", report.device_uid, report.app_name, report.usage_minutes);
|
||||
}
|
||||
|
||||
MessageType::SoftwareViolation => {
|
||||
let report: csm_protocol::SoftwareViolationReport = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid software violation: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "SoftwareViolation", &report.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO software_violations (device_uid, software_name, action_taken, timestamp) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&report.device_uid)
|
||||
.bind(&report.software_name)
|
||||
.bind(&report.action_taken)
|
||||
.bind(&report.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting software violation: {}", e))?;
|
||||
|
||||
info!("Software violation: {} tried to run {} -> {}", report.device_uid, report.software_name, report.action_taken);
|
||||
|
||||
state.ws_hub.broadcast(serde_json::json!({
|
||||
"type": "software_violation",
|
||||
"device_uid": report.device_uid,
|
||||
"software_name": report.software_name,
|
||||
"action_taken": report.action_taken
|
||||
}).to_string()).await;
|
||||
}
|
||||
|
||||
MessageType::UsbFileOp => {
|
||||
let entry: csm_protocol::UsbFileOpEntry = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid USB file op: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "UsbFileOp", &entry.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO usb_file_operations (device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&entry.device_uid)
|
||||
.bind(&entry.usb_serial)
|
||||
.bind(&entry.drive_letter)
|
||||
.bind(&entry.operation)
|
||||
.bind(&entry.file_path)
|
||||
.bind(entry.file_size.map(|s| s as i64))
|
||||
.bind(&entry.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting USB file op: {}", e))?;
|
||||
|
||||
debug!("USB file op: {} {} on {}", entry.operation, entry.file_path, entry.device_uid);
|
||||
}
|
||||
|
||||
MessageType::WebAccessLog => {
|
||||
let entry: csm_protocol::WebAccessLogEntry = frame.decode_payload()
|
||||
.map_err(|e| anyhow::anyhow!("Invalid web access log: {}", e))?;
|
||||
|
||||
if !verify_device_uid(device_uid, "WebAccessLog", &entry.device_uid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO web_access_log (device_uid, url, action, timestamp) VALUES (?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&entry.device_uid)
|
||||
.bind(&entry.url)
|
||||
.bind(&entry.action)
|
||||
.bind(&entry.timestamp)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("DB error inserting web access log: {}", e))?;
|
||||
|
||||
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Unhandled message type: {:?}", frame.msg_type);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a single client TCP connection
|
||||
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Set read timeout to detect stale connections
|
||||
let _ = stream.set_nodelay(true);
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
|
||||
// Writer task: forwards messages from channel to TCP stream
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
if writer.write_all(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Reader loop
|
||||
'reader: loop {
|
||||
let n = reader.read(&mut buffer).await?;
|
||||
if n == 0 {
|
||||
break; // Connection closed
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
// Guard against unbounded buffer growth
|
||||
if read_buf.len() > MAX_READ_BUF_SIZE {
|
||||
warn!("Connection exceeded max buffer size, dropping");
|
||||
break;
|
||||
}
|
||||
|
||||
// Process complete frames
|
||||
while let Some(frame) = Frame::decode(&read_buf)? {
|
||||
let frame_size = frame.encoded_size();
|
||||
// Remove consumed bytes without reallocating
|
||||
read_buf.drain(..frame_size);
|
||||
|
||||
// Rate limit check
|
||||
if !rate_limiter.check() {
|
||||
warn!("Rate limit exceeded for device {:?}, dropping connection", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
|
||||
// Verify protocol version
|
||||
if frame.version != PROTOCOL_VERSION {
|
||||
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_on_disconnect(&state, &device_uid).await;
|
||||
write_task.abort();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a TLS-wrapped client connection
|
||||
async fn handle_client_tls(
|
||||
stream: tokio_rustls::server::TlsStream<TcpStream>,
|
||||
state: AppState,
|
||||
) -> anyhow::Result<()> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let (mut reader, mut writer) = tokio::io::split(stream);
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let mut buffer = vec![0u8; 65536];
|
||||
let mut read_buf = Vec::with_capacity(65536);
|
||||
let mut device_uid: Option<String> = None;
|
||||
let mut rate_limiter = RateLimiter::new();
|
||||
|
||||
let write_task = tokio::spawn(async move {
|
||||
while let Some(data) = rx.recv().await {
|
||||
if writer.write_all(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Reader loop — same logic as plaintext handler
|
||||
'reader: loop {
|
||||
let n = reader.read(&mut buffer).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
read_buf.extend_from_slice(&buffer[..n]);
|
||||
|
||||
if read_buf.len() > MAX_READ_BUF_SIZE {
|
||||
warn!("TLS connection exceeded max buffer size, dropping");
|
||||
break;
|
||||
}
|
||||
|
||||
while let Some(frame) = Frame::decode(&read_buf)? {
|
||||
let frame_size = frame.encoded_size();
|
||||
read_buf.drain(..frame_size);
|
||||
|
||||
if frame.version != PROTOCOL_VERSION {
|
||||
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !rate_limiter.check() {
|
||||
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
|
||||
break 'reader;
|
||||
}
|
||||
|
||||
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
|
||||
warn!("Frame processing error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_on_disconnect(&state, &device_uid).await;
|
||||
write_task.abort();
|
||||
Ok(())
|
||||
}
|
||||
125
crates/server/src/ws.rs
Normal file
125
crates/server/src/ws.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::extract::Query;
|
||||
use jsonwebtoken::{decode, Validation, DecodingKey};
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::broadcast;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, warn};
|
||||
use crate::api::auth::Claims;
|
||||
use crate::AppState;
|
||||
|
||||
/// WebSocket hub for broadcasting real-time events to admin browsers
|
||||
#[derive(Clone)]
|
||||
pub struct WsHub {
|
||||
tx: broadcast::Sender<String>,
|
||||
}
|
||||
|
||||
impl WsHub {
|
||||
pub fn new() -> Self {
|
||||
let (tx, _) = broadcast::channel(1024);
|
||||
Self { tx }
|
||||
}
|
||||
|
||||
pub async fn broadcast(&self, message: String) {
|
||||
if self.tx.send(message).is_err() {
|
||||
debug!("No WebSocket subscribers to receive broadcast");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<String> {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WsAuthParams {
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
/// HTTP upgrade handler for WebSocket connections
|
||||
/// Validates JWT token from query parameter before upgrading
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Query(params): Query<WsAuthParams>,
|
||||
axum::extract::State(state): axum::extract::State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let token = match params.token {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("WebSocket connection rejected: no token provided");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let claims = match decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(c) => c.claims,
|
||||
Err(e) => {
|
||||
warn!("WebSocket connection rejected: invalid token - {}", e);
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if claims.token_type != "access" {
|
||||
warn!("WebSocket connection rejected: not an access token");
|
||||
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
|
||||
}
|
||||
|
||||
let hub = state.ws_hub.clone();
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
|
||||
debug!("WebSocket client connected: user={}", claims.username);
|
||||
|
||||
let welcome = serde_json::json!({
|
||||
"type": "connected",
|
||||
"message": "CSM real-time feed active",
|
||||
"user": claims.username
|
||||
});
|
||||
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Subscribe to broadcast hub for real-time events
|
||||
let mut rx = hub.subscribe();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward broadcast messages to WebSocket client
|
||||
msg = rx.recv() => {
|
||||
match msg {
|
||||
Ok(text) => {
|
||||
if socket.send(Message::Text(text)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
debug!("WebSocket client lagged {} messages, continuing", n);
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
// Handle incoming WebSocket messages (ping/close)
|
||||
msg = socket.recv() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
if socket.send(Message::Pong(data)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => break,
|
||||
Some(Err(_)) => break,
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("WebSocket client disconnected: user={}", claims.username);
|
||||
}
|
||||
Reference in New Issue
Block a user