feat: 初始化项目基础架构和核心功能

- 添加项目基础结构:Cargo.toml、.gitignore、设备UID和密钥文件
- 实现前端Vue3项目结构:路由、登录页面、设备管理页面
- 添加核心协议定义(crates/protocol):设备状态、资产、USB事件等
- 实现客户端监控模块:系统状态收集、资产收集
- 实现服务端基础API和插件系统
- 添加数据库迁移脚本:设备管理、资产跟踪、告警系统等
- 实现前端设备状态展示和基本交互
- 添加使用时长统计和水印功能插件
This commit is contained in:
iven
2026-04-05 00:57:51 +08:00
commit fd6fb5cca0
87 changed files with 19576 additions and 0 deletions

51
crates/server/Cargo.toml Normal file
View File

@@ -0,0 +1,51 @@
[package]
name = "csm-server"
version.workspace = true
edition.workspace = true
[dependencies]
csm-protocol = { path = "../protocol" }
# Async runtime
tokio = { workspace = true }
# Web framework
axum = { version = "0.7", features = ["ws"] }
tower-http = { version = "0.5", features = ["cors", "fs", "trace", "compression-gzip", "set-header"] }
tower = "0.4"
# Database
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
# TLS
rustls = "0.23"
tokio-rustls = "0.26"
rustls-pemfile = "2"
rustls-pki-types = "1"
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
# Auth
jsonwebtoken = "9"
bcrypt = "0.15"
# Notifications
lettre = { version = "0.11", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder", "hostname"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
# Config & logging
toml = "0.8"
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
anyhow = { workspace = true }
# Utilities
uuid = { workspace = true }
chrono = { workspace = true }
thiserror = { workspace = true }
include_dir = "0.7"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"

118
crates/server/src/alert.rs Normal file
View File

@@ -0,0 +1,118 @@
use crate::AppState;
use tracing::{info, warn, error};
/// Background task for data cleanup and alert processing
pub async fn cleanup_task(state: AppState) {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600));
loop {
interval.tick().await;
// Cleanup old status history
if let Err(e) = sqlx::query(
"DELETE FROM device_status_history WHERE reported_at < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.status_history_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup status history: {}", e);
}
// Cleanup old USB events
if let Err(e) = sqlx::query(
"DELETE FROM usb_events WHERE event_time < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.usb_events_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup USB events: {}", e);
}
// Cleanup handled alert records
if let Err(e) = sqlx::query(
"DELETE FROM alert_records WHERE handled = 1 AND triggered_at < datetime('now', ?)"
)
.bind(format!("-{} days", state.config.retention.alert_records_days))
.execute(&state.db)
.await
{
error!("Failed to cleanup alert records: {}", e);
}
// Mark devices as offline if no heartbeat for 2 minutes
if let Err(e) = sqlx::query(
"UPDATE devices SET status = 'offline' WHERE status = 'online' AND last_heartbeat < datetime('now', '-2 minutes')"
)
.execute(&state.db)
.await
{
error!("Failed to mark stale devices offline: {}", e);
}
// SQLite WAL checkpoint
if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
.execute(&state.db)
.await
{
warn!("WAL checkpoint failed: {}", e);
}
info!("Cleanup cycle completed");
}
}
/// Send email notification
pub async fn send_email(
smtp_config: &crate::config::SmtpConfig,
to: &str,
subject: &str,
body: &str,
) -> anyhow::Result<()> {
use lettre::message::header::ContentType;
use lettre::{Message, SmtpTransport, Transport};
use lettre::transport::smtp::authentication::Credentials;
let email = Message::builder()
.from(smtp_config.from.parse()?)
.to(to.parse()?)
.subject(subject)
.header(ContentType::TEXT_HTML)
.body(body.to_string())?;
let creds = Credentials::new(
smtp_config.username.clone(),
smtp_config.password.clone(),
);
let mailer = SmtpTransport::starttls_relay(&smtp_config.host)?
.port(smtp_config.port)
.credentials(creds)
.build();
mailer.send(&email)?;
Ok(())
}
/// Shared HTTP client for webhook notifications.
/// Lazily initialized once and reused across calls to benefit from connection pooling.
static WEBHOOK_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
fn webhook_client() -> &'static reqwest::Client {
WEBHOOK_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| reqwest::Client::new())
})
}
/// Send webhook notification
pub async fn send_webhook(url: &str, payload: &serde_json::Value) -> anyhow::Result<()> {
webhook_client().post(url)
.json(payload)
.send()
.await?;
Ok(())
}

View File

@@ -0,0 +1,243 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
use super::auth::Claims;
#[derive(Debug, Deserialize)]
pub struct AlertRecordListParams {
pub device_uid: Option<String>,
pub alert_type: Option<String>,
pub severity: Option<String>,
pub handled: Option<i32>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_rules(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, rule_type, condition, severity, enabled, notify_email, notify_webhook, created_at, updated_at
FROM alert_rules ORDER BY created_at DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name": r.get::<String, _>("name"),
"rule_type": r.get::<String, _>("rule_type"),
"condition": r.get::<String, _>("condition"),
"severity": r.get::<String, _>("severity"),
"enabled": r.get::<i32, _>("enabled"),
"notify_email": r.get::<Option<String>, _>("notify_email"),
"notify_webhook": r.get::<Option<String>, _>("notify_webhook"),
"created_at": r.get::<String, _>("created_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"rules": items,
})))
}
Err(e) => Json(ApiResponse::internal_error("query alert rules", e)),
}
}
pub async fn list_records(
State(state): State<AppState>,
Query(params): Query<AlertRecordListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None (Axum deserializes `key=` as Some(""))
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let alert_type = params.alert_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
let severity = params.severity.as_deref().filter(|s| !s.is_empty()).map(String::from);
let handled = params.handled;
let rows = sqlx::query(
"SELECT id, rule_id, device_uid, alert_type, severity, detail, handled, handled_by, handled_at, triggered_at
FROM alert_records WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR alert_type = ?)
AND (? IS NULL OR severity = ?)
AND (? IS NULL OR handled = ?)
ORDER BY triggered_at DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&alert_type).bind(&alert_type)
.bind(&severity).bind(&severity)
.bind(&handled).bind(&handled)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_id": r.get::<Option<i64>, _>("rule_id"),
"device_uid": r.get::<Option<String>, _>("device_uid"),
"alert_type": r.get::<String, _>("alert_type"),
"severity": r.get::<String, _>("severity"),
"detail": r.get::<String, _>("detail"),
"handled": r.get::<i32, _>("handled"),
"handled_by": r.get::<Option<String>, _>("handled_by"),
"handled_at": r.get::<Option<String>, _>("handled_at"),
"triggered_at": r.get::<String, _>("triggered_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"records": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query alert records", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub name: String,
pub rule_type: String,
pub condition: String,
pub severity: Option<String>,
pub notify_email: Option<String>,
pub notify_webhook: Option<String>,
}
pub async fn create_rule(
State(state): State<AppState>,
Json(body): Json<CreateRuleRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let severity = body.severity.unwrap_or_else(|| "medium".to_string());
let result = sqlx::query(
"INSERT INTO alert_rules (name, rule_type, condition, severity, notify_email, notify_webhook)
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&body.name)
.bind(&body.rule_type)
.bind(&body.condition)
.bind(&severity)
.bind(&body.notify_email)
.bind(&body.notify_webhook)
.execute(&state.db)
.await;
match result {
Ok(r) => (StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
"id": r.last_insert_rowid(),
})))),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create alert rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest {
pub name: Option<String>,
pub rule_type: Option<String>,
pub condition: Option<String>,
pub severity: Option<String>,
pub enabled: Option<i32>,
pub notify_email: Option<String>,
pub notify_webhook: Option<String>,
}
pub async fn update_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateRuleRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let existing = sqlx::query("SELECT * FROM alert_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Rule not found")),
Err(e) => return Json(ApiResponse::internal_error("query alert rule", e)),
};
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
let condition = body.condition.unwrap_or_else(|| existing.get::<String, _>("condition"));
let severity = body.severity.unwrap_or_else(|| existing.get::<String, _>("severity"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
let notify_email = body.notify_email.or_else(|| existing.get::<Option<String>, _>("notify_email"));
let notify_webhook = body.notify_webhook.or_else(|| existing.get::<Option<String>, _>("notify_webhook"));
let result = sqlx::query(
"UPDATE alert_rules SET name = ?, rule_type = ?, condition = ?, severity = ?, enabled = ?,
notify_email = ?, notify_webhook = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&name)
.bind(&rule_type)
.bind(&condition)
.bind(&severity)
.bind(enabled)
.bind(&notify_email)
.bind(&notify_webhook)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(_) => Json(ApiResponse::ok(serde_json::json!({"updated": true}))),
Err(e) => Json(ApiResponse::internal_error("update alert rule", e)),
}
}
pub async fn delete_rule(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let result = sqlx::query("DELETE FROM alert_rules WHERE id = ?")
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
} else {
Json(ApiResponse::error("Rule not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("delete alert rule", e)),
}
}
pub async fn handle_record(
State(state): State<AppState>,
Path(id): Path<i64>,
claims: axum::Extension<Claims>,
) -> Json<ApiResponse<serde_json::Value>> {
let handled_by = &claims.username;
let result = sqlx::query(
"UPDATE alert_records SET handled = 1, handled_by = ?, handled_at = datetime('now') WHERE id = ?"
)
.bind(handled_by)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
Json(ApiResponse::ok(serde_json::json!({"handled": true})))
} else {
Json(ApiResponse::error("Alert record not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("handle alert record", e)),
}
}

View File

@@ -0,0 +1,143 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
#[derive(Debug, Deserialize)]
pub struct AssetListParams {
pub device_uid: Option<String>,
pub search: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_hardware(
State(state): State<AppState>,
Query(params): Query<AssetListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb,
gpu_model, motherboard, serial_number, reported_at
FROM hardware_assets WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR cpu_model LIKE '%' || ? || '%' OR gpu_model LIKE '%' || ? || '%')
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"cpu_model": r.get::<String, _>("cpu_model"),
"cpu_cores": r.get::<i32, _>("cpu_cores"),
"memory_total_mb": r.get::<i64, _>("memory_total_mb"),
"disk_model": r.get::<String, _>("disk_model"),
"disk_total_mb": r.get::<i64, _>("disk_total_mb"),
"gpu_model": r.get::<Option<String>, _>("gpu_model"),
"motherboard": r.get::<Option<String>, _>("motherboard"),
"serial_number": r.get::<Option<String>, _>("serial_number"),
"reported_at": r.get::<String, _>("reported_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"hardware": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query hardware assets", e)),
}
}
pub async fn list_software(
State(state): State<AppState>,
Query(params): Query<AssetListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, name, version, publisher, install_date, install_path
FROM software_assets WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR name LIKE '%' || ? || '%' OR publisher LIKE '%' || ? || '%')
ORDER BY name ASC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"name": r.get::<String, _>("name"),
"version": r.get::<Option<String>, _>("version"),
"publisher": r.get::<Option<String>, _>("publisher"),
"install_date": r.get::<Option<String>, _>("install_date"),
"install_path": r.get::<Option<String>, _>("install_path"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"software": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query software assets", e)),
}
}
pub async fn list_changes(
State(state): State<AppState>,
Query(page): Query<Pagination>,
) -> Json<ApiResponse<serde_json::Value>> {
let offset = page.offset();
let limit = page.limit();
let rows = sqlx::query(
"SELECT id, device_uid, change_type, change_detail, detected_at
FROM asset_changes ORDER BY detected_at DESC LIMIT ? OFFSET ?"
)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"change_type": r.get::<String, _>("change_type"),
"change_detail": r.get::<String, _>("change_detail"),
"detected_at": r.get::<String, _>("detected_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"changes": items,
"page": page.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query asset changes", e)),
}
}

View File

