//! Data Masking Middleware — protect sensitive business data from leaving the user's machine. //! //! Before LLM calls, replaces detected entities (company names, amounts, phone numbers) //! with deterministic tokens. After responses, the caller can restore the original entities. //! //! Priority: 90 (runs before Compaction@100 and Memory@150) use std::collections::HashMap; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, LazyLock, RwLock}; use async_trait::async_trait; use regex::Regex; use zclaw_types::{Message, Result}; use super::{AgentMiddleware, MiddlewareContext, MiddlewareDecision}; // --------------------------------------------------------------------------- // Pre-compiled regex patterns (compiled once, reused across all calls) // --------------------------------------------------------------------------- static RE_COMPANY: LazyLock = LazyLock::new(|| { Regex::new(r"[^\s]{1,20}(?:公司|厂|集团|工作室|商行|有限|股份)").unwrap() }); static RE_MONEY: LazyLock = LazyLock::new(|| { Regex::new(r"[¥¥$]\s*[\d,.]+[万亿]?元?|[\d,.]+[万亿]元").unwrap() }); static RE_PHONE: LazyLock = LazyLock::new(|| { Regex::new(r"1[3-9]\d-?\d{4}-?\d{4}").unwrap() }); static RE_EMAIL: LazyLock = LazyLock::new(|| { Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}").unwrap() }); static RE_ID_CARD: LazyLock = LazyLock::new(|| { Regex::new(r"\b\d{17}[\dXx]\b").unwrap() }); // --------------------------------------------------------------------------- // DataMasker — entity detection and token mapping // --------------------------------------------------------------------------- /// Counts entities by type for token generation. static ENTITY_COUNTER: AtomicU64 = AtomicU64::new(1); /// Detects and replaces sensitive entities with deterministic tokens. pub struct DataMasker { /// entity text → token mapping (persistent across conversations). forward: Arc>>, /// token → entity text reverse mapping (in-memory only). reverse: Arc>>, } impl DataMasker { pub fn new() -> Self { Self { forward: Arc::new(RwLock::new(HashMap::new())), reverse: Arc::new(RwLock::new(HashMap::new())), } } /// Mask all detected entities in `text`, replacing them with tokens. pub fn mask(&self, text: &str) -> Result { let entities = self.detect_entities(text); if entities.is_empty() { return Ok(text.to_string()); } let mut result = text.to_string(); for entity in entities { let token = self.get_or_create_token(&entity); // Replace all occurrences (longest entities first to avoid partial matches) result = result.replace(&entity, &token); } Ok(result) } /// Restore all tokens in `text` back to their original entities. pub fn unmask(&self, text: &str) -> Result { let reverse = self.reverse.read().map_err(|e| zclaw_types::ZclawError::IoError(std::io::Error::other(e.to_string())))?; if reverse.is_empty() { return Ok(text.to_string()); } let mut result = text.to_string(); for (token, entity) in reverse.iter() { result = result.replace(token, entity); } Ok(result) } /// Detect sensitive entities in text using regex patterns. fn detect_entities(&self, text: &str) -> Vec { let mut entities = Vec::new(); // Company names: X公司、XX集团、XX工作室 (1-20 char prefix + suffix) for cap in RE_COMPANY.find_iter(text) { entities.push(cap.as_str().to_string()); } // Money amounts: ¥50万、¥100元、$200、50万元 for cap in RE_MONEY.find_iter(text) { entities.push(cap.as_str().to_string()); } // Phone numbers: 1XX-XXXX-XXXX or 1XXXXXXXXXX for cap in RE_PHONE.find_iter(text) { entities.push(cap.as_str().to_string()); } // Email addresses for cap in RE_EMAIL.find_iter(text) { entities.push(cap.as_str().to_string()); } // ID card numbers (simplified): 18 digits for cap in RE_ID_CARD.find_iter(text) { entities.push(cap.as_str().to_string()); } // Sort by length descending to replace longest entities first entities.sort_by(|a, b| b.len().cmp(&a.len())); entities.dedup(); entities } /// Get existing token for entity or create a new one. fn get_or_create_token(&self, entity: &str) -> String { /// Recover from a poisoned RwLock by taking the inner value and re-wrapping. /// A poisoned lock only means a panic occurred while holding it — the data is still valid. fn recover_read(lock: &RwLock) -> std::sync::LockResult> { match lock.read() { Ok(guard) => Ok(guard), Err(_e) => { tracing::warn!("[DataMasker] RwLock poisoned during read, recovering"); // Poison error still gives us access to the inner guard lock.read() } } } fn recover_write(lock: &RwLock) -> std::sync::LockResult> { match lock.write() { Ok(guard) => Ok(guard), Err(_e) => { tracing::warn!("[DataMasker] RwLock poisoned during write, recovering"); lock.write() } } } // Check if already mapped { if let Ok(forward) = recover_read(&self.forward) { if let Some(token) = forward.get(entity) { return token.clone(); } } } // Create new token let counter = ENTITY_COUNTER.fetch_add(1, Ordering::Relaxed); let token = format!("__ENTITY_{}__", counter); // Store in both mappings if let Ok(mut forward) = recover_write(&self.forward) { forward.insert(entity.to_string(), token.clone()); } if let Ok(mut reverse) = recover_write(&self.reverse) { reverse.insert(token.clone(), entity.to_string()); } token } } impl Default for DataMasker { fn default() -> Self { Self::new() } } // --------------------------------------------------------------------------- // DataMaskingMiddleware — masks user messages before LLM completion // --------------------------------------------------------------------------- pub struct DataMaskingMiddleware { masker: Arc, } impl DataMaskingMiddleware { pub fn new(masker: Arc) -> Self { Self { masker } } /// Get a reference to the masker for unmasking responses externally. pub fn masker(&self) -> &Arc { &self.masker } } #[async_trait] impl AgentMiddleware for DataMaskingMiddleware { fn name(&self) -> &str { "data_masking" } fn priority(&self) -> i32 { 90 } async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result { // Mask user messages — replace sensitive entities with tokens for msg in &mut ctx.messages { if let Message::User { ref mut content } = msg { let masked = self.masker.mask(content)?; *content = masked; } } // Also mask user_input field if !ctx.user_input.is_empty() { ctx.user_input = self.masker.mask(&ctx.user_input)?; } Ok(MiddlewareDecision::Continue) } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; #[test] fn test_mask_company_name() { let masker = DataMasker::new(); let input = "A公司的订单被退了"; let masked = masker.mask(input).unwrap(); assert!(!masked.contains("A公司"), "Company name should be masked: {}", masked); assert!(masked.contains("__ENTITY_"), "Should contain token: {}", masked); let unmasked = masker.unmask(&masked).unwrap(); assert_eq!(unmasked, input, "Unmask should restore original"); } #[test] fn test_mask_consistency() { let masker = DataMasker::new(); let masked1 = masker.mask("A公司").unwrap(); let masked2 = masker.mask("A公司").unwrap(); assert_eq!(masked1, masked2, "Same entity should always get same token"); } #[test] fn test_mask_money() { let masker = DataMasker::new(); let input = "成本是¥50万"; let masked = masker.mask(input).unwrap(); assert!(!masked.contains("¥50万"), "Money should be masked: {}", masked); let unmasked = masker.unmask(&masked).unwrap(); assert_eq!(unmasked, input); } #[test] fn test_mask_phone() { let masker = DataMasker::new(); let input = "联系13812345678"; let masked = masker.mask(input).unwrap(); assert!(!masked.contains("13812345678"), "Phone should be masked: {}", masked); let unmasked = masker.unmask(&masked).unwrap(); assert_eq!(unmasked, input); } #[test] fn test_mask_email() { let masker = DataMasker::new(); let input = "发到 test@example.com 吧"; let masked = masker.mask(input).unwrap(); assert!(!masked.contains("test@example.com"), "Email should be masked: {}", masked); let unmasked = masker.unmask(&masked).unwrap(); assert_eq!(unmasked, input); } #[test] fn test_mask_no_entities() { let masker = DataMasker::new(); let input = "今天天气不错"; let masked = masker.mask(input).unwrap(); assert_eq!(masked, input, "Text without entities should pass through unchanged"); } #[test] fn test_mask_multiple_entities() { let masker = DataMasker::new(); let input = "A公司的订单花了¥50万,联系13812345678"; let masked = masker.mask(input).unwrap(); assert!(!masked.contains("A公司")); assert!(!masked.contains("¥50万")); assert!(!masked.contains("13812345678")); let unmasked = masker.unmask(&masked).unwrap(); assert_eq!(unmasked, input); } #[test] fn test_unmask_empty() { let masker = DataMasker::new(); let result = masker.unmask("hello world").unwrap(); assert_eq!(result, "hello world"); } #[test] fn test_mask_id_card() { let masker = DataMasker::new(); let input = "身份证号 110101199001011234"; let masked = masker.mask(input).unwrap(); assert!(!masked.contains("110101199001011234"), "ID card should be masked: {}", masked); let unmasked = masker.unmask(&masked).unwrap(); assert_eq!(unmasked, input); } }