Files
csm/crates/client/src/network/mod.rs
iven 60ee38a3c2 feat: 新增补丁管理和异常检测插件及相关功能
feat(protocol): 添加补丁管理和行为指标协议类型
feat(client): 实现补丁管理插件采集功能
feat(server): 添加补丁管理和异常检测API
feat(database): 新增补丁状态和异常检测相关表
feat(web): 添加补丁管理和异常检测前端页面
fix(security): 增强输入验证和防注入保护
refactor(auth): 重构认证检查逻辑
perf(service): 优化Windows服务恢复策略
style: 统一健康评分显示样式
docs: 更新知识库文档
2026-04-11 15:59:53 +08:00

599 lines
25 KiB
Rust

use anyhow::Result;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{info, debug, warn};
use csm_protocol::{Frame, MessageType, RegisterRequest, RegisterResponse, HeartbeatPayload, WatermarkConfigPayload, UsbPolicyPayload, DiskEncryptionConfigPayload, PatchScanConfigPayload};
use hmac::{Hmac, Mac};
use sha2::{Sha256, Digest};
use crate::ClientState;
/// Holds senders for all plugin config channels
pub struct PluginChannels {
pub watermark_tx: tokio::sync::watch::Sender<Option<WatermarkConfigPayload>>,
pub web_filter_tx: tokio::sync::watch::Sender<crate::web_filter::WebFilterConfig>,
pub software_blocker_tx: tokio::sync::watch::Sender<crate::software_blocker::SoftwareBlockerConfig>,
pub popup_blocker_tx: tokio::sync::watch::Sender<crate::popup_blocker::PopupBlockerConfig>,
pub usb_audit_tx: tokio::sync::watch::Sender<crate::usb_audit::UsbAuditConfig>,
pub usage_timer_tx: tokio::sync::watch::Sender<crate::usage_timer::UsageConfig>,
pub usb_policy_tx: tokio::sync::watch::Sender<Option<UsbPolicyPayload>>,
pub disk_encryption_tx: tokio::sync::watch::Sender<crate::disk_encryption::DiskEncryptionConfig>,
pub print_audit_tx: tokio::sync::watch::Sender<crate::print_audit::PrintAuditConfig>,
pub clipboard_control_tx: tokio::sync::watch::Sender<crate::clipboard_control::ClipboardControlConfig>,
pub patch_tx: tokio::sync::watch::Sender<crate::patch::PluginConfig>,
}
/// Connect to server and run the main communication loop
pub async fn connect_and_run(
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()> {
let tcp_stream = TcpStream::connect(&state.server_addr).await?;
debug!("TCP connected to {}", state.server_addr);
if state.use_tls {
let tls_stream = wrap_tls(tcp_stream, &state.server_addr).await?;
run_comm_loop(tls_stream, state, data_rx, plugins).await
} else {
run_comm_loop(tcp_stream, state, data_rx, plugins).await
}
}
/// Wrap a TCP stream with TLS and certificate pinning.
async fn wrap_tls(stream: TcpStream, server_addr: &str) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let mut root_store = rustls::RootCertStore::empty();
// Load custom CA certificate if specified
if let Ok(ca_path) = std::env::var("CSM_TLS_CA_CERT") {
let ca_pem = std::fs::read(&ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA cert {}: {}", ca_path, e))?;
let certs = rustls_pemfile::certs(&mut &ca_pem[..])
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("Failed to parse CA cert: {:?}", e))?;
for cert in certs {
root_store.add(cert)?;
}
info!("Loaded custom CA certificates from {}", ca_path);
}
// Always include system roots as fallback
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
// Check if skip-verify is allowed (only in CSM_DEV mode)
let skip_verify = std::env::var("CSM_TLS_SKIP_VERIFY").as_deref() == Ok("true")
&& std::env::var("CSM_DEV").is_ok();
let config = if skip_verify {
warn!("TLS certificate verification DISABLED — CSM_DEV mode only!");
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth()
} else {
// Build standard verifier with pinning wrapper
let inner = rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| anyhow::anyhow!("Failed to build TLS verifier: {:?}", e))?;
let pin_file = pin_file_path();
let pinned_hashes = load_pinned_hashes(&pin_file);
let verifier = PinnedCertVerifier {
inner,
pin_file,
pinned_hashes: Arc::new(Mutex::new(pinned_hashes)),
};
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
// Extract hostname from server_addr (strip port)
let domain = server_addr.split(':').next().unwrap_or("localhost").to_string();
let server_name = rustls_pki_types::ServerName::try_from(domain.clone())
.map_err(|e| anyhow::anyhow!("Invalid TLS server name '{}': {:?}", domain, e))?;
let tls_stream = connector.connect(server_name, stream).await?;
info!("TLS handshake completed with {}", domain);
Ok(tls_stream)
}
/// Default pin file path: %PROGRAMDATA%\CSM\server_cert_pin (Windows)
fn pin_file_path() -> PathBuf {
if let Ok(custom) = std::env::var("CSM_TLS_PIN_FILE") {
PathBuf::from(custom)
} else if cfg!(target_os = "windows") {
std::env::var("PROGRAMDATA")
.map(|p| PathBuf::from(p).join("CSM").join("server_cert_pin"))
.unwrap_or_else(|_| PathBuf::from("server_cert_pin"))
} else {
PathBuf::from("/var/lib/csm/server_cert_pin")
}
}
/// Load pinned certificate hashes from file.
/// Format: one hex-encoded SHA-256 hash per line.
fn load_pinned_hashes(path: &PathBuf) -> Vec<String> {
match std::fs::read_to_string(path) {
Ok(content) => content.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect(),
Err(_) => Vec::new(), // First connection — no pin file yet
}
}
/// Save a pinned hash to the pin file.
fn save_pinned_hash(path: &PathBuf, hash: &str) {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let _ = std::fs::write(path, format!("{}\n", hash));
}
/// Compute SHA-256 fingerprint of a DER-encoded certificate.
fn cert_fingerprint(cert: &rustls_pki_types::CertificateDer) -> String {
let mut hasher = Sha256::new();
hasher.update(cert.as_ref());
hex::encode(hasher.finalize())
}
/// Certificate verifier with pinning support.
/// On first connection (no stored pin), records the certificate fingerprint.
/// On subsequent connections, verifies the fingerprint matches.
#[derive(Debug)]
struct PinnedCertVerifier {
inner: Arc<rustls::client::WebPkiServerVerifier>,
pin_file: PathBuf,
pinned_hashes: Arc<Mutex<Vec<String>>>,
}
impl rustls::client::danger::ServerCertVerifier for PinnedCertVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls_pki_types::CertificateDer,
intermediates: &[rustls_pki_types::CertificateDer],
server_name: &rustls_pki_types::ServerName,
ocsp_response: &[u8],
now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
// 1. Standard PKIX verification
self.inner.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)?;
// 2. Compute certificate fingerprint
let fingerprint = cert_fingerprint(end_entity);
// 3. Check against pinned hashes
let mut pinned = self.pinned_hashes.lock().unwrap();
if pinned.is_empty() {
// First connection — record the certificate fingerprint
info!("Recording server certificate pin: {}...", &fingerprint[..16]);
save_pinned_hash(&self.pin_file, &fingerprint);
pinned.push(fingerprint);
} else if !pinned.contains(&fingerprint) {
warn!("Certificate pin mismatch! Expected one of {:?}, got {}", pinned, fingerprint);
return Err(rustls::Error::General(
"Server certificate does not match pinned fingerprint. Possible MITM attack.".into(),
));
}
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls_pki_types::CertificateDer,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
/// Update pinned certificate hash (called when receiving TlsCertRotate).
pub fn update_cert_pin(new_hash: &str) {
let pin_file = pin_file_path();
let mut pinned = load_pinned_hashes(&pin_file);
if !pinned.contains(&new_hash.to_string()) {
pinned.push(new_hash.to_string());
// Keep only the last 2 hashes (current + rotating)
while pinned.len() > 2 {
pinned.remove(0);
}
// Write all hashes to file
if let Some(parent) = pin_file.parent() {
let _ = std::fs::create_dir_all(parent);
}
let content = pinned.iter().map(|h| h.as_str()).collect::<Vec<_>>().join("\n");
let _ = std::fs::write(&pin_file, format!("{}\n", content));
info!("Updated certificate pin file with new hash: {}...", &new_hash[..16]);
}
}
/// A no-op certificate verifier for development use (CSM_TLS_SKIP_VERIFY=true).
#[derive(Debug)]
struct NoVerifier;
impl rustls::client::danger::ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &rustls_pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}
/// Main communication loop over any read+write stream
async fn run_comm_loop<S>(
mut stream: S,
state: &ClientState,
data_rx: &mut tokio::sync::mpsc::Receiver<Frame>,
plugins: &PluginChannels,
) -> Result<()>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
// Send registration
let register = RegisterRequest {
device_uid: state.device_uid.clone(),
hostname: hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()),
registration_token: state.registration_token.clone(),
os_version: get_os_info(),
mac_address: None,
};
let frame = Frame::new_json(MessageType::Register, &register)?;
stream.write_all(&frame.encode()).await?;
info!("Registration request sent");
let mut buffer = vec![0u8; 65536];
let mut read_buf = Vec::with_capacity(65536);
// Clamp heartbeat interval to sane range [5, 3600] to prevent CPU spin or effective disable
let heartbeat_secs = state.config.heartbeat_interval_secs.clamp(5, 3600);
let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(heartbeat_secs));
heartbeat_interval.tick().await; // Skip first tick
// HMAC key — set after receiving RegisterResponse
let mut device_secret: Option<String> = state.device_secret.clone();
loop {
tokio::select! {
// Read from server
result = stream.read(&mut buffer) => {
let n = result?;
if n == 0 {
return Err(anyhow::anyhow!("Server closed connection"));
}
read_buf.extend_from_slice(&buffer[..n]);
// Guard against unbounded buffer growth from a malicious server
if read_buf.len() > 1_048_576 {
return Err(anyhow::anyhow!("Read buffer exceeded 1MB, server may be malicious"));
}
// Process complete frames
loop {
match Frame::decode(&read_buf)? {
Some(frame) => {
let consumed = frame.encoded_size();
read_buf.drain(..consumed);
// Capture device_secret from registration response
if frame.msg_type == MessageType::RegisterResponse {
if let Ok(resp) = frame.decode_payload::<RegisterResponse>() {
device_secret = Some(resp.device_secret.clone());
crate::save_device_secret(&resp.device_secret);
info!("Device secret received and persisted, HMAC enabled for heartbeats");
}
}
handle_server_message(frame, plugins)?;
}
None => break, // Incomplete frame, wait for more data
}
}
}
// Send queued data
frame = data_rx.recv() => {
let frame = frame.ok_or_else(|| anyhow::anyhow!("Channel closed"))?;
stream.write_all(&frame.encode()).await?;
}
// Heartbeat
_ = heartbeat_interval.tick() => {
let timestamp = chrono::Utc::now().to_rfc3339();
let hmac_value = compute_hmac(device_secret.as_deref(), &state.device_uid, &timestamp);
let heartbeat = HeartbeatPayload {
device_uid: state.device_uid.clone(),
timestamp,
hmac: hmac_value,
};
let frame = Frame::new_json(MessageType::Heartbeat, &heartbeat)?;
stream.write_all(&frame.encode()).await?;
debug!("Heartbeat sent (hmac={})", !heartbeat.hmac.is_empty());
}
}
}
}
fn handle_server_message(frame: Frame, plugins: &PluginChannels) -> Result<()> {
match frame.msg_type {
MessageType::RegisterResponse => {
let resp: RegisterResponse = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid registration response: {}", e))?;
info!("Registration accepted by server (server version: {})", resp.config.server_version);
}
MessageType::PolicyUpdate => {
let policy: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid policy update: {}", e))?;
info!("Received policy update: {}", policy);
}
MessageType::ConfigUpdate => {
let update: csm_protocol::ConfigUpdateType = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid config update: {}", e))?;
match update {
csm_protocol::ConfigUpdateType::UpdateIntervals { heartbeat, status, asset } => {
info!("Config update: intervals heartbeat={}s status={}s asset={}s", heartbeat, status, asset);
}
csm_protocol::ConfigUpdateType::TlsCertRotate { new_cert_hash, valid_until } => {
info!("Certificate rotation: new hash={}... valid_until={}", &new_cert_hash[..16.min(new_cert_hash.len())], valid_until);
update_cert_pin(&new_cert_hash);
}
csm_protocol::ConfigUpdateType::SelfDestruct => {
warn!("Self-destruct command received (not implemented)");
}
}
}
MessageType::TaskExecute => {
warn!("Task execution requested (not yet implemented)");
}
MessageType::WatermarkConfig => {
let config: WatermarkConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid watermark config: {}", e))?;
info!("Received watermark config: enabled={}", config.enabled);
plugins.watermark_tx.send(Some(config))?;
}
MessageType::UsbPolicyUpdate => {
let policy: UsbPolicyPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid USB policy: {}", e))?;
info!("Received USB policy: type={}, enabled={}", policy.policy_type, policy.enabled);
plugins.usb_policy_tx.send(Some(policy))?;
}
MessageType::WebFilterRuleUpdate => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid web filter update: {}", e))?;
info!("Received web filter rules update");
let rules: Vec<crate::web_filter::WebFilterRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::web_filter::WebFilterConfig { enabled: true, rules };
plugins.web_filter_tx.send(config)?;
}
MessageType::SoftwareBlacklist => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid software blacklist: {}", e))?;
info!("Received software blacklist update");
let blacklist: Vec<crate::software_blocker::BlacklistEntry> = payload.get("blacklist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let whitelist: Vec<String> = payload.get("whitelist")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::software_blocker::SoftwareBlockerConfig {
enabled: true,
blacklist,
whitelist,
};
plugins.software_blocker_tx.send(config)?;
}
MessageType::PopupRules => {
let payload: serde_json::Value = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid popup rules: {}", e))?;
info!("Received popup blocker rules update");
let rules: Vec<crate::popup_blocker::PopupRule> = payload.get("rules")
.and_then(|r| serde_json::from_value(r.clone()).ok())
.unwrap_or_default();
let config = crate::popup_blocker::PopupBlockerConfig { enabled: true, rules };
plugins.popup_blocker_tx.send(config)?;
}
MessageType::DiskEncryptionConfig => {
let config: DiskEncryptionConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid disk encryption config: {}", e))?;
info!("Received disk encryption config: enabled={}, interval={}s", config.enabled, config.report_interval_secs);
let plugin_config = crate::disk_encryption::DiskEncryptionConfig {
enabled: config.enabled,
report_interval_secs: config.report_interval_secs,
};
plugins.disk_encryption_tx.send(plugin_config)?;
}
MessageType::PluginEnable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin enable: {}", e))?;
info!("Plugin enabled: {}", payload.plugin_name);
// Route to appropriate plugin channel based on plugin_name
handle_plugin_control(&payload, plugins, true)?;
}
MessageType::PluginDisable => {
let payload: csm_protocol::PluginControlPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid plugin disable: {}", e))?;
info!("Plugin disabled: {}", payload.plugin_name);
handle_plugin_control(&payload, plugins, false)?;
}
MessageType::ClipboardRules => {
let payload: csm_protocol::ClipboardRulesPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid clipboard rules: {}", e))?;
info!("Received clipboard rules update: {} rules", payload.rules.len());
let config = crate::clipboard_control::ClipboardControlConfig {
enabled: true,
rules: payload.rules,
};
plugins.clipboard_control_tx.send(config)?;
}
MessageType::PatchScanConfig => {
let config: PatchScanConfigPayload = frame.decode_payload()
.map_err(|e| anyhow::anyhow!("Invalid patch scan config: {}", e))?;
info!("Received patch scan config: enabled={}, interval={}s", config.enabled, config.scan_interval_secs);
let plugin_config = crate::patch::PluginConfig {
enabled: config.enabled,
scan_interval_secs: config.scan_interval_secs,
};
plugins.patch_tx.send(plugin_config)?;
}
_ => {
debug!("Unhandled message type: {:?}", frame.msg_type);
}
}
Ok(())
}
fn handle_plugin_control(
payload: &csm_protocol::PluginControlPayload,
plugins: &PluginChannels,
enabled: bool,
) -> Result<()> {
match payload.plugin_name.as_str() {
"watermark" => {
if !enabled {
// Send disabled config to remove overlay
plugins.watermark_tx.send(None)?;
}
// When enabling, server will push the actual config next
}
"web_filter" => {
if !enabled {
// Clear hosts rules on disable
plugins.web_filter_tx.send(crate::web_filter::WebFilterConfig { enabled: false, rules: vec![] })?;
}
// When enabling, server will push rules
}
"software_blocker" => {
if !enabled {
plugins.software_blocker_tx.send(crate::software_blocker::SoftwareBlockerConfig { enabled: false, blacklist: vec![], whitelist: vec![] })?;
}
}
"popup_blocker" => {
if !enabled {
plugins.popup_blocker_tx.send(crate::popup_blocker::PopupBlockerConfig { enabled: false, rules: vec![] })?;
}
}
"usb_audit" => {
if !enabled {
plugins.usb_audit_tx.send(crate::usb_audit::UsbAuditConfig { enabled: false, monitored_extensions: vec![] })?;
}
}
"usage_timer" => {
if !enabled {
plugins.usage_timer_tx.send(crate::usage_timer::UsageConfig { enabled: false, ..Default::default() })?;
}
}
"disk_encryption" => {
if !enabled {
plugins.disk_encryption_tx.send(crate::disk_encryption::DiskEncryptionConfig { enabled: false, ..Default::default() })?;
}
}
"print_audit" => {
if !enabled {
plugins.print_audit_tx.send(crate::print_audit::PrintAuditConfig { enabled: false, ..Default::default() })?;
}
}
"clipboard_control" => {
if !enabled {
plugins.clipboard_control_tx.send(crate::clipboard_control::ClipboardControlConfig { enabled: false, ..Default::default() })?;
}
}
"patch" => {
if !enabled {
plugins.patch_tx.send(crate::patch::PluginConfig { enabled: false, scan_interval_secs: 43200 })?;
}
}
_ => {
warn!("Unknown plugin: {}", payload.plugin_name);
}
}
Ok(())
}
/// Compute HMAC-SHA256 for heartbeat verification.
/// Format: HMAC-SHA256(device_secret, "{device_uid}\n{timestamp}")
fn compute_hmac(secret: Option<&str>, device_uid: &str, timestamp: &str) -> String {
let secret = match secret {
Some(s) if !s.is_empty() => s,
_ => return String::new(),
};
type HmacSha256 = Hmac<Sha256>;
let message = format!("{}\n{}", device_uid, timestamp);
let mut mac = match HmacSha256::new_from_slice(secret.as_bytes()) {
Ok(m) => m,
Err(_) => return String::new(),
};
mac.update(message.as_bytes());
hex::encode(mac.finalize().into_bytes())
}
fn get_os_info() -> String {
use sysinfo::System;
let name = System::name().unwrap_or_else(|| "Unknown".to_string());
let version = System::os_version().unwrap_or_else(|| "Unknown".to_string());
format!("{} {}", name, version)
}