@@ -0,0 +1,295 @@
use axum::{extract::State, Json, http::StatusCode, extract::Request, middleware::Next, response::Response};
use serde::{Deserialize, Serialize};
use jsonwebtoken::{encode, decode, Header, EncodingKey, DecodingKey, Validation};
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::Mutex;
use crate::AppState;
use super::ApiResponse;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: i64,
pub username: String,
pub role: String,
pub exp: u64,
pub iat: u64,
pub token_type: String,
/// Random family ID for refresh token rotation detection
pub family: String,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: String,
pub user: UserInfo,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct UserInfo {
pub id: i64,
pub username: String,
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
/// In-memory rate limiter for login attempts
#[derive(Clone, Default)]
pub struct LoginRateLimiter {
attempts: Arc<Mutex<HashMap<String, (Instant, u32)>>>,
}
impl LoginRateLimiter {
pub fn new() -> Self {
Self::default()
}
/// Returns true if the request should be rate-limited
pub async fn is_limited(&self, key: &str) -> bool {
let mut attempts = self.attempts.lock().await;
let now = Instant::now();
let window = std::time::Duration::from_secs(300); // 5-minute window
let max_attempts = 10u32;
if let Some((first_attempt, count)) = attempts.get_mut(key) {
if now.duration_since(*first_attempt) > window {
// Window expired, reset
*first_attempt = now;
*count = 1;
false
} else if *count >= max_attempts {
true // Rate limited
} else {
*count += 1;
false
}
} else {
attempts.insert(key.to_string(), (now, 1));
// Cleanup old entries periodically
if attempts.len() > 1000 {
let cutoff = now - window;
attempts.retain(|_, (t, _)| *t > cutoff);
}
false
}
}
}
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
// Rate limit check
if state.login_limiter.is_limited(&req.username).await {
return Ok((StatusCode::TOO_MANY_REQUESTS, Json(ApiResponse::error("Too many login attempts. Try again later."))));
}
let user: Option<UserInfo> = sqlx::query_as::<_, UserInfo>(
"SELECT id, username, role FROM users WHERE username = ?"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let user = match user {
Some(u) => u,
None => return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials")))),
};
let hash: String = sqlx::query_scalar::<_, String>(
"SELECT password FROM users WHERE id = ?"
)
.bind(user.id)
.fetch_one(&state.db)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if !bcrypt::verify(&req.password, &hash).unwrap_or(false) {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid credentials"))));
}
let now = chrono::Utc::now().timestamp() as u64;
let family = uuid::Uuid::new_v4().to_string();
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &family)?;
// Audit log
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, 'login', ?)"
)
.bind(user.id)
.bind(format!("User {} logged in", user.username))
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
}
pub async fn refresh(
State(state): State<AppState>,
Json(req): Json<RefreshRequest>,
) -> Result<(StatusCode, Json<ApiResponse<LoginResponse>>), StatusCode> {
let claims = decode::<Claims>(
&req.refresh_token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "refresh" {
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Invalid token type"))));
}
// Check if this refresh token family has been revoked (reuse detection)
let revoked: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM revoked_token_families WHERE family = ?"
)
.bind(&claims.claims.family)
.fetch_one(&state.db)
.await
.unwrap_or(0) > 0;
if revoked {
// Token reuse detected — revoke entire family and force re-login
tracing::warn!("Refresh token reuse detected for user {} family {}", claims.claims.sub, claims.claims.family);
let _ = sqlx::query("DELETE FROM refresh_tokens WHERE user_id = ?")
.bind(claims.claims.sub)
.execute(&state.db)
.await;
return Ok((StatusCode::UNAUTHORIZED, Json(ApiResponse::error("Token reuse detected. Please log in again."))));
}
let user = UserInfo {
id: claims.claims.sub,
username: claims.claims.username,
role: claims.claims.role,
};
// Rotate: new family for each refresh
let new_family = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().timestamp() as u64;
let access_token = create_token(&user, "access", state.config.auth.access_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
let refresh_token = create_token(&user, "refresh", state.config.auth.refresh_token_ttl_secs, now, &state.config.auth.jwt_secret, &new_family)?;
// Revoke old family
let _ = sqlx::query("INSERT OR IGNORE INTO revoked_token_families (family, user_id, revoked_at) VALUES (?, ?, datetime('now'))")
.bind(&claims.claims.family)
.bind(claims.claims.sub)
.execute(&state.db)
.await;
Ok((StatusCode::OK, Json(ApiResponse::ok(LoginResponse {
access_token,
refresh_token,
user,
}))))
}
fn create_token(user: &UserInfo, token_type: &str, ttl: u64, now: u64, secret: &str, family: &str) -> Result<String, StatusCode> {
let claims = Claims {
sub: user.id,
username: user.username.clone(),
role: user.role.clone(),
exp: now + ttl,
iat: now,
token_type: token_type.to_string(),
family: family.to_string(),
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
/// Axum middleware: require valid JWT access token
pub async fn require_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let token = match auth_header {
Some(t) => t,
None => return Err(StatusCode::UNAUTHORIZED),
};
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
if claims.claims.token_type != "access" {
return Err(StatusCode::UNAUTHORIZED);
}
// Inject claims into request extensions for handlers to use
request.extensions_mut().insert(claims.claims);
Ok(next.run(request).await)
}
/// Axum middleware: require admin role for write operations + audit log
pub async fn require_admin(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let claims = request.extensions()
.get::<Claims>()
.ok_or(StatusCode::UNAUTHORIZED)?;
if claims.role != "admin" {
return Err(StatusCode::FORBIDDEN);
}
// Capture audit info before running handler
let method = request.method().clone();
let path = request.uri().path().to_string();
let user_id = claims.sub;
let username = claims.username.clone();
let response = next.run(request).await;
// Record admin action to audit log (fire and forget — don't block response)
let status = response.status();
if status.is_success() {
let action = format!("{} {}", method, path);
let detail = format!("by {}", username);
let _ = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, detail) VALUES (?, ?, ?)"
)
.bind(user_id)
.bind(&action)
.bind(&detail)
.execute(&state.db)
.await;
}
Ok(response)
}

View File

