use anyhow::Result; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::net::TcpStream; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::{info, debug, warn}; use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload, PatchScanConfigPayload}; use hmac::{Hmac, Mac}; use sha2::{Sha256, Digest}; use crate::ClientState; /// Holds senders for all plugin config channels pub struct PluginChannels { pub watermark_tx: tokio::sync::watch::Sender>, pub web_filter_tx: tokio::sync::watch::Sender, pub software_blocker_tx: tokio::sync::watch::Sender, pub popup_blocker_tx: tokio::sync::watch::Sender, pub usb_audit_tx: tokio::sync::watch::Sender, pub usage_timer_tx: tokio::sync::watch::Sender, pub usb_policy_tx: tokio::sync::watch::Sender>, pub disk_encryption_tx: tokio::sync::watch::Sender, pub print_audit_tx: tokio::sync::watch::Sender, pub clipboard_control_tx: tokio::sync::watch::Sender, pub patch_tx: tokio::sync::watch::Sender, } /// Connect to server and run the main communication loop pub async fn connect_and_run( state: &ClientState, data_rx: &mut tokio::sync::mpsc::Receiver, plugins: &PluginChannels, ) -> Result<()> { let tcp_stream = TcpStream::connect(&state.server_addr).await?; debug!("TCP connected to {}", state.server_addr); if state.use_tls { let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?; run_comm_loop(tls_stream, state, data_rx, plugins).await } else { run_comm_loop(tcp_stream, state, data_rx, plugins).await } } /// Wrap a TCP stream with TLS and certificate pinning. async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result> { let mut root_store = rustls::RootCertStore::empty(); // Load custom CA certificate if specified if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") { let ca_pem = std::fs::read(&ca_path) .map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?; let certs = rustls_pemfile::certs(&mut &ca_pem[..]) .collect::, _>>() .map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?; for cert in certs { root_store.add(cert)?; } info!("Loaded custom CA certificates from {}", ca_path); } // Always include system roots as fallback root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); // Check if skip-verify is allowed (only in CSM_DEV mode) let skip_verify = std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true") && std::env::var("CSM_DEV").is_ok(); let config = if skip_verify { warn!("TLS certificate verification DISABLED — CSM_DEV mode only!"); rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoVerifier)) .with_no_client_auth() } else { // Build standard verifier with pinning wrapper let inner = rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store)) .build() .map_err(|e| anyhow::anyhow!("Failed to build TLS verifier: {:?}", e))?; let pin_file = pin_file_path(); let pinned_hashes = load_pinned_hashes(&pin_file); let verifier = PinnedCertVerifier { inner, pin_file, pinned_hashes: Arc::new(Mutex::new(pinned_hashes)), }; rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(verifier)) .with_no_client_auth() }; let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); // Extract hostname from server_addr (strip port) let domain = server_addr.split(':').next().unwrap_or("localhost").to_string(); let server_name = rustls_pki_types::ServerName::try_from(domain.clone()) .map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?; let tls_stream = connector.connect(server_name, stream).await?; info!("TLS handshake completed with {}", domain); Ok(tls_stream) } /// Default pin file path: %PROGRAMDATA%\CSM\server_cert_pin (Windows) fn pin_file_path() -> PathBuf { if let Ok(custom) = std::env::var("CSM_TLS_PIN_FILE") { PathBuf::from(custom) } else if cfg!(target_os = "windows") { std::env::var("PROGRAMDATA") .map(|p| PathBuf::from(p).join("CSM").join("server_cert_pin")) .unwrap_or_else(|_| PathBuf::from("server_cert_pin")) } else { PathBuf::from("/var/lib/csm/server_cert_pin") } } /// Load pinned certificate hashes from file. /// Format: one hex-encoded SHA-256 hash per line. fn load_pinned_hashes(path: &PathBuf) -> Vec { match std::fs::read_to_string(path) { Ok(content) => content.lines() .map(|l| l.trim().to_string()) .filter(|l| !l.is_empty()) .collect(), Err(_) => Vec::new(), // First connection — no pin file yet } } /// Save a pinned hash to the pin file. fn save_pinned_hash(path: &PathBuf, hash: &str) { if let Some(parent) = path.parent() { let _ = std::fs::create_dir_all(parent); } let _ = std::fs::write(path, format!("{}\n", hash)); } /// Compute SHA-256 fingerprint of a DER-encoded certificate. fn cert_fingerprint(cert: &rustls_pki_types::CertificateDer) -> String { let mut hasher = Sha256::new(); hasher.update(cert.as_ref()); hex::encode(hasher.finalize()) } /// Certificate verifier with pinning support. /// On first connection (no stored pin), records the certificate fingerprint. /// On subsequent connections, verifies the fingerprint matches. #[derive(Debug)] struct PinnedCertVerifier { inner: Arc, pin_file: PathBuf, pinned_hashes: Arc>>, } impl rustls::client::danger::ServerCertVerifier for PinnedCertVerifier { fn verify_server_cert( &self, end_entity: &rustls_pki_types::CertificateDer, intermediates: &[rustls_pki_types::CertificateDer], server_name: &rustls_pki_types::ServerName, ocsp_response: &[u8], now: rustls_pki_types::UnixTime, ) -> Result { // 1. Standard PKIX verification self.inner.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?; // 2. Compute certificate fingerprint let fingerprint = cert_fingerprint(end_entity); // 3. Check against pinned hashes let mut pinned = self.pinned_hashes.lock().unwrap(); if pinned.is_empty() { // First connection — record the certificate fingerprint info!("Recording server certificate pin: {}...", &fingerprint[..16]); save_pinned_hash(&self.pin_file, &fingerprint); pinned.push(fingerprint); } else if !pinned.contains(&fingerprint) { warn!("Certificate pin mismatch! Expected one of {:?}, got {}", pinned, fingerprint); return Err(rustls::Error::General( "Server certificate does not match pinned fingerprint. Possible MITM attack.".into(), )); } Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, message: &[u8], cert: &rustls_pki_types::CertificateDer, dss: &rustls::DigitallySignedStruct, ) -> Result { self.inner.verify_tls12_signature(message, cert, dss) } fn verify_tls13_signature( &self, message: &[u8], cert: &rustls_pki_types::CertificateDer, dss: &rustls::DigitallySignedStruct, ) -> Result { self.inner.verify_tls13_signature(message, cert, dss) } fn supported_verify_schemes(&self) -> Vec { self.inner.supported_verify_schemes() } } /// Update pinned certificate hash (called when receiving TlsCertRotate). pub fn update_cert_pin(new_hash: &str) { let pin_file = pin_file_path(); let mut pinned = load_pinned_hashes(&pin_file); if !pinned.contains(&new_hash.to_string()) { pinned.push(new_hash.to_string()); // Keep only the last 2 hashes (current + rotating) while pinned.len() > 2 { pinned.remove(0); } // Write all hashes to file if let Some(parent) = pin_file.parent() { let _ = std::fs::create_dir_all(parent); } let content = pinned.iter().map(|h| h.as_str()).collect::>().join("\n"); let _ = std::fs::write(&pin_file, format!("{}\n", content)); info!("Updated certificate pin file with new hash: {}...", &new_hash[..16]); } } /// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true). #[derive(Debug)] struct NoVerifier; impl rustls::client::danger::ServerCertVerifier for NoVerifier { fn verify_server_cert( &self, _end_entity: &rustls_pki_types::CertificateDer, _intermediates: &[rustls_pki_types::CertificateDer], _server_name: &rustls_pki_types::ServerName, _ocsp_response: &[u8], _now: rustls_pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &rustls_pki_types::CertificateDer, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &rustls_pki_types::CertificateDer, _dss: &rustls::DigitallySignedStruct, ) -> Result { Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ rustls::SignatureScheme::RSA_PKCS1_SHA256, rustls::SignatureScheme::RSA_PKCS1_SHA384, rustls::SignatureScheme::RSA_PKCS1_SHA512, rustls::SignatureScheme::ECDSA_NISTP256_SHA256, rustls::SignatureScheme::ECDSA_NISTP384_SHA384, rustls::SignatureScheme::RSA_PSS_SHA256, rustls::SignatureScheme::RSA_PSS_SHA384, rustls::SignatureScheme::RSA_PSS_SHA512, rustls::SignatureScheme::ED25519, ] } } /// Main communication loop over any read+write stream async fn run_comm_loop( mut stream: S, state: &ClientState, data_rx: &mut tokio::sync::mpsc::Receiver, plugins: &PluginChannels, ) -> Result<()> where S: AsyncReadExt + AsyncWriteExt + Unpin, { // Send registration let register = RegisterRequest { device_uid: state.device_uid.clone(), hostname: hostname::get() .map(|h| h.to_string_lossy().to_string()) .unwrap_or_else(|_| "unknown".to_string()), registration_token: state.registration_token.clone(), os_version: get_os_info(), mac_address: None, }; let frame = Frame::new_json(MessageType::Register, ®ister)?; stream.write_all(&frame.encode()).await?; info!("Registration request sent"); let mut buffer = vec![0u8; 65536]; let mut read_buf = Vec::with_capacity(65536); // Clamp heartbeat interval to sane range [5, 3600] to prevent CPU spin or effective disable let heartbeat_secs = state.config.heartbeat_interval_secs.clamp(5, 3600); let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs)); heartbeat_interval.tick().await; // Skip first tick // HMAC key — set after receiving RegisterResponse let mut device_secret: Option = state.device_secret.clone(); loop { tokio::select! { // Read from server result = stream.read(&mut buffer) => { let n = result?; if n == 0 { return Err(anyhow::anyhow!("Server closed connection")); } read_buf.extend_from_slice(&buffer[..n]); // Guard against unbounded buffer growth from a malicious server if read_buf.len() > 1_048_576 { return Err(anyhow::anyhow!("Read buffer exceeded 1MB, server may be malicious")); } // Process complete frames loop { match Frame::decode(&read_buf)? { Some(frame) => { let consumed = frame.encoded_size(); read_buf.drain(..consumed); // Capture device_secret from registration response if frame.msg_type == MessageType::RegisterResponse { if let Ok(resp) = frame.decode_payload::() { device_secret = Some(resp.device_secret.clone()); crate::save_device_secret(&resp.device_secret); info!("Device secret received and persisted, HMAC enabled for heartbeats"); } } handle_server_message(frame, plugins)?; } None => break, // Incomplete frame, wait for more data } } } // Send queued data frame = data_rx.recv() => { let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?; stream.write_all(&frame.encode()).await?; } // Heartbeat _ = heartbeat_interval.tick() => { let timestamp = chrono::Utc::now().to_rfc3339(); let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, ×tamp); let heartbeat = HeartbeatPayload { device_uid: state.device_uid.clone(), timestamp, hmac: hmac_value, }; let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?; stream.write_all(&frame.encode()).await?; debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty()); } } } } fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> { match frame.msg_type { MessageType::RegisterResponse => { let resp: RegisterResponse = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?; info!("Registration accepted by server (server version: {})", resp.config.server_version); } MessageType::PolicyUpdate => { let policy: serde_json::Value = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?; info!("Received policy update: {}", policy); } MessageType::ConfigUpdate => { let update: csm_protocol::ConfigUpdateType = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid config update: {}", e))?; match update { csm_protocol::ConfigUpdateType::UpdateIntervals { heartbeat, status, asset } => { info!("Config update: intervals heartbeat={}s status={}s asset={}s", heartbeat, status, asset); } csm_protocol::ConfigUpdateType::TlsCertRotate { new_cert_hash, valid_until } => { info!("Certificate rotation: new hash={}... valid_until={}", &new_cert_hash[..16.min(new_cert_hash.len())], valid_until); update_cert_pin(&new_cert_hash); } csm_protocol::ConfigUpdateType::SelfDestruct => { warn!("Self-destruct command received (not implemented)"); } } } MessageType::TaskExecute => { warn!("Task execution requested (not yet implemented)"); } MessageType::WatermarkConfig => { let config: WatermarkConfigPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?; info!("Received watermark config: enabled={}", config.enabled); plugins.watermark_tx.send(Some(config))?; } MessageType::UsbPolicyUpdate => { let policy: UsbPolicyPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid USB policy: {}", e))?; info!("Received USB policy: type={}, enabled={}", policy.policy_type, policy.enabled); plugins.usb_policy_tx.send(Some(policy))?; } MessageType::WebFilterRuleUpdate => { let payload: serde_json::Value = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?; info!("Received web filter rules update"); let rules: Vec = payload.get("rules") .and_then(|r| serde_json::from_value(r.clone()).ok()) .unwrap_or_default(); let config = crate::web_filter::WebFilterConfig { enabled: true, rules }; plugins.web_filter_tx.send(config)?; } MessageType::SoftwareBlacklist => { let payload: serde_json::Value = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?; info!("Received software blacklist update"); let blacklist: Vec = payload.get("blacklist") .and_then(|r| serde_json::from_value(r.clone()).ok()) .unwrap_or_default(); let whitelist: Vec = payload.get("whitelist") .and_then(|r| serde_json::from_value(r.clone()).ok()) .unwrap_or_default(); let config = crate::software_blocker::SoftwareBlockerConfig { enabled: true, blacklist, whitelist, }; plugins.software_blocker_tx.send(config)?; } MessageType::PopupRules => { let payload: serde_json::Value = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?; info!("Received popup blocker rules update"); let rules: Vec = payload.get("rules") .and_then(|r| serde_json::from_value(r.clone()).ok()) .unwrap_or_default(); let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules }; plugins.popup_blocker_tx.send(config)?; } MessageType::DiskEncryptionConfig => { let config: DiskEncryptionConfigPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid disk encryption config: {}", e))?; info!("Received disk encryption config: enabled={}, interval={}s", config.enabled, config.report_interval_secs); let plugin_config = crate::disk_encryption::DiskEncryptionConfig { enabled: config.enabled, report_interval_secs: config.report_interval_secs, }; plugins.disk_encryption_tx.send(plugin_config)?; } MessageType::PluginEnable => { let payload: csm_protocol::PluginControlPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?; info!("Plugin enabled: {}", payload.plugin_name); // Route to appropriate plugin channel based on plugin_name handle_plugin_control(&payload, plugins, true)?; } MessageType::PluginDisable => { let payload: csm_protocol::PluginControlPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?; info!("Plugin disabled: {}", payload.plugin_name); handle_plugin_control(&payload, plugins, false)?; } MessageType::ClipboardRules => { let payload: csm_protocol::ClipboardRulesPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid clipboard rules: {}", e))?; info!("Received clipboard rules update: {} rules", payload.rules.len()); let config = crate::clipboard_control::ClipboardControlConfig { enabled: true, rules: payload.rules, }; plugins.clipboard_control_tx.send(config)?; } MessageType::PatchScanConfig => { let config: PatchScanConfigPayload = frame.decode_payload() .map_err(|e| anyhow::anyhow!("Invalid patch scan config: {}", e))?; info!("Received patch scan config: enabled={}, interval={}s", config.enabled, config.scan_interval_secs); let plugin_config = crate::patch::PluginConfig { enabled: config.enabled, scan_interval_secs: config.scan_interval_secs, }; plugins.patch_tx.send(plugin_config)?; } _ => { debug!("Unhandled message type: {:?}", frame.msg_type); } } Ok(()) } fn handle_plugin_control( payload: &csm_protocol::PluginControlPayload, plugins: &PluginChannels, enabled: bool, ) -> Result<()> { match payload.plugin_name.as_str() { "watermark" => { if !enabled { // Send disabled config to remove overlay plugins.watermark_tx.send(None)?; } // When enabling, server will push the actual config next } "web_filter" => { if !enabled { // Clear hosts rules on disable plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?; } // When enabling, server will push rules } "software_blocker" => { if !enabled { plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![], whitelist: vec![] })?; } } "popup_blocker" => { if !enabled { plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?; } } "usb_audit" => { if !enabled { plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?; } } "usage_timer" => { if !enabled { plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?; } } "disk_encryption" => { if !enabled { plugins.disk_encryption_tx.send(crate::disk_encryption::DiskEncryptionConfig { enabled: false, ..Default::default() })?; } } "print_audit" => { if !enabled { plugins.print_audit_tx.send(crate::print_audit::PrintAuditConfig { enabled: false, ..Default::default() })?; } } "clipboard_control" => { if !enabled { plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?; } } "patch" => { if !enabled { plugins.patch_tx.send(crate::patch::PluginConfig { enabled: false, scan_interval_secs: 43200 })?; } } _ => { warn!("Unknown plugin: {}", payload.plugin_name); } } Ok(()) } /// Compute HMAC-SHA256 for heartbeat verification. /// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}") fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String { let secret = match secret { Some(s) if !s.is_empty() => s, _ => return String::new(), }; type HmacSha256 = Hmac; let message = format!("{}\n{}", device_uid, timestamp); let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) { Ok(m) => m, Err(_) => return String::new(), }; mac.update(message.as_bytes()); hex::encode(mac.finalize().into_bytes()) } fn get_os_info() -> String { use sysinfo::System; let name = System::name().unwrap_or_else(|| "Unknown".to_string()); let version = System::os_version().unwrap_or_else(|| "Unknown".to_string()); format!("{} {}", name, version) }