@@ -0,0 +1,263 @@
use axum::{extract::{State, Path, Query}, Json};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use crate::AppState;
use super::{ApiResponse, Pagination};
#[derive(Debug, Deserialize)]
pub struct DeviceListParams {
pub status: Option<String>,
pub group: Option<String>,
pub search: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
#[derive(Debug, Serialize, sqlx::FromRow)]
pub struct DeviceRow {
pub id: i64,
pub device_uid: String,
pub hostname: String,
pub ip_address: String,
pub mac_address: Option<String>,
pub os_version: Option<String>,
pub client_version: Option<String>,
pub status: String,
pub last_heartbeat: Option<String>,
pub registered_at: Option<String>,
pub group_name: Option<String>,
}
pub async fn list(
State(state): State<AppState>,
Query(params): Query<DeviceListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None (Axum deserializes `status=` as Some(""))
let status = params.status.as_deref().filter(|s| !s.is_empty()).map(String::from);
let group = params.group.as_deref().filter(|s| !s.is_empty()).map(String::from);
let search = params.search.as_deref().filter(|s| !s.is_empty()).map(String::from);
let devices = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')
ORDER BY registered_at DESC LIMIT ? OFFSET ?"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
.bind(&search).bind(&search).bind(&search)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM devices WHERE 1=1
AND (? IS NULL OR status = ?)
AND (? IS NULL OR group_name = ?)
AND (? IS NULL OR hostname LIKE '%' || ? || '%' OR ip_address LIKE '%' || ? || '%')"
)
.bind(&status).bind(&status)
.bind(&group).bind(&group)
.bind(&search).bind(&search).bind(&search)
.fetch_one(&state.db)
.await
.unwrap_or(0);
match devices {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({
"devices": rows,
"total": total,
"page": params.page.unwrap_or(1),
"page_size": limit,
}))),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn get_detail(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let device = sqlx::query_as::<_, DeviceRow>(
"SELECT id, device_uid, hostname, ip_address, mac_address, os_version, client_version,
status, last_heartbeat, registered_at, group_name
FROM devices WHERE device_uid = ?"
)
.bind(&uid)
.fetch_optional(&state.db)
.await;
match device {
Ok(Some(d)) => Json(ApiResponse::ok(serde_json::to_value(d).unwrap_or_default())),
Ok(None) => Json(ApiResponse::error("Device not found")),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
#[derive(Debug, Serialize, sqlx::FromRow)]
struct StatusRow {
pub cpu_usage: f64,
pub memory_usage: f64,
pub memory_total_mb: i64,
pub disk_usage: f64,
pub disk_total_mb: i64,
pub network_rx_rate: i64,
pub network_tx_rate: i64,
pub running_procs: i32,
pub top_processes: Option<String>,
pub reported_at: String,
}
pub async fn get_status(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let status = sqlx::query_as::<_, StatusRow>(
"SELECT cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb,
network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at
FROM device_status WHERE device_uid = ?"
)
.bind(&uid)
.fetch_optional(&state.db)
.await;
match status {
Ok(Some(s)) => {
let mut val = serde_json::to_value(&s).unwrap_or_default();
// Parse top_processes JSON string back to array
if let Some(tp_str) = &s.top_processes {
if let Ok(tp) = serde_json::from_str::<serde_json::Value>(tp_str) {
val["top_processes"] = tp;
}
}
Json(ApiResponse::ok(val))
}
Ok(None) => Json(ApiResponse::error("No status data found")),
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn get_history(
State(state): State<AppState>,
Path(uid): Path<String>,
Query(page): Query<Pagination>,
) -> Json<ApiResponse<serde_json::Value>> {
let offset = page.offset();
let limit = page.limit();
let rows = sqlx::query(
"SELECT cpu_usage, memory_usage, disk_usage, running_procs, reported_at
FROM device_status_history WHERE device_uid = ?
ORDER BY reported_at DESC LIMIT ? OFFSET ?"
)
.bind(&uid)
.bind(limit)
.bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| {
serde_json::json!({
"cpu_usage": r.get::<f64, _>("cpu_usage"),
"memory_usage": r.get::<f64, _>("memory_usage"),
"disk_usage": r.get::<f64, _>("disk_usage"),
"running_procs": r.get::<i32, _>("running_procs"),
"reported_at": r.get::<String, _>("reported_at"),
})
}).collect();
Json(ApiResponse::ok(serde_json::json!({
"history": items,
"page": page.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query devices", e)),
}
}
pub async fn remove(
State(state): State<AppState>,
Path(uid): Path<String>,
) -> Json<ApiResponse<()>> {
// If client is connected, send self-destruct command
let frame = csm_protocol::Frame::new_json(
csm_protocol::MessageType::ConfigUpdate,
&serde_json::json!({"type": "SelfDestruct"}),
).ok();
if let Some(frame) = frame {
state.clients.send_to(&uid, frame.encode()).await;
}
// Delete device and all associated data in a transaction
let mut tx = match state.db.begin().await {
Ok(tx) => tx,
Err(e) => return Json(ApiResponse::internal_error("begin transaction", e)),
};
// Delete status history
if let Err(e) = sqlx::query("DELETE FROM device_status_history WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await
{
return Json(ApiResponse::internal_error("remove device history", e));
}
// Delete current status
if let Err(e) = sqlx::query("DELETE FROM device_status WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await
{
return Json(ApiResponse::internal_error("remove device status", e));
}
// Delete plugin-related data
let cleanup_tables = [
"hardware_assets",
"usb_events",
"usb_file_operations",
"usage_daily",
"app_usage_daily",
"software_violations",
"web_access_log",
"popup_block_stats",
];
for table in &cleanup_tables {
if let Err(e) = sqlx::query(&format!("DELETE FROM {} WHERE device_uid = ?", table))
.bind(&uid)
.execute(&mut *tx)
.await
{
tracing::warn!("Failed to clean {} for device {}: {}", table, uid, e);
}
}
// Finally delete the device itself
let delete_result = sqlx::query("DELETE FROM devices WHERE device_uid = ?")
.bind(&uid)
.execute(&mut *tx)
.await;
match delete_result {
Ok(r) if r.rows_affected() > 0 => {
if let Err(e) = tx.commit().await {
return Json(ApiResponse::internal_error("commit device deletion", e));
}
state.clients.unregister(&uid).await;
tracing::info!(device_uid = %uid, "Device and all associated data deleted");
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Device not found")),
Err(e) => Json(ApiResponse::internal_error("remove device", e)),
}
}

View File

@@ -0,0 +1,120 @@
use axum::{routing::{get, post, put, delete}, Router, Json, extract::State, middleware};
use serde::{Deserialize, Serialize};
use crate::AppState;
pub mod auth;
pub mod devices;
pub mod assets;
pub mod usb;
pub mod alerts;
pub mod plugins;
pub fn routes(state: AppState) -> Router<AppState> {
let public = Router::new()
.route("/api/auth/login", post(auth::login))
.route("/api/auth/refresh", post(auth::refresh))
.route("/health", get(health_check))
.with_state(state.clone());
// Read-only routes (any authenticated user)
let read_routes = Router::new()
// Devices
.route("/api/devices", get(devices::list))
.route("/api/devices/:uid", get(devices::get_detail))
.route("/api/devices/:uid/status", get(devices::get_status))
.route("/api/devices/:uid/history", get(devices::get_history))
// Assets
.route("/api/assets/hardware", get(assets::list_hardware))
.route("/api/assets/software", get(assets::list_software))
.route("/api/assets/changes", get(assets::list_changes))
// USB (read)
.route("/api/usb/events", get(usb::list_events))
.route("/api/usb/policies", get(usb::list_policies))
// Alerts (read)
.route("/api/alerts/rules", get(alerts::list_rules))
.route("/api/alerts/records", get(alerts::list_records))
// Plugin read routes
.merge(plugins::read_routes())
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// Write routes (admin only)
let write_routes = Router::new()
// Devices
.route("/api/devices/:uid", delete(devices::remove))
// USB (write)
.route("/api/usb/policies", post(usb::create_policy))
.route("/api/usb/policies/:id", put(usb::update_policy).delete(usb::delete_policy))
// Alerts (write)
.route("/api/alerts/rules", post(alerts::create_rule))
.route("/api/alerts/rules/:id", put(alerts::update_rule).delete(alerts::delete_rule))
.route("/api/alerts/records/:id/handle", put(alerts::handle_record))
// Plugin write routes (already has require_admin layer internally)
.merge(plugins::write_routes())
// Layer order: outer (require_admin) runs AFTER inner (require_auth)
// so require_auth sets Claims extension first, then require_admin checks it
.layer(middleware::from_fn_with_state(state.clone(), auth::require_admin))
.layer(middleware::from_fn_with_state(state.clone(), auth::require_auth));
// WebSocket has its own JWT auth via query parameter
let ws_router = Router::new()
.route("/ws", get(crate::ws::ws_handler))
.with_state(state.clone());
Router::new()
.merge(public)
.merge(read_routes)
.merge(write_routes)
.merge(ws_router)
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
}
async fn health_check() -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok",
})
}
/// Standard API response envelope
#[derive(Serialize)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub data: Option<T>,
pub error: Option<String>,
}
impl<T: Serialize> ApiResponse<T> {
pub fn ok(data: T) -> Self {
Self { success: true, data: Some(data), error: None }
}
pub fn error(msg: impl Into<String>) -> Self {
Self { success: false, data: None, error: Some(msg.into()) }
}
/// Log internal error and return sanitized message to client
pub fn internal_error(context: &str, e: impl std::fmt::Display) -> Self {
tracing::error!("{}: {}", context, e);
Self { success: false, data: None, error: Some("Internal server error".to_string()) }
}
}
/// Pagination parameters
#[derive(Debug, Deserialize)]
pub struct Pagination {
pub page: Option<u32>,
pub page_size: Option<u32>,
}
impl Pagination {
pub fn offset(&self) -> u32 {
self.page.unwrap_or(1).saturating_sub(1) * self.limit()
}
pub fn limit(&self) -> u32 {
self.page_size.unwrap_or(20).min(100)
}
}

View File

@@ -0,0 +1,49 @@
pub mod web_filter;
pub mod usage_timer;
pub mod software_blocker;
pub mod popup_blocker;
pub mod usb_file_audit;
pub mod watermark;
use axum::{Router, routing::{get, post, put}};
use crate::AppState;
/// Read-only plugin routes (accessible by admin + viewer)
pub fn read_routes() -> Router<AppState> {
Router::new()
// Web Filter
.route("/api/plugins/web-filter/rules", get(web_filter::list_rules))
.route("/api/plugins/web-filter/log", get(web_filter::list_access_log))
// Usage Timer
.route("/api/plugins/usage-timer/daily", get(usage_timer::list_daily))
.route("/api/plugins/usage-timer/app-usage", get(usage_timer::list_app_usage))
.route("/api/plugins/usage-timer/leaderboard", get(usage_timer::leaderboard))
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", get(software_blocker::list_blacklist))
.route("/api/plugins/software-blocker/violations", get(software_blocker::list_violations))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", get(popup_blocker::list_rules))
.route("/api/plugins/popup-blocker/stats", get(popup_blocker::list_stats))
// USB File Audit
.route("/api/plugins/usb-file-audit/log", get(usb_file_audit::list_operations))
.route("/api/plugins/usb-file-audit/summary", get(usb_file_audit::summary))
// Watermark
.route("/api/plugins/watermark/config", get(watermark::get_config_list))
}
/// Write plugin routes (admin only — require_admin middleware applied by caller)
pub fn write_routes() -> Router<AppState> {
Router::new()
// Web Filter
.route("/api/plugins/web-filter/rules", post(web_filter::create_rule))
.route("/api/plugins/web-filter/rules/:id", put(web_filter::update_rule).delete(web_filter::delete_rule))
// Software Blocker
.route("/api/plugins/software-blocker/blacklist", post(software_blocker::add_to_blacklist))
.route("/api/plugins/software-blocker/blacklist/:id", put(software_blocker::update_blacklist).delete(software_blocker::remove_from_blacklist))
// Popup Blocker
.route("/api/plugins/popup-blocker/rules", post(popup_blocker::create_rule))
.route("/api/plugins/popup-blocker/rules/:id", put(popup_blocker::update_rule).delete(popup_blocker::delete_rule))
// Watermark
.route("/api/plugins/watermark/config", post(watermark::create_config))
.route("/api/plugins/watermark/config/:id", put(watermark::update_config).delete(watermark::delete_config))
}

View File

@@ -0,0 +1,155 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String, // "block" | "allow"
pub window_title: Option<String>,
pub window_class: Option<String>,
pub process_name: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, window_title, window_class, process_name, target_type, target_id, enabled, created_at FROM popup_filter_rules ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"window_title": r.get::<Option<String>,_>("window_title"),
"window_class": r.get::<Option<String>,_>("window_class"),
"process_name": r.get::<Option<String>,_>("process_name"),
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query popup filter rules", e)),
}
}
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if !matches!(req.rule_type.as_str(), "block" | "allow") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("rule_type must be 'block' or 'allow'")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
let has_filter = req.window_title.as_ref().map_or(false, |s| !s.is_empty())
|| req.window_class.as_ref().map_or(false, |s| !s.is_empty())
|| req.process_name.as_ref().map_or(false, |s| !s.is_empty());
if !has_filter {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("at least one filter (window_title/window_class/process_name) required")));
}
match sqlx::query("INSERT INTO popup_filter_rules (rule_type, window_title, window_class, process_name, target_type, target_id) VALUES (?,?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.window_title).bind(&req.window_class).bind(&req.process_name).bind(&target_type).bind(&req.target_id)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_popup_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create popup filter rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest { pub window_title: Option<String>, pub window_class: Option<String>, pub process_name: Option<String>, pub enabled: Option<bool> }
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM popup_filter_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query popup filter rule", e)),
};
let window_title = body.window_title.or_else(|| existing.get::<Option<String>, _>("window_title"));
let window_class = body.window_class.or_else(|| existing.get::<Option<String>, _>("window_class"));
let process_name = body.process_name.or_else(|| existing.get::<Option<String>, _>("process_name"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE popup_filter_rules SET window_title = ?, window_class = ?, process_name = ?, enabled = ? WHERE id = ?")
.bind(&window_title)
.bind(&window_class)
.bind(&process_name)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_popup_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update popup filter rule", e)),
}
}
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM popup_filter_rules WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM popup_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_popup_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::PopupRules, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
pub async fn list_stats(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT device_uid, blocked_count, date FROM popup_block_stats ORDER BY date DESC LIMIT 30")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"stats": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "blocked_count": r.get::<i32,_>("blocked_count"),
"date": r.get::<String,_>("date")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query popup block stats", e)),
}
}
async fn fetch_popup_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"window_title": r.get::<Option<String>,_>("window_title"),
"window_class": r.get::<Option<String>,_>("window_class"),
"process_name": r.get::<Option<String>,_>("process_name"),
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,155 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateBlacklistRequest {
pub name_pattern: String,
pub category: Option<String>,
pub action: Option<String>,
pub target_type: Option<String>,
pub target_id: Option<String>,
}
pub async fn list_blacklist(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, name_pattern, category, action, target_type, target_id, enabled, created_at FROM software_blacklist ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"blacklist": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"),
"category": r.get::<Option<String>,_>("category"), "action": r.get::<String,_>("action"),
"target_type": r.get::<String,_>("target_type"), "target_id": r.get::<Option<String>,_>("target_id"),
"enabled": r.get::<bool,_>("enabled"), "created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software blacklist", e)),
}
}
pub async fn add_to_blacklist(State(state): State<AppState>, Json(req): Json<CreateBlacklistRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let action = req.action.unwrap_or_else(|| "block".to_string());
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if req.name_pattern.trim().is_empty() || req.name_pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("name_pattern must be 1-255 chars")));
}
if !matches!(action.as_str(), "block" | "alert") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("action must be 'block' or 'alert'")));
}
if let Some(ref cat) = req.category {
if !matches!(cat.as_str(), "game" | "social" | "vpn" | "mining" | "custom") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid category")));
}
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO software_blacklist (name_pattern, category, action, target_type, target_id) VALUES (?,?,?,?,?)")
.bind(&req.name_pattern).bind(&req.category).bind(&action).bind(&target_type).bind(&req.target_id)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("add software blacklist entry", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateBlacklistRequest { pub name_pattern: Option<String>, pub action: Option<String>, pub enabled: Option<bool> }
pub async fn update_blacklist(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateBlacklistRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM software_blacklist WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query software blacklist", e)),
};
let name_pattern = body.name_pattern.unwrap_or_else(|| existing.get::<String, _>("name_pattern"));
let action = body.action.unwrap_or_else(|| existing.get::<String, _>("action"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE software_blacklist SET name_pattern = ?, action = ?, enabled = ? WHERE id = ?")
.bind(&name_pattern)
.bind(&action)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let blacklist = fetch_blacklist_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update software blacklist", e)),
}
}
pub async fn remove_from_blacklist(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM software_blacklist WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM software_blacklist WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let blacklist = fetch_blacklist_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": blacklist}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
#[derive(Debug, Deserialize)]
pub struct ViolationFilters { pub device_uid: Option<String> }
pub async fn list_violations(State(state): State<AppState>, Query(f): Query<ViolationFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, software_name, action_taken, timestamp FROM software_violations WHERE (? IS NULL OR device_uid=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"violations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"software_name": r.get::<String,_>("software_name"), "action_taken": r.get::<String,_>("action_taken"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query software violations", e)),
}
}
async fn fetch_blacklist_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "name_pattern": r.get::<String,_>("name_pattern"), "action": r.get::<String,_>("action")
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,60 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct DailyFilters { pub device_uid: Option<String>, pub start_date: Option<String>, pub end_date: Option<String> }
pub async fn list_daily(State(state): State<AppState>, Query(f): Query<DailyFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at
FROM usage_daily WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date>=?) AND (? IS NULL OR date<=?)
ORDER BY date DESC LIMIT 90")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.start_date).bind(&f.start_date)
.bind(&f.end_date).bind(&f.end_date)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"daily": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"date": r.get::<String,_>("date"), "total_active_minutes": r.get::<i32,_>("total_active_minutes"),
"total_idle_minutes": r.get::<i32,_>("total_idle_minutes"),
"first_active_at": r.get::<Option<String>,_>("first_active_at"),
"last_active_at": r.get::<Option<String>,_>("last_active_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query daily usage", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct AppUsageFilters { pub device_uid: Option<String>, pub date: Option<String> }
pub async fn list_app_usage(State(state): State<AppState>, Query(f): Query<AppUsageFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, date, app_name, usage_minutes FROM app_usage_daily
WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR date=?)
ORDER BY usage_minutes DESC LIMIT 100")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.date).bind(&f.date)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"app_usage": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"date": r.get::<String,_>("date"), "app_name": r.get::<String,_>("app_name"),
"usage_minutes": r.get::<i32,_>("usage_minutes")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query app usage", e)),
}
}
pub async fn leaderboard(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT device_uid, SUM(total_active_minutes) as total_minutes FROM usage_daily
WHERE date >= date('now', '-7 days') GROUP BY device_uid ORDER BY total_minutes DESC LIMIT 20")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"leaderboard": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "total_minutes": r.get::<i64,_>("total_minutes")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query usage leaderboard", e)),
}
}

View File

@@ -0,0 +1,47 @@
use axum::{extract::{State, Query}, Json};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use crate::api::ApiResponse;
#[derive(Debug, Deserialize)]
pub struct LogFilters {
pub device_uid: Option<String>,
pub operation: Option<String>,
pub usb_serial: Option<String>,
}
pub async fn list_operations(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT id, device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp
FROM usb_file_operations WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR operation=?) AND (? IS NULL OR usb_serial=?)
ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid)
.bind(&f.operation).bind(&f.operation)
.bind(&f.usb_serial).bind(&f.usb_serial)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"operations": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"usb_serial": r.get::<Option<String>,_>("usb_serial"), "drive_letter": r.get::<Option<String>,_>("drive_letter"),
"operation": r.get::<String,_>("operation"), "file_path": r.get::<String,_>("file_path"),
"file_size": r.get::<Option<i64>,_>("file_size"), "timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query USB file operations", e)),
}
}
pub async fn summary(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query(
"SELECT device_uid, COUNT(*) as op_count, COUNT(DISTINCT usb_serial) as usb_count,
MIN(timestamp) as first_op, MAX(timestamp) as last_op
FROM usb_file_operations WHERE timestamp >= datetime('now', '-7 days')
GROUP BY device_uid ORDER BY op_count DESC LIMIT 50")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"summary": rows.iter().map(|r| serde_json::json!({
"device_uid": r.get::<String,_>("device_uid"), "op_count": r.get::<i64,_>("op_count"),
"usb_count": r.get::<i64,_>("usb_count"), "first_op": r.get::<String,_>("first_op"),
"last_op": r.get::<String,_>("last_op")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query USB file audit summary", e)),
}
}

View File

@@ -0,0 +1,186 @@
use axum::{extract::{State, Path, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::{MessageType, WatermarkConfigPayload};
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct CreateConfigRequest {
pub target_type: Option<String>,
pub target_id: Option<String>,
pub content: Option<String>,
pub font_size: Option<u32>,
pub opacity: Option<f64>,
pub color: Option<String>,
pub angle: Option<i32>,
pub enabled: Option<bool>,
}
pub async fn get_config_list(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, target_type, target_id, content, font_size, opacity, color, angle, enabled, updated_at FROM watermark_config ORDER BY id")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"configs": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "target_type": r.get::<String,_>("target_type"),
"target_id": r.get::<Option<String>,_>("target_id"), "content": r.get::<String,_>("content"),
"font_size": r.get::<i32,_>("font_size"), "opacity": r.get::<f64,_>("opacity"),
"color": r.get::<String,_>("color"), "angle": r.get::<i32,_>("angle"),
"enabled": r.get::<bool,_>("enabled"), "updated_at": r.get::<String,_>("updated_at")
})).collect::<Vec<_>>()}))),
Err(e) => Json(ApiResponse::internal_error("query watermark configs", e)),
}
}
pub async fn create_config(
State(state): State<AppState>,
Json(req): Json<CreateConfigRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
let content = req.content.unwrap_or_else(|| "{company} | {username} | {date}".to_string());
let font_size = req.font_size.unwrap_or(14).clamp(8, 72) as i32;
let opacity = req.opacity.unwrap_or(0.15).clamp(0.01, 1.0);
let color = req.color.unwrap_or_else(|| "#808080".to_string());
let angle = req.angle.unwrap_or(-30);
let enabled = req.enabled.unwrap_or(true);
// Validate inputs
if content.len() > 200 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("content too long (max 200 chars)")));
}
if !is_valid_hex_color(&color) {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid color format (expected #RRGGBB)")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO watermark_config (target_type, target_id, content, font_size, opacity, color, angle, enabled) VALUES (?,?,?,?,?,?,?,?)")
.bind(&target_type).bind(&req.target_id).bind(&content).bind(font_size).bind(opacity).bind(&color).bind(angle).bind(enabled)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push to online clients
let config = WatermarkConfigPayload {
content: content.clone(),
font_size: font_size as u32,
opacity,
color: color.clone(),
angle,
enabled,
};
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create watermark config", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateConfigRequest {
pub content: Option<String>, pub font_size: Option<u32>, pub opacity: Option<f64>,
pub color: Option<String>, pub angle: Option<i32>, pub enabled: Option<bool>,
}
pub async fn update_config(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdateConfigRequest>,
) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM watermark_config WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query watermark config", e)),
};
let content = body.content.unwrap_or_else(|| existing.get::<String, _>("content"));
let font_size = body.font_size.map(|v| v.clamp(8, 72) as i32).unwrap_or_else(|| existing.get::<i32, _>("font_size"));
let opacity = body.opacity.map(|v| v.clamp(0.01, 1.0)).unwrap_or_else(|| existing.get::<f64, _>("opacity"));
let color = body.color.unwrap_or_else(|| existing.get::<String, _>("color"));
let angle = body.angle.unwrap_or_else(|| existing.get::<i32, _>("angle"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
// Validate inputs
if content.len() > 200 {
return Json(ApiResponse::error("content too long (max 200 chars)"));
}
if !is_valid_hex_color(&color) {
return Json(ApiResponse::error("invalid color format (expected #RRGGBB)"));
}
let result = sqlx::query(
"UPDATE watermark_config SET content = ?, font_size = ?, opacity = ?, color = ?, angle = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&content)
.bind(font_size)
.bind(opacity)
.bind(&color)
.bind(angle)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
// Push updated config to online clients
let config = WatermarkConfigPayload {
content: content.clone(),
font_size: font_size as u32,
opacity,
color: color.clone(),
angle,
enabled,
};
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &config, &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update watermark config", e)),
}
}
pub async fn delete_config(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
// Fetch existing config to get target info for push
let existing = sqlx::query("SELECT target_type, target_id FROM watermark_config WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM watermark_config WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
// Push disabled watermark to clients
let disabled = WatermarkConfigPayload {
content: String::new(),
font_size: 0,
opacity: 0.0,
color: String::new(),
angle: 0,
enabled: false,
};
push_to_targets(&state.db, &state.clients, MessageType::WatermarkConfig, &disabled, &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
/// Validate a hex color string (#RRGGBB format)
fn is_valid_hex_color(color: &str) -> bool {
if color.len() != 7 || !color.starts_with('#') {
return false;
}
color[1..].chars().all(|c| c.is_ascii_hexdigit())
}

View File

@@ -0,0 +1,156 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use csm_protocol::MessageType;
use crate::AppState;
use crate::api::ApiResponse;
use crate::tcp::push_to_targets;
#[derive(Debug, Deserialize)]
pub struct RuleFilters { pub rule_type: Option<String>, pub enabled: Option<bool> }
#[derive(Debug, Deserialize)]
pub struct CreateRuleRequest {
pub rule_type: String,
pub pattern: String,
pub target_type: Option<String>,
pub target_id: Option<String>,
pub enabled: Option<bool>,
}
pub async fn list_rules(State(state): State<AppState>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, rule_type, pattern, target_type, target_id, enabled, created_at FROM web_filter_rules ORDER BY created_at DESC")
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({ "rules": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"),
"pattern": r.get::<String,_>("pattern"), "target_type": r.get::<String,_>("target_type"),
"target_id": r.get::<Option<String>,_>("target_id"), "enabled": r.get::<bool,_>("enabled"),
"created_at": r.get::<String,_>("created_at")
})).collect::<Vec<_>>() }))),
Err(e) => Json(ApiResponse::internal_error("query web filter rules", e)),
}
}
pub async fn create_rule(State(state): State<AppState>, Json(req): Json<CreateRuleRequest>) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = req.enabled.unwrap_or(true);
let target_type = req.target_type.unwrap_or_else(|| "global".to_string());
// Validate inputs
if !matches!(req.rule_type.as_str(), "blacklist" | "whitelist" | "category") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid rule_type (expected blacklist|whitelist|category)")));
}
if req.pattern.trim().is_empty() || req.pattern.len() > 255 {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("pattern must be 1-255 chars")));
}
if !matches!(target_type.as_str(), "global" | "device" | "group") {
return (StatusCode::BAD_REQUEST, Json(ApiResponse::error("invalid target_type")));
}
match sqlx::query("INSERT INTO web_filter_rules (rule_type, pattern, target_type, target_id, enabled) VALUES (?,?,?,?,?)")
.bind(&req.rule_type).bind(&req.pattern).bind(&target_type).bind(&req.target_id).bind(enabled)
.execute(&state.db).await {
Ok(r) => {
let new_id = r.last_insert_rowid();
let rules = fetch_rules_for_push(&state.db, &target_type, req.target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, req.target_id.as_deref()).await;
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({"id": new_id}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create web filter rule", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdateRuleRequest { pub rule_type: Option<String>, pub pattern: Option<String>, pub enabled: Option<bool> }
pub async fn update_rule(State(state): State<AppState>, Path(id): Path<i64>, Json(body): Json<UpdateRuleRequest>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT * FROM web_filter_rules WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Not found")),
Err(e) => return Json(ApiResponse::internal_error("query web filter rule", e)),
};
let rule_type = body.rule_type.unwrap_or_else(|| existing.get::<String, _>("rule_type"));
let pattern = body.pattern.unwrap_or_else(|| existing.get::<String, _>("pattern"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<bool, _>("enabled"));
let result = sqlx::query("UPDATE web_filter_rules SET rule_type = ?, pattern = ?, enabled = ? WHERE id = ?")
.bind(&rule_type)
.bind(&pattern)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) if r.rows_affected() > 0 => {
let target_type_val: String = existing.get("target_type");
let target_id_val: Option<String> = existing.get("target_id");
let rules = fetch_rules_for_push(&state.db, &target_type_val, target_id_val.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type_val, target_id_val.as_deref()).await;
Json(ApiResponse::ok(()))
}
Ok(_) => Json(ApiResponse::error("Not found")),
Err(e) => Json(ApiResponse::internal_error("update web filter rule", e)),
}
}
pub async fn delete_rule(State(state): State<AppState>, Path(id): Path<i64>) -> Json<ApiResponse<()>> {
let existing = sqlx::query("SELECT target_type, target_id FROM web_filter_rules WHERE id = ?")
.bind(id).fetch_optional(&state.db).await;
let (target_type, target_id) = match existing {
Ok(Some(row)) => (row.get::<String, _>("target_type"), row.get::<Option<String>, _>("target_id")),
_ => return Json(ApiResponse::error("Not found")),
};
match sqlx::query("DELETE FROM web_filter_rules WHERE id=?").bind(id).execute(&state.db).await {
Ok(r) if r.rows_affected() > 0 => {
let rules = fetch_rules_for_push(&state.db, &target_type, target_id.as_deref()).await;
push_to_targets(&state.db, &state.clients, MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules}), &target_type, target_id.as_deref()).await;
Json(ApiResponse::ok(()))
}
_ => Json(ApiResponse::error("Not found")),
}
}
#[derive(Debug, Deserialize)]
pub struct LogFilters { pub device_uid: Option<String>, pub action: Option<String> }
pub async fn list_access_log(State(state): State<AppState>, Query(f): Query<LogFilters>) -> Json<ApiResponse<serde_json::Value>> {
match sqlx::query("SELECT id, device_uid, url, action, timestamp FROM web_access_log WHERE (? IS NULL OR device_uid=?) AND (? IS NULL OR action=?) ORDER BY timestamp DESC LIMIT 200")
.bind(&f.device_uid).bind(&f.device_uid).bind(&f.action).bind(&f.action)
.fetch_all(&state.db).await {
Ok(rows) => Json(ApiResponse::ok(serde_json::json!({"log": rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "device_uid": r.get::<String,_>("device_uid"),
"url": r.get::<String,_>("url"), "action": r.get::<String,_>("action"),
"timestamp": r.get::<String,_>("timestamp")
})).collect::<Vec<_>>() }))),
Err(e) => Json(ApiResponse::internal_error("query web access log", e)),
}
}
/// Fetch enabled web filter rules applicable to a given target scope.
/// For "device" targets, includes both global rules and device-specific rules
/// (matching the logic used during registration push in tcp.rs).
async fn fetch_rules_for_push(
db: &sqlx::SqlitePool,
target_type: &str,
target_id: Option<&str>,
) -> Vec<serde_json::Value> {
let query = match target_type {
"device" => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?))"
).bind(target_id),
_ => sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND target_type = 'global'"
),
};
query.fetch_all(db).await
.map(|rows| rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>("id"), "rule_type": r.get::<String,_>("rule_type"), "pattern": r.get::<String,_>("pattern")
})).collect())
.unwrap_or_default()
}

View File

@@ -0,0 +1,246 @@
use axum::{extract::{State, Path, Query, Json}, http::StatusCode};
use serde::Deserialize;
use sqlx::Row;
use crate::AppState;
use super::ApiResponse;
use crate::tcp::push_to_targets;
use csm_protocol::{MessageType, UsbPolicyPayload, UsbDeviceRule};
#[derive(Debug, Deserialize)]
pub struct UsbEventListParams {
pub device_uid: Option<String>,
pub event_type: Option<String>,
pub page: Option<u32>,
pub page_size: Option<u32>,
}
pub async fn list_events(
State(state): State<AppState>,
Query(params): Query<UsbEventListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let limit = params.page_size.unwrap_or(20).min(100);
let offset = params.page.unwrap_or(1).saturating_sub(1) * limit;
// Normalize empty strings to None
let device_uid = params.device_uid.as_deref().filter(|s| !s.is_empty()).map(String::from);
let event_type = params.event_type.as_deref().filter(|s| !s.is_empty()).map(String::from);
let rows = sqlx::query(
"SELECT id, device_uid, vendor_id, product_id, serial_number, device_name, event_type, event_time
FROM usb_events WHERE 1=1
AND (? IS NULL OR device_uid = ?)
AND (? IS NULL OR event_type = ?)
ORDER BY event_time DESC LIMIT ? OFFSET ?"
)
.bind(&device_uid).bind(&device_uid)
.bind(&event_type).bind(&event_type)
.bind(limit).bind(offset)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"device_uid": r.get::<String, _>("device_uid"),
"vendor_id": r.get::<Option<String>, _>("vendor_id"),
"product_id": r.get::<Option<String>, _>("product_id"),
"serial_number": r.get::<Option<String>, _>("serial_number"),
"device_name": r.get::<Option<String>, _>("device_name"),
"event_type": r.get::<String, _>("event_type"),
"event_time": r.get::<String, _>("event_time"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"events": items,
"page": params.page.unwrap_or(1),
"page_size": limit,
})))
}
Err(e) => Json(ApiResponse::internal_error("query usb events", e)),
}
}
pub async fn list_policies(
State(state): State<AppState>,
) -> Json<ApiResponse<serde_json::Value>> {
let rows = sqlx::query(
"SELECT id, name, policy_type, target_group, rules, enabled, created_at, updated_at
FROM usb_policies ORDER BY created_at DESC"
)
.fetch_all(&state.db)
.await;
match rows {
Ok(records) => {
let items: Vec<serde_json::Value> = records.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name": r.get::<String, _>("name"),
"policy_type": r.get::<String, _>("policy_type"),
"target_group": r.get::<Option<String>, _>("target_group"),
"rules": r.get::<String, _>("rules"),
"enabled": r.get::<i32, _>("enabled"),
"created_at": r.get::<String, _>("created_at"),
"updated_at": r.get::<String, _>("updated_at"),
})).collect();
Json(ApiResponse::ok(serde_json::json!({
"policies": items,
})))
}
Err(e) => Json(ApiResponse::internal_error("query usb policies", e)),
}
}
#[derive(Debug, Deserialize)]
pub struct CreatePolicyRequest {
pub name: String,
pub policy_type: String,
pub target_group: Option<String>,
pub rules: String,
pub enabled: Option<i32>,
}
pub async fn create_policy(
State(state): State<AppState>,
Json(body): Json<CreatePolicyRequest>,
) -> (StatusCode, Json<ApiResponse<serde_json::Value>>) {
let enabled = body.enabled.unwrap_or(1);
let result = sqlx::query(
"INSERT INTO usb_policies (name, policy_type, target_group, rules, enabled) VALUES (?, ?, ?, ?, ?)"
)
.bind(&body.name)
.bind(&body.policy_type)
.bind(&body.target_group)
.bind(&body.rules)
.bind(enabled)
.execute(&state.db)
.await;
match result {
Ok(r) => {
let new_id = r.last_insert_rowid();
// Push USB policy to matching online clients
if enabled == 1 {
let payload = build_usb_policy_payload(&body.policy_type, true, &body.rules);
let target_group = body.target_group.as_deref();
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
}
(StatusCode::CREATED, Json(ApiResponse::ok(serde_json::json!({
"id": new_id,
}))))
}
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::internal_error("create usb policy", e))),
}
}
#[derive(Debug, Deserialize)]
pub struct UpdatePolicyRequest {
pub name: Option<String>,
pub policy_type: Option<String>,
pub target_group: Option<String>,
pub rules: Option<String>,
pub enabled: Option<i32>,
}
pub async fn update_policy(
State(state): State<AppState>,
Path(id): Path<i64>,
Json(body): Json<UpdatePolicyRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
// Fetch existing policy
let existing = sqlx::query("SELECT * FROM usb_policies WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let existing = match existing {
Ok(Some(row)) => row,
Ok(None) => return Json(ApiResponse::error("Policy not found")),
Err(e) => return Json(ApiResponse::internal_error("query usb policy", e)),
};
let name = body.name.unwrap_or_else(|| existing.get::<String, _>("name"));
let policy_type = body.policy_type.unwrap_or_else(|| existing.get::<String, _>("policy_type"));
let target_group = body.target_group.or_else(|| existing.get::<Option<String>, _>("target_group"));
let rules = body.rules.unwrap_or_else(|| existing.get::<String, _>("rules"));
let enabled = body.enabled.unwrap_or_else(|| existing.get::<i32, _>("enabled"));
let result = sqlx::query(
"UPDATE usb_policies SET name = ?, policy_type = ?, target_group = ?, rules = ?, enabled = ?, updated_at = datetime('now') WHERE id = ?"
)
.bind(&name)
.bind(&policy_type)
.bind(&target_group)
.bind(&rules)
.bind(enabled)
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(_) => {
// Push updated USB policy to matching online clients
let payload = build_usb_policy_payload(&policy_type, enabled == 1, &rules);
let target_group = target_group.as_deref();
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &payload, "group", target_group).await;
Json(ApiResponse::ok(serde_json::json!({"updated": true})))
}
Err(e) => Json(ApiResponse::internal_error("update usb policy", e)),
}
}
pub async fn delete_policy(
State(state): State<AppState>,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
// Fetch existing policy to get target info for push
let existing = sqlx::query("SELECT target_group FROM usb_policies WHERE id = ?")
.bind(id)
.fetch_optional(&state.db)
.await;
let target_group = match existing {
Ok(Some(row)) => row.get::<Option<String>, _>("target_group"),
_ => return Json(ApiResponse::error("Policy not found")),
};
let result = sqlx::query("DELETE FROM usb_policies WHERE id = ?")
.bind(id)
.execute(&state.db)
.await;
match result {
Ok(r) => {
if r.rows_affected() > 0 {
// Push disabled policy to clients
let disabled = UsbPolicyPayload {
policy_type: String::new(),
enabled: false,
rules: vec![],
};
push_to_targets(&state.db, &state.clients, MessageType::UsbPolicyUpdate, &disabled, "group", target_group.as_deref()).await;
Json(ApiResponse::ok(serde_json::json!({"deleted": true})))
} else {
Json(ApiResponse::error("Policy not found"))
}
}
Err(e) => Json(ApiResponse::internal_error("delete usb policy", e)),
}
}
/// Build a UsbPolicyPayload from raw policy fields
fn build_usb_policy_payload(policy_type: &str, enabled: bool, rules_json: &str) -> UsbPolicyPayload {
let raw_rules: Vec<serde_json::Value> = serde_json::from_str(rules_json).unwrap_or_default();
let rules: Vec<UsbDeviceRule> = raw_rules.iter().map(|r| UsbDeviceRule {
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
}).collect();
UsbPolicyPayload {
policy_type: policy_type.to_string(),
enabled,
rules,
}
}

View File

@@ -0,0 +1,28 @@
use sqlx::SqlitePool;
use tracing::debug;
/// Record an admin audit log entry.
pub async fn audit_log(
db: &SqlitePool,
user_id: i64,
action: &str,
target_type: Option<&str>,
target_id: Option<&str>,
detail: Option<&str>,
) {
let result = sqlx::query(
"INSERT INTO admin_audit_log (user_id, action, target_type, target_id, detail) VALUES (?, ?, ?, ?, ?)"
)
.bind(user_id)
.bind(action)
.bind(target_type)
.bind(target_id)
.bind(detail)
.execute(db)
.await;
match result {
Ok(_) => debug!("Audit: user={} action={} target={}/{}", user_id, action, target_type.unwrap_or("-"), target_id.unwrap_or("-")),
Err(e) => tracing::warn!("Failed to write audit log: {}", e),
}
}

134
crates/server/src/config.rs Normal file
View File

@@ -0,0 +1,134 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub retention: RetentionConfig,
#[serde(default)]
pub notify: NotifyConfig,
/// Token required for device registration. Empty = any token accepted.
#[serde(default)]
pub registration_token: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ServerConfig {
pub http_addr: String,
pub tcp_addr: String,
/// Allowed CORS origins. Empty = same-origin only (no CORS headers).
#[serde(default)]
pub cors_origins: Vec<String>,
/// Optional TLS configuration for the TCP listener.
#[serde(default)]
pub tls: Option<TlsConfig>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TlsConfig {
/// Path to the server certificate (PEM format)
pub cert_path: String,
/// Path to the server private key (PEM format)
pub key_path: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DatabaseConfig {
pub path: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AuthConfig {
pub jwt_secret: String,
#[serde(default = "default_access_ttl")]
pub access_token_ttl_secs: u64,
#[serde(default = "default_refresh_ttl")]
pub refresh_token_ttl_secs: u64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RetentionConfig {
#[serde(default = "default_status_history_days")]
pub status_history_days: u32,
#[serde(default = "default_usb_events_days")]
pub usb_events_days: u32,
#[serde(default = "default_asset_changes_days")]
pub asset_changes_days: u32,
#[serde(default = "default_alert_records_days")]
pub alert_records_days: u32,
#[serde(default = "default_audit_log_days")]
pub audit_log_days: u32,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct NotifyConfig {
#[serde(default)]
pub smtp: Option<SmtpConfig>,
#[serde(default)]
pub webhook_urls: Vec<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SmtpConfig {
pub host: String,
pub port: u16,
pub username: String,
pub password: String,
pub from: String,
}
impl AppConfig {
pub async fn load(path: &str) -> Result<Self> {
if Path::new(path).exists() {
let content = tokio::fs::read_to_string(path).await?;
let config: AppConfig = toml::from_str(&content)?;
Ok(config)
} else {
let config = default_config();
// Write default config for reference
let toml_str = toml::to_string_pretty(&config)?;
tokio::fs::write(path, &toml_str).await?;
tracing::warn!("Created default config at {}", path);
Ok(config)
}
}
}
fn default_access_ttl() -> u64 { 1800 } // 30 minutes
fn default_refresh_ttl() -> u64 { 604800 } // 7 days
fn default_status_history_days() -> u32 { 7 }
fn default_usb_events_days() -> u32 { 90 }
fn default_asset_changes_days() -> u32 { 365 }
fn default_alert_records_days() -> u32 { 90 }
fn default_audit_log_days() -> u32 { 365 }
pub fn default_config() -> AppConfig {
AppConfig {
server: ServerConfig {
http_addr: "0.0.0.0:8080".into(),
tcp_addr: "0.0.0.0:9999".into(),
cors_origins: vec![],
tls: None,
},
database: DatabaseConfig {
path: "./csm.db".into(),
},
auth: AuthConfig {
jwt_secret: uuid::Uuid::new_v4().to_string(),
access_token_ttl_secs: default_access_ttl(),
refresh_token_ttl_secs: default_refresh_ttl(),
},
retention: RetentionConfig {
status_history_days: default_status_history_days(),
usb_events_days: default_usb_events_days(),
asset_changes_days: default_asset_changes_days(),
alert_records_days: default_alert_records_days(),
audit_log_days: default_audit_log_days(),
},
notify: NotifyConfig::default(),
registration_token: uuid::Uuid::new_v4().to_string(),
}
}

118
crates/server/src/db.rs Normal file
View File

@@ -0,0 +1,118 @@
use sqlx::SqlitePool;
use anyhow::Result;
/// Database repository for device operations
pub struct DeviceRepo;
impl DeviceRepo {
pub async fn upsert_status(pool: &SqlitePool, device_uid: &str, status: &csm_protocol::DeviceStatus) -> Result<()> {
let top_procs_json = serde_json::to_string(&status.top_processes)?;
// Update latest snapshot
sqlx::query(
"INSERT INTO device_status (device_uid, cpu_usage, memory_usage, memory_total_mb, disk_usage, disk_total_mb, network_rx_rate, network_tx_rate, running_procs, top_processes, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid) DO UPDATE SET
cpu_usage = excluded.cpu_usage,
memory_usage = excluded.memory_usage,
memory_total_mb = excluded.memory_total_mb,
disk_usage = excluded.disk_usage,
disk_total_mb = excluded.disk_total_mb,
network_rx_rate = excluded.network_rx_rate,
network_tx_rate = excluded.network_tx_rate,
running_procs = excluded.running_procs,
top_processes = excluded.top_processes,
reported_at = datetime('now'),
updated_at = datetime('now')"
)
.bind(device_uid)
.bind(status.cpu_usage)
.bind(status.memory_usage)
.bind(status.memory_total_mb as i64)
.bind(status.disk_usage)
.bind(status.disk_total_mb as i64)
.bind(status.network_rx_rate as i64)
.bind(status.network_tx_rate as i64)
.bind(status.running_procs as i32)
.bind(&top_procs_json)
.execute(pool)
.await?;
// Insert into history
sqlx::query(
"INSERT INTO device_status_history (device_uid, cpu_usage, memory_usage, disk_usage, network_rx_rate, network_tx_rate, running_procs, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
)
.bind(device_uid)
.bind(status.cpu_usage)
.bind(status.memory_usage)
.bind(status.disk_usage)
.bind(status.network_rx_rate as i64)
.bind(status.network_tx_rate as i64)
.bind(status.running_procs as i32)
.execute(pool)
.await?;
// Update device heartbeat
sqlx::query(
"UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?"
)
.bind(device_uid)
.execute(pool)
.await?;
Ok(())
}
pub async fn insert_usb_event(pool: &SqlitePool, event: &csm_protocol::UsbEvent) -> Result<i64> {
let result = sqlx::query(
"INSERT INTO usb_events (device_uid, vendor_id, product_id, serial_number, device_name, event_type)
VALUES (?, ?, ?, ?, ?, ?)"
)
.bind(&event.device_uid)
.bind(&event.vendor_id)
.bind(&event.product_id)
.bind(&event.serial)
.bind(&event.device_name)
.bind(match event.event_type {
csm_protocol::UsbEventType::Inserted => "inserted",
csm_protocol::UsbEventType::Removed => "removed",
csm_protocol::UsbEventType::Blocked => "blocked",
})
.execute(pool)
.await?;
Ok(result.last_insert_rowid())
}
pub async fn upsert_hardware(pool: &SqlitePool, asset: &csm_protocol::HardwareAsset) -> Result<()> {
sqlx::query(
"INSERT INTO hardware_assets (device_uid, cpu_model, cpu_cores, memory_total_mb, disk_model, disk_total_mb, gpu_model, motherboard, serial_number, reported_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(device_uid) DO UPDATE SET
cpu_model = excluded.cpu_model,
cpu_cores = excluded.cpu_cores,
memory_total_mb = excluded.memory_total_mb,
disk_model = excluded.disk_model,
disk_total_mb = excluded.disk_total_mb,
gpu_model = excluded.gpu_model,
motherboard = excluded.motherboard,
serial_number = excluded.serial_number,
reported_at = datetime('now'),
updated_at = datetime('now')"
)
.bind(&asset.device_uid)
.bind(&asset.cpu_model)
.bind(asset.cpu_cores as i32)
.bind(asset.memory_total_mb as i64)
.bind(&asset.disk_model)
.bind(asset.disk_total_mb as i64)
.bind(&asset.gpu_model)
.bind(&asset.motherboard)
.bind(&asset.serial_number)
.execute(pool)
.await?;
Ok(())
}
}

264
crates/server/src/main.rs Normal file
View File

@@ -0,0 +1,264 @@
use anyhow::Result;
use axum::Router;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteJournalMode};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{CorsLayer, Any};
use tower_http::trace::TraceLayer;
use tower_http::compression::CompressionLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tracing::{info, warn, error};
mod api;
mod audit;
mod config;
mod db;
mod tcp;
mod ws;
mod alert;
use config::AppConfig;
/// Application shared state
#[derive(Clone)]
pub struct AppState {
pub db: sqlx::SqlitePool,
pub config: Arc<AppConfig>,
pub clients: Arc<tcp::ClientRegistry>,
pub ws_hub: Arc<ws::WsHub>,
pub login_limiter: Arc<api::auth::LoginRateLimiter>,
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "csm_server=info,tower_http=info".into()),
)
.json()
.init();
info!("CSM Server starting...");
// Load configuration
let config = AppConfig::load("config.toml").await?;
let config = Arc::new(config);
// Initialize database
let db = init_database(&config.database.path).await?;
run_migrations(&db).await?;
info!("Database initialized at {}", config.database.path);
// Ensure default admin exists
ensure_default_admin(&db).await?;
// Initialize shared state
let clients = Arc::new(tcp::ClientRegistry::new());
let ws_hub = Arc::new(ws::WsHub::new());
let state = AppState {
db: db.clone(),
config: config.clone(),
clients: clients.clone(),
ws_hub: ws_hub.clone(),
login_limiter: Arc::new(api::auth::LoginRateLimiter::new()),
};
// Start background tasks
let cleanup_state = state.clone();
tokio::spawn(async move {
alert::cleanup_task(cleanup_state).await;
});
// Start TCP listener for client connections
let tcp_state = state.clone();
let tcp_addr = config.server.tcp_addr.clone();
tokio::spawn(async move {
if let Err(e) = tcp::start_tcp_server(tcp_addr, tcp_state).await {
error!("TCP server error: {}", e);
}
});
// Build HTTP router
let app = Router::new()
.merge(api::routes(state.clone()))
.layer(
build_cors_layer(&config.server.cors_origins),
)
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
// Security headers
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::X_CONTENT_TYPE_OPTIONS,
axum::http::HeaderValue::from_static("nosniff"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::X_FRAME_OPTIONS,
axum::http::HeaderValue::from_static("DENY"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("x-xss-protection"),
axum::http::HeaderValue::from_static("1; mode=block"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("referrer-policy"),
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
))
.layer(SetResponseHeaderLayer::if_not_present(
axum::http::header::HeaderName::from_static("content-security-policy"),
axum::http::HeaderValue::from_static("default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:"),
))
.with_state(state);
// Start HTTP server
let http_addr = &config.server.http_addr;
info!("HTTP server listening on {}", http_addr);
let listener = TcpListener::bind(http_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn init_database(db_path: &str) -> Result<sqlx::SqlitePool> {
// Ensure parent directory exists for file-based databases
// Strip sqlite: prefix if present for directory creation
let file_path = db_path.strip_prefix("sqlite:").unwrap_or(db_path);
// Strip query parameters
let file_path = file_path.split('?').next().unwrap_or(file_path);
if let Some(parent) = Path::new(file_path).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
let options = SqliteConnectOptions::from_str(db_path)?
.journal_mode(SqliteJournalMode::Wal)
.synchronous(sqlx::sqlite::SqliteSynchronous::Normal)
.busy_timeout(std::time::Duration::from_secs(5))
.foreign_keys(true);
let pool = SqlitePoolOptions::new()
.max_connections(8)
.connect_with(options)
.await?;
// Set pragmas on each connection
sqlx::query("PRAGMA cache_size = -64000")
.execute(&pool)
.await?;
sqlx::query("PRAGMA wal_autocheckpoint = 1000")
.execute(&pool)
.await?;
Ok(pool)
}
async fn run_migrations(pool: &sqlx::SqlitePool) -> Result<()> {
// Embedded migrations - run in order
let migrations = [
include_str!("../../../migrations/001_init.sql"),
include_str!("../../../migrations/002_assets.sql"),
include_str!("../../../migrations/003_usb.sql"),
include_str!("../../../migrations/004_alerts.sql"),
include_str!("../../../migrations/005_plugins_web_filter.sql"),
include_str!("../../../migrations/006_plugins_usage_timer.sql"),
include_str!("../../../migrations/007_plugins_software_blocker.sql"),
include_str!("../../../migrations/008_plugins_popup_blocker.sql"),
include_str!("../../../migrations/009_plugins_usb_file_audit.sql"),
include_str!("../../../migrations/010_plugins_watermark.sql"),
include_str!("../../../migrations/011_token_security.sql"),
];
// Create migrations tracking table
sqlx::query(
"CREATE TABLE IF NOT EXISTS _migrations (id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, applied_at TEXT NOT NULL DEFAULT (datetime('now')))"
)
.execute(pool)
.await?;
for (i, migration_sql) in migrations.iter().enumerate() {
let name = format!("{:03}", i + 1);
let exists: bool = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM _migrations WHERE name = ?"
)
.bind(&name)
.fetch_one(pool)
.await? > 0;
if !exists {
info!("Running migration: {}", name);
sqlx::query(migration_sql)
.execute(pool)
.await?;
sqlx::query("INSERT INTO _migrations (name) VALUES (?)")
.bind(&name)
.execute(pool)
.await?;
}
}
Ok(())
}
async fn ensure_default_admin(pool: &sqlx::SqlitePool) -> Result<()> {
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users")
.fetch_one(pool)
.await?;
if count == 0 {
// Generate a random 16-character alphanumeric password
let random_password: String = {
use std::fmt::Write;
let bytes = uuid::Uuid::new_v4();
let mut s = String::with_capacity(16);
for byte in bytes.as_bytes().iter().take(16) {
write!(s, "{:02x}", byte).unwrap();
}
s
};
let password_hash = bcrypt::hash(&random_password, 12)?;
sqlx::query(
"INSERT INTO users (username, password, role) VALUES (?, ?, 'admin')"
)
.bind("admin")
.bind(&password_hash)
.execute(pool)
.await?;
warn!("Created default admin user (username: admin)");
// Print password directly to stderr — bypasses tracing JSON formatter
eprintln!("============================================================");
eprintln!(" Generated admin password: {}", random_password);
eprintln!(" *** Save this password now — it will NOT be shown again! ***");
eprintln!("============================================================");
}
Ok(())
}
/// Build CORS layer from configured origins.
/// If cors_origins is empty, no CORS headers are sent (same-origin only).
/// If origins are specified, only those are allowed.
fn build_cors_layer(origins: &[String]) -> CorsLayer {
use axum::http::HeaderValue;
let allowed_origins: Vec<HeaderValue> = origins.iter()
.filter_map(|o| o.parse::<HeaderValue>().ok())
.collect();
if allowed_origins.is_empty() {
// No CORS — production safe by default
CorsLayer::new()
} else {
CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::list(allowed_origins))
.allow_methods(Any)
.allow_headers(Any)
.max_age(std::time::Duration::from_secs(3600))
}
}

844
crates/server/src/tcp.rs Normal file
View File

@@ -0,0 +1,844 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tokio::net::{TcpListener, TcpStream};
use tracing::{info, warn, debug};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use csm_protocol::{Frame, MessageType, PROTOCOL_VERSION};
use crate::AppState;
/// Maximum frames per second per connection before rate-limiting kicks in
const RATE_LIMIT_WINDOW_SECS: u64 = 5;
const RATE_LIMIT_MAX_FRAMES: usize = 100;
/// Per-connection rate limiter using a sliding window of frame timestamps
struct RateLimiter {
timestamps: Vec<Instant>,
}
impl RateLimiter {
fn new() -> Self {
Self { timestamps: Vec::with_capacity(RATE_LIMIT_MAX_FRAMES) }
}
/// Returns false if the connection is rate-limited
fn check(&mut self) -> bool {
let now = Instant::now();
let cutoff = now - std::time::Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
// Evict timestamps outside the window
self.timestamps.retain(|t| *t > cutoff);
if self.timestamps.len() >= RATE_LIMIT_MAX_FRAMES {
return false;
}
self.timestamps.push(now);
true
}
}
/// Push a plugin config frame to all online clients matching the target scope.
/// target_type: "global" | "device" | "group"
/// target_id: device_uid or group_name (None for global)
pub async fn push_to_targets(
db: &sqlx::SqlitePool,
clients: &crate::tcp::ClientRegistry,
msg_type: MessageType,
payload: &impl serde::Serialize,
target_type: &str,
target_id: Option<&str>,
) {
let frame = match Frame::new_json(msg_type, payload) {
Ok(f) => f.encode(),
Err(e) => {
warn!("Failed to encode plugin push frame: {}", e);
return;
}
};
let online = clients.list_online().await;
let mut pushed_count = 0usize;
// For group targeting, resolve group members from DB once
let group_members: Option<Vec<String>> = if target_type == "group" {
if let Some(group_name) = target_id {
sqlx::query_scalar::<_, String>(
"SELECT device_uid FROM devices WHERE group_name = ?"
)
.bind(group_name)
.fetch_all(db)
.await
.ok()
.into()
} else {
None
}
} else {
None
};
for uid in &online {
let should_push = match target_type {
"global" => true,
"device" => target_id.map_or(false, |id| id == uid),
"group" => {
if let Some(members) = &group_members {
members.contains(uid)
} else {
false
}
}
other => {
warn!("Unknown target_type '{}', skipping push", other);
false
}
};
if should_push {
if clients.send_to(uid, frame.clone()).await {
pushed_count += 1;
}
}
}
debug!("Pushed {:?} to {}/{} online clients (target={})", msg_type, pushed_count, online.len(), target_type);
}
/// Push all active plugin configs to a newly registered client.
pub async fn push_all_plugin_configs(
db: &sqlx::SqlitePool,
clients: &crate::tcp::ClientRegistry,
device_uid: &str,
) {
use sqlx::Row;
// Watermark configs — only push the highest-priority enabled config (device > group > global)
if let Ok(rows) = sqlx::query(
"SELECT content, font_size, opacity, color, angle, enabled FROM watermark_config WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?))) ORDER BY CASE WHEN target_type = 'device' THEN 0 WHEN target_type = 'group' THEN 1 ELSE 2 END LIMIT 1"
)
.bind(device_uid)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
if let Some(row) = rows.first() {
let config = csm_protocol::WatermarkConfigPayload {
content: row.get("content"),
font_size: row.get::<i32, _>("font_size") as u32,
opacity: row.get("opacity"),
color: row.get("color"),
angle: row.get::<i32, _>("angle"),
enabled: row.get("enabled"),
};
if let Ok(frame) = Frame::new_json(MessageType::WatermarkConfig, &config) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Web filter rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, pattern FROM web_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"pattern": r.get::<String, _>("pattern"),
})).collect();
if !rules.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::WebFilterRuleUpdate, &serde_json::json!({"rules": rules})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Software blacklist
if let Ok(rows) = sqlx::query(
"SELECT id, name_pattern, action FROM software_blacklist WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let entries: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"name_pattern": r.get::<String, _>("name_pattern"),
"action": r.get::<String, _>("action"),
})).collect();
if !entries.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::SoftwareBlacklist, &serde_json::json!({"blacklist": entries})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// Popup blocker rules
if let Ok(rows) = sqlx::query(
"SELECT id, rule_type, window_title, window_class, process_name FROM popup_filter_rules WHERE enabled = 1 AND (target_type = 'global' OR (target_type = 'device' AND target_id = ?) OR (target_type = 'group' AND target_id = (SELECT group_name FROM devices WHERE device_uid = ?)))"
)
.bind(device_uid)
.bind(device_uid)
.fetch_all(db).await
{
let rules: Vec<serde_json::Value> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64, _>("id"),
"rule_type": r.get::<String, _>("rule_type"),
"window_title": r.get::<Option<String>, _>("window_title"),
"window_class": r.get::<Option<String>, _>("window_class"),
"process_name": r.get::<Option<String>, _>("process_name"),
})).collect();
if !rules.is_empty() {
if let Ok(frame) = Frame::new_json(MessageType::PopupRules, &serde_json::json!({"rules": rules})) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
// USB policies — push highest-priority enabled policy for the device's group
if let Ok(rows) = sqlx::query(
"SELECT policy_type, rules, enabled FROM usb_policies WHERE enabled = 1 AND target_group = (SELECT group_name FROM devices WHERE device_uid = ?) ORDER BY CASE WHEN policy_type = 'all_block' THEN 0 WHEN policy_type = 'blacklist' THEN 1 ELSE 2 END LIMIT 1"
)
.bind(device_uid)
.fetch_all(db).await
{
if let Some(row) = rows.first() {
let policy_type: String = row.get("policy_type");
let rules_json: String = row.get("rules");
let rules: Vec<serde_json::Value> = serde_json::from_str(&rules_json).unwrap_or_default();
let payload = csm_protocol::UsbPolicyPayload {
policy_type,
enabled: true,
rules: rules.iter().map(|r| csm_protocol::UsbDeviceRule {
vendor_id: r.get("vendor_id").and_then(|v| v.as_str().map(String::from)),
product_id: r.get("product_id").and_then(|v| v.as_str().map(String::from)),
serial: r.get("serial").and_then(|v| v.as_str().map(String::from)),
device_name: r.get("device_name").and_then(|v| v.as_str().map(String::from)),
}).collect(),
};
if let Ok(frame) = Frame::new_json(MessageType::UsbPolicyUpdate, &payload) {
clients.send_to(device_uid, frame.encode()).await;
}
}
}
info!("Pushed all plugin configs to newly registered device {}", device_uid);
}
/// Maximum accumulated read buffer size per connection (8 MB)
const MAX_READ_BUF_SIZE: usize = 8 * 1024 * 1024;
/// Registry of all connected client sessions
#[derive(Clone, Default)]
pub struct ClientRegistry {
sessions: Arc<RwLock<HashMap<String, Arc<tokio::sync::mpsc::Sender<Vec<u8>>>>>>,
}
impl ClientRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn register(&self, device_uid: String, tx: Arc<tokio::sync::mpsc::Sender<Vec<u8>>>) {
self.sessions.write().await.insert(device_uid, tx);
}
pub async fn unregister(&self, device_uid: &str) {
self.sessions.write().await.remove(device_uid);
}
pub async fn send_to(&self, device_uid: &str, data: Vec<u8>) -> bool {
if let Some(tx) = self.sessions.read().await.get(device_uid) {
tx.send(data).await.is_ok()
} else {
false
}
}
pub async fn count(&self) -> usize {
self.sessions.read().await.len()
}
pub async fn list_online(&self) -> Vec<String> {
self.sessions.read().await.keys().cloned().collect()
}
}
/// Start the TCP server for client connections (optionally with TLS)
pub async fn start_tcp_server(addr: String, state: AppState) -> anyhow::Result<()> {
let listener = TcpListener::bind(&addr).await?;
// Build TLS acceptor if configured
let tls_acceptor = build_tls_acceptor(&state.config.server.tls)?;
if tls_acceptor.is_some() {
info!("TCP server listening on {} (TLS enabled)", addr);
} else {
info!("TCP server listening on {} (plaintext)", addr);
}
loop {
let (stream, peer_addr) = listener.accept().await?;
let state = state.clone();
let acceptor = tls_acceptor.clone();
tokio::spawn(async move {
debug!("New TCP connection from {}", peer_addr);
match acceptor {
Some(acceptor) => {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_client_tls(tls_stream, state).await {
warn!("Client {} TLS error: {}", peer_addr, e);
}
}
Err(e) => warn!("TLS handshake failed for {}: {}", peer_addr, e),
}
}
None => {
if let Err(e) = handle_client(stream, state).await {
warn!("Client {} error: {}", peer_addr, e);
}
}
}
});
}
}
fn build_tls_acceptor(
tls_config: &Option<crate::config::TlsConfig>,
) -> anyhow::Result<Option<tokio_rustls::TlsAcceptor>> {
let config = match tls_config {
Some(c) => c,
None => return Ok(None),
};
let cert_pem = std::fs::read(&config.cert_path)
.map_err(|e| anyhow::anyhow!("Failed to read TLS cert {}: {}", config.cert_path, e))?;
let key_pem = std::fs::read(&config.key_path)
.map_err(|e| anyhow::anyhow!("Failed to read TLS key {}: {}", config.key_path, e))?;
let certs: Vec<rustls_pki_types::CertificateDer> = rustls_pemfile::certs(&mut &cert_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse TLS cert: {:?}", e))?
.into_iter()
.map(|c| c.into())
.collect();
let key = rustls_pemfile::private_key(&mut &key_pem[..])
.map_err(|e| anyhow::anyhow!("Failed to parse TLS key: {:?}", e))?
.ok_or_else(|| anyhow::anyhow!("No private key found in {}", config.key_path))?;
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| anyhow::anyhow!("Failed to build TLS config: {}", e))?;
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(server_config))))
}
/// Cleanup on client disconnect: unregister from client map, mark offline, notify WS.
async fn cleanup_on_disconnect(state: &AppState, device_uid: &Option<String>) {
if let Some(uid) = device_uid {
state.clients.unregister(uid).await;
sqlx::query("UPDATE devices SET status = 'offline' WHERE device_uid = ?")
.bind(uid)
.execute(&state.db)
.await
.ok();
state.ws_hub.broadcast(serde_json::json!({
"type": "device_state",
"device_uid": uid,
"status": "offline"
}).to_string()).await;
info!("Device disconnected: {}", uid);
}
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") → hex-encoded
fn compute_hmac(secret: &str, device_uid: &str, timestamp: &str) -> String {
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
/// Verify that a frame sender is a registered device and that the claimed device_uid
/// matches the one registered on this connection. Returns true if valid.
fn verify_device_uid(device_uid: &Option<String>, msg_type: &str, claimed_uid: &str) -> bool {
match device_uid {
Some(uid) if *uid == claimed_uid => true,
Some(uid) => {
warn!("{} device_uid mismatch: expected {:?}, got {}", msg_type, uid, claimed_uid);
false
}
None => {
warn!("{} from unregistered connection", msg_type);
false
}
}
}
/// Process a single decoded frame. Shared by both plaintext and TLS handlers.
async fn process_frame(
frame: Frame,
state: &AppState,
device_uid: &mut Option<String>,
tx: &Arc<tokio::sync::mpsc::Sender<Vec<u8>>>,
) -> anyhow::Result<()> {
match frame.msg_type {
MessageType::Register => {
let req: csm_protocol::RegisterRequest = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration payload: {}", e))?;
info!("Device registration attempt: {} ({})", req.hostname, req.device_uid);
// Validate registration token against configured token
let expected_token = &state.config.registration_token;
if !expected_token.is_empty() {
if req.registration_token.is_empty() || req.registration_token != *expected_token {
warn!("Registration rejected for {}: invalid token", req.device_uid);
let err_frame = Frame::new_json(MessageType::RegisterResponse,
&serde_json::json!({"error": "invalid_registration_token"}))?;
tx.send(err_frame.encode()).await.ok();
return Ok(());
}
}
// Check if device already exists with a secret (reconnection scenario)
let existing_secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&req.device_uid)
.fetch_optional(&state.db)
.await
.ok()
.flatten();
let device_secret = match existing_secret {
// Existing device — keep the same secret, don't rotate
Some(secret) if !secret.is_empty() => secret,
// New device — generate a fresh secret
_ => uuid::Uuid::new_v4().to_string(),
};
sqlx::query(
"INSERT INTO devices (device_uid, hostname, ip_address, mac_address, os_version, device_secret, status) \
VALUES (?, ?, '0.0.0.0', ?, ?, ?, 'online') \
ON CONFLICT(device_uid) DO UPDATE SET hostname=excluded.hostname, os_version=excluded.os_version, \
mac_address=excluded.mac_address, status='online'"
)
.bind(&req.device_uid)
.bind(&req.hostname)
.bind(&req.mac_address)
.bind(&req.os_version)
.bind(&device_secret)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error during registration: {}", e))?;
*device_uid = Some(req.device_uid.clone());
// If this device was already connected on a different session, evict the old one
// The new register() call will replace it in the hashmap
state.clients.register(req.device_uid.clone(), tx.clone()).await;
// Send registration response
let config = csm_protocol::ClientConfig::default();
let response = csm_protocol::RegisterResponse {
device_secret,
config,
};
let resp_frame = Frame::new_json(MessageType::RegisterResponse, &response)?;
tx.send(resp_frame.encode()).await?;
info!("Device registered successfully: {} ({})", req.hostname, req.device_uid);
// Push all active plugin configs to newly registered client
push_all_plugin_configs(&state.db, &state.clients, &req.device_uid).await;
}
MessageType::Heartbeat => {
let heartbeat: csm_protocol::HeartbeatPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid heartbeat: {}", e))?;
if !verify_device_uid(device_uid, "Heartbeat", &heartbeat.device_uid) {
return Ok(());
}
// Verify HMAC — reject if secret exists but HMAC is missing or wrong
let secret: Option<String> = sqlx::query_scalar(
"SELECT device_secret FROM devices WHERE device_uid = ?"
)
.bind(&heartbeat.device_uid)
.fetch_optional(&state.db)
.await
.map_err(|e| {
warn!("DB error fetching device_secret for {}: {}", heartbeat.device_uid, e);
anyhow::anyhow!("DB error during HMAC verification")
})?;
if let Some(ref secret) = secret {
if !secret.is_empty() {
if heartbeat.hmac.is_empty() {
warn!("Heartbeat missing HMAC for device {}", heartbeat.device_uid);
return Ok(());
}
// Constant-time HMAC verification using hmac::Mac::verify_slice
let message = format!("{}\n{}", heartbeat.device_uid, heartbeat.timestamp);
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
.map_err(|_| anyhow::anyhow!("HMAC key error"))?;
mac.update(message.as_bytes());
let provided_bytes = hex::decode(&heartbeat.hmac).unwrap_or_default();
if mac.verify_slice(&provided_bytes).is_err() {
warn!("Heartbeat HMAC mismatch for device {}", heartbeat.device_uid);
return Ok(());
}
}
}
debug!("Heartbeat from {} (hmac verified)", heartbeat.device_uid);
// Update device status in DB
sqlx::query("UPDATE devices SET status = 'online', last_heartbeat = datetime('now') WHERE device_uid = ?")
.bind(&heartbeat.device_uid)
.execute(&state.db)
.await
.ok();
// Push to WebSocket subscribers
state.ws_hub.broadcast(serde_json::json!({
"type": "device_state",
"device_uid": heartbeat.device_uid,
"status": "online"
}).to_string()).await;
}
MessageType::StatusReport => {
let status: csm_protocol::DeviceStatus = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid status report: {}", e))?;
if !verify_device_uid(device_uid, "StatusReport", &status.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::upsert_status(&state.db, &status.device_uid, &status).await?;
// Push to WebSocket subscribers
state.ws_hub.broadcast(serde_json::json!({
"type": "device_status",
"device_uid": status.device_uid,
"cpu": status.cpu_usage,
"memory": status.memory_usage,
"disk": status.disk_usage
}).to_string()).await;
}
MessageType::UsbEvent => {
let event: csm_protocol::UsbEvent = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB event: {}", e))?;
if !verify_device_uid(device_uid, "UsbEvent", &event.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::insert_usb_event(&state.db, &event).await?;
state.ws_hub.broadcast(serde_json::json!({
"type": "usb_event",
"device_uid": event.device_uid,
"event": event.event_type,
"usb_name": event.device_name
}).to_string()).await;
}
MessageType::AssetReport => {
let asset: csm_protocol::HardwareAsset = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid asset report: {}", e))?;
if !verify_device_uid(device_uid, "AssetReport", &asset.device_uid) {
return Ok(());
}
crate::db::DeviceRepo::upsert_hardware(&state.db, &asset).await?;
}
MessageType::UsageReport => {
let report: csm_protocol::UsageDailyReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid usage report: {}", e))?;
if !verify_device_uid(device_uid, "UsageReport", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO usage_daily (device_uid, date, total_active_minutes, total_idle_minutes, first_active_at, last_active_at) \
VALUES (?, ?, ?, ?, ?, ?) \
ON CONFLICT(device_uid, date) DO UPDATE SET \
total_active_minutes = excluded.total_active_minutes, \
total_idle_minutes = excluded.total_idle_minutes, \
first_active_at = excluded.first_active_at, \
last_active_at = excluded.last_active_at"
)
.bind(&report.device_uid)
.bind(&report.date)
.bind(report.total_active_minutes as i32)
.bind(report.total_idle_minutes as i32)
.bind(&report.first_active_at)
.bind(&report.last_active_at)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting usage report: {}", e))?;
debug!("Usage report saved for device {}", report.device_uid);
}
MessageType::AppUsageReport => {
let report: csm_protocol::AppUsageEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid app usage report: {}", e))?;
if !verify_device_uid(device_uid, "AppUsageReport", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO app_usage_daily (device_uid, date, app_name, usage_minutes) \
VALUES (?, ?, ?, ?) \
ON CONFLICT(device_uid, date, app_name) DO UPDATE SET \
usage_minutes = MAX(usage_minutes, excluded.usage_minutes)"
)
.bind(&report.device_uid)
.bind(&report.date)
.bind(&report.app_name)
.bind(report.usage_minutes as i32)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting app usage: {}", e))?;
debug!("App usage saved: {} -> {} ({} min)", report.device_uid, report.app_name, report.usage_minutes);
}
MessageType::SoftwareViolation => {
let report: csm_protocol::SoftwareViolationReport = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software violation: {}", e))?;
if !verify_device_uid(device_uid, "SoftwareViolation", &report.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO software_violations (device_uid, software_name, action_taken, timestamp) VALUES (?, ?, ?, ?)"
)
.bind(&report.device_uid)
.bind(&report.software_name)
.bind(&report.action_taken)
.bind(&report.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting software violation: {}", e))?;
info!("Software violation: {} tried to run {} -> {}", report.device_uid, report.software_name, report.action_taken);
state.ws_hub.broadcast(serde_json::json!({
"type": "software_violation",
"device_uid": report.device_uid,
"software_name": report.software_name,
"action_taken": report.action_taken
}).to_string()).await;
}
MessageType::UsbFileOp => {
let entry: csm_protocol::UsbFileOpEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB file op: {}", e))?;
if !verify_device_uid(device_uid, "UsbFileOp", &entry.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO usb_file_operations (device_uid, usb_serial, drive_letter, operation, file_path, file_size, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)"
)
.bind(&entry.device_uid)
.bind(&entry.usb_serial)
.bind(&entry.drive_letter)
.bind(&entry.operation)
.bind(&entry.file_path)
.bind(entry.file_size.map(|s| s as i64))
.bind(&entry.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting USB file op: {}", e))?;
debug!("USB file op: {} {} on {}", entry.operation, entry.file_path, entry.device_uid);
}
MessageType::WebAccessLog => {
let entry: csm_protocol::WebAccessLogEntry = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web access log: {}", e))?;
if !verify_device_uid(device_uid, "WebAccessLog", &entry.device_uid) {
return Ok(());
}
sqlx::query(
"INSERT INTO web_access_log (device_uid, url, action, timestamp) VALUES (?, ?, ?, ?)"
)
.bind(&entry.device_uid)
.bind(&entry.url)
.bind(&entry.action)
.bind(&entry.timestamp)
.execute(&state.db)
.await
.map_err(|e| anyhow::anyhow!("DB error inserting web access log: {}", e))?;
debug!("Web access log: {} {} {}", entry.device_uid, entry.action, entry.url);
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
/// Handle a single client TCP connection
async fn handle_client(stream: TcpStream, state: AppState) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// Set read timeout to detect stale connections
let _ = stream.set_nodelay(true);
let (mut reader, mut writer) = stream.into_split();
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
// Writer task: forwards messages from channel to TCP stream
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
}
});
// Reader loop
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break; // Connection closed
}
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("Connection exceeded max buffer size, dropping");
break;
}
// Process complete frames
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
// Remove consumed bytes without reallocating
read_buf.drain(..frame_size);
// Rate limit check
if !rate_limiter.check() {
warn!("Rate limit exceeded for device {:?}, dropping connection", device_uid);
break 'reader;
}
// Verify protocol version
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
warn!("Frame processing error: {}", e);
}
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
}
/// Handle a TLS-wrapped client connection
async fn handle_client_tls(
stream: tokio_rustls::server::TlsStream<TcpStream>,
state: AppState,
) -> anyhow::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (mut reader, mut writer) = tokio::io::split(stream);
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(256);
let tx = Arc::new(tx);
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
let mut device_uid: Option<String> = None;
let mut rate_limiter = RateLimiter::new();
let write_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
if writer.write_all(&data).await.is_err() {
break;
}
}
});
// Reader loop — same logic as plaintext handler
'reader: loop {
let n = reader.read(&mut buffer).await?;
if n == 0 {
break;
}
read_buf.extend_from_slice(&buffer[..n]);
if read_buf.len() > MAX_READ_BUF_SIZE {
warn!("TLS connection exceeded max buffer size, dropping");
break;
}
while let Some(frame) = Frame::decode(&read_buf)? {
let frame_size = frame.encoded_size();
read_buf.drain(..frame_size);
if frame.version != PROTOCOL_VERSION {
warn!("Unsupported protocol version: 0x{:02X}", frame.version);
continue;
}
if !rate_limiter.check() {
warn!("Rate limit exceeded for TLS device {:?}, dropping connection", device_uid);
break 'reader;
}
if let Err(e) = process_frame(frame, &state, &mut device_uid, &tx).await {
warn!("Frame processing error: {}", e);
}
}
}
cleanup_on_disconnect(&state, &device_uid).await;
write_task.abort();
Ok(())
}

125
crates/server/src/ws.rs Normal file
View File

@@ -0,0 +1,125 @@
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message};
use axum::response::IntoResponse;
use axum::extract::Query;
use jsonwebtoken::{decode, Validation, DecodingKey};
use serde::Deserialize;
use tokio::sync::broadcast;
use std::sync::Arc;
use tracing::{debug, warn};
use crate::api::auth::Claims;
use crate::AppState;
/// WebSocket hub for broadcasting real-time events to admin browsers
#[derive(Clone)]
pub struct WsHub {
tx: broadcast::Sender<String>,
}
impl WsHub {
pub fn new() -> Self {
let (tx, _) = broadcast::channel(1024);
Self { tx }
}
pub async fn broadcast(&self, message: String) {
if self.tx.send(message).is_err() {
debug!("No WebSocket subscribers to receive broadcast");
}
}
pub fn subscribe(&self) -> broadcast::Receiver<String> {
self.tx.subscribe()
}
}
#[derive(Debug, Deserialize)]
pub struct WsAuthParams {
pub token: Option<String>,
}
/// HTTP upgrade handler for WebSocket connections
/// Validates JWT token from query parameter before upgrading
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(params): Query<WsAuthParams>,
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl IntoResponse {
let token = match params.token {
Some(t) => t,
None => {
warn!("WebSocket connection rejected: no token provided");
return (axum::http::StatusCode::UNAUTHORIZED, "Missing token").into_response();
}
};
let claims = match decode::<Claims>(
&token,
&DecodingKey::from_secret(state.config.auth.jwt_secret.as_bytes()),
&Validation::default(),
) {
Ok(c) => c.claims,
Err(e) => {
warn!("WebSocket connection rejected: invalid token - {}", e);
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token").into_response();
}
};
if claims.token_type != "access" {
warn!("WebSocket connection rejected: not an access token");
return (axum::http::StatusCode::UNAUTHORIZED, "Invalid token type").into_response();
}
let hub = state.ws_hub.clone();
ws.on_upgrade(move |socket| handle_socket(socket, claims, hub))
}
async fn handle_socket(mut socket: WebSocket, claims: Claims, hub: Arc<WsHub>) {
debug!("WebSocket client connected: user={}", claims.username);
let welcome = serde_json::json!({
"type": "connected",
"message": "CSM real-time feed active",
"user": claims.username
});
if socket.send(Message::Text(welcome.to_string())).await.is_err() {
return;
}
// Subscribe to broadcast hub for real-time events
let mut rx = hub.subscribe();
loop {
tokio::select! {
// Forward broadcast messages to WebSocket client
msg = rx.recv() => {
match msg {
Ok(text) => {
if socket.send(Message::Text(text)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
debug!("WebSocket client lagged {} messages, continuing", n);
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
// Handle incoming WebSocket messages (ping/close)
msg = socket.recv() => {
match msg {
Some(Ok(Message::Ping(data))) => {
if socket.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) => break,
Some(Err(_)) => break,
None => break,
_ => {}
}
}
}
}
debug!("WebSocket client disconnected: user={}", claims.username);
}