feat(middleware): add DataMaskingMiddleware — sensitive entity protection (Chunk 3)
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

Priority 90 — runs before Compaction@100 and Memory@150.
Detects and replaces company names, money amounts, phone numbers,
emails, and ID card numbers with deterministic tokens (__ENTITY_N__).
External callers can restore originals via DataMasker::unmask().
This commit is contained in:
iven
2026-04-07 08:01:05 +08:00
parent deb206ec0b
commit 8aed363fc8
5 changed files with 302 additions and 0 deletions

1
Cargo.lock generated
View File

@@ -9039,6 +9039,7 @@ dependencies = [
"dirs",
"futures",
"rand 0.8.5",
"regex",
"reqwest 0.12.28",
"secrecy",
"serde",

View File

@@ -190,6 +190,14 @@ impl Kernel {
pub(crate) fn create_middleware_chain(&self) -> Option<zclaw_runtime::middleware::MiddlewareChain> {
let mut chain = zclaw_runtime::middleware::MiddlewareChain::new();
// Data masking middleware — mask sensitive entities before any other processing
{
use std::sync::Arc;
let masker = Arc::new(zclaw_runtime::middleware::data_masking::DataMasker::new());
let mw = zclaw_runtime::middleware::data_masking::DataMaskingMiddleware::new(masker);
chain.register(Arc::new(mw));
}
// Growth integration — shared VikingAdapter for memory middleware & compaction
let mut growth = zclaw_runtime::GrowthIntegration::new(self.viking.clone());
if let Some(ref driver) = self.extraction_driver {

View File

@@ -24,6 +24,7 @@ uuid = { workspace = true }
chrono = { workspace = true }
tracing = { workspace = true }
async-trait = { workspace = true }
regex = { workspace = true }
# HTTP client
reqwest = { workspace = true }

View File

@@ -267,6 +267,7 @@ impl Default for MiddlewareChain {
pub mod compaction;
pub mod dangling_tool;
pub mod data_masking;
pub mod guardrail;
pub mod loop_guard;
pub mod memory;

View File

@@ -0,0 +1,291 @@
//! Data Masking Middleware — protect sensitive business data from leaving the user's machine.
//!
//! Before LLM calls, replaces detected entities (company names, amounts, phone numbers)
//! with deterministic tokens. After responses, the caller can restore the original entities.
//!
//! Priority: 90 (runs before Compaction@100 and Memory@150)
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use regex::Regex;
use zclaw_types::{Message, Result};
use super::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
// ---------------------------------------------------------------------------
// DataMasker — entity detection and token mapping
// ---------------------------------------------------------------------------
/// Counts entities by type for token generation.
static ENTITY_COUNTER: AtomicU64 = AtomicU64::new(1);
/// Detects and replaces sensitive entities with deterministic tokens.
pub struct DataMasker {
/// entity text → token mapping (persistent across conversations).
forward: Arc<RwLock<HashMap<String, String>>>,
/// token → entity text reverse mapping (in-memory only).
reverse: Arc<RwLock<HashMap<String, String>>>,
}
impl DataMasker {
pub fn new() -> Self {
Self {
forward: Arc::new(RwLock::new(HashMap::new())),
reverse: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Mask all detected entities in `text`, replacing them with tokens.
pub fn mask(&self, text: &str) -> Result<String> {
let entities = self.detect_entities(text);
if entities.is_empty() {
return Ok(text.to_string());
}
let mut result = text.to_string();
for entity in entities {
let token = self.get_or_create_token(&entity);
// Replace all occurrences (longest entities first to avoid partial matches)
result = result.replace(&entity, &token);
}
Ok(result)
}
/// Restore all tokens in `text` back to their original entities.
pub fn unmask(&self, text: &str) -> Result<String> {
let reverse = self.reverse.read().map_err(|e| zclaw_types::ZclawError::IoError(std::io::Error::other(e.to_string())))?;
if reverse.is_empty() {
return Ok(text.to_string());
}
let mut result = text.to_string();
for (token, entity) in reverse.iter() {
result = result.replace(token, entity);
}
Ok(result)
}
/// Detect sensitive entities in text using regex patterns.
fn detect_entities(&self, text: &str) -> Vec<String> {
let mut entities = Vec::new();
// Company names: X公司、XX集团、XX工作室 (1-20 char prefix + suffix)
if let Ok(re) = Regex::new(r"[^\s]{1,20}(?:公司|厂|集团|工作室|商行|有限|股份)") {
for cap in re.find_iter(text) {
entities.push(cap.as_str().to_string());
}
}
// Money amounts: ¥50万、¥100元、$200、50万元
if let Ok(re) = Regex::new(r"[¥¥$]\s*[\d,.]+[万亿]?元?|[\d,.]+[万亿]元") {
for cap in re.find_iter(text) {
entities.push(cap.as_str().to_string());
}
}
// Phone numbers: 1XX-XXXX-XXXX or 1XXXXXXXXXX
if let Ok(re) = Regex::new(r"1[3-9]\d-?\d{4}-?\d{4}") {
for cap in re.find_iter(text) {
entities.push(cap.as_str().to_string());
}
}
// Email addresses
if let Ok(re) = Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}") {
for cap in re.find_iter(text) {
entities.push(cap.as_str().to_string());
}
}
// ID card numbers (simplified): 18 digits
if let Ok(re) = Regex::new(r"\b\d{17}[\dXx]\b") {
for cap in re.find_iter(text) {
entities.push(cap.as_str().to_string());
}
}
// Sort by length descending to replace longest entities first
entities.sort_by(|a, b| b.len().cmp(&a.len()));
entities.dedup();
entities
}
/// Get existing token for entity or create a new one.
fn get_or_create_token(&self, entity: &str) -> String {
// Check if already mapped
{
let forward = self.forward.read().unwrap();
if let Some(token) = forward.get(entity) {
return token.clone();
}
}
// Create new token
let counter = ENTITY_COUNTER.fetch_add(1, Ordering::Relaxed);
let token = format!("__ENTITY_{}__", counter);
// Store in both mappings
{
let mut forward = self.forward.write().unwrap();
forward.insert(entity.to_string(), token.clone());
}
{
let mut reverse = self.reverse.write().unwrap();
reverse.insert(token.clone(), entity.to_string());
}
token
}
}
impl Default for DataMasker {
fn default() -> Self {
Self::new()
}
}
// ---------------------------------------------------------------------------
// DataMaskingMiddleware — masks user messages before LLM completion
// ---------------------------------------------------------------------------
pub struct DataMaskingMiddleware {
masker: Arc<DataMasker>,
}
impl DataMaskingMiddleware {
pub fn new(masker: Arc<DataMasker>) -> Self {
Self { masker }
}
/// Get a reference to the masker for unmasking responses externally.
pub fn masker(&self) -> &Arc<DataMasker> {
&self.masker
}
}
#[async_trait]
impl AgentMiddleware for DataMaskingMiddleware {
fn name(&self) -> &str { "data_masking" }
fn priority(&self) -> i32 { 90 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
// Mask user messages — replace sensitive entities with tokens
for msg in &mut ctx.messages {
if let Message::User { ref mut content } = msg {
let masked = self.masker.mask(content)?;
*content = masked;
}
}
// Also mask user_input field
if !ctx.user_input.is_empty() {
ctx.user_input = self.masker.mask(&ctx.user_input)?;
}
Ok(MiddlewareDecision::Continue)
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mask_company_name() {
let masker = DataMasker::new();
let input = "A公司的订单被退了";
let masked = masker.mask(input).unwrap();
assert!(!masked.contains("A公司"), "Company name should be masked: {}", masked);
assert!(masked.contains("__ENTITY_"), "Should contain token: {}", masked);
let unmasked = masker.unmask(&masked).unwrap();
assert_eq!(unmasked, input, "Unmask should restore original");
}
#[test]
fn test_mask_consistency() {
let masker = DataMasker::new();
let masked1 = masker.mask("A公司").unwrap();
let masked2 = masker.mask("A公司").unwrap();
assert_eq!(masked1, masked2, "Same entity should always get same token");
}
#[test]
fn test_mask_money() {
let masker = DataMasker::new();
let input = "成本是¥50万";
let masked = masker.mask(input).unwrap();
assert!(!masked.contains("¥50万"), "Money should be masked: {}", masked);
let unmasked = masker.unmask(&masked).unwrap();
assert_eq!(unmasked, input);
}
#[test]
fn test_mask_phone() {
let masker = DataMasker::new();
let input = "联系13812345678";
let masked = masker.mask(input).unwrap();
assert!(!masked.contains("13812345678"), "Phone should be masked: {}", masked);
let unmasked = masker.unmask(&masked).unwrap();
assert_eq!(unmasked, input);
}
#[test]
fn test_mask_email() {
let masker = DataMasker::new();
let input = "发到 test@example.com 吧";
let masked = masker.mask(input).unwrap();
assert!(!masked.contains("test@example.com"), "Email should be masked: {}", masked);
let unmasked = masker.unmask(&masked).unwrap();
assert_eq!(unmasked, input);
}
#[test]
fn test_mask_no_entities() {
let masker = DataMasker::new();
let input = "今天天气不错";
let masked = masker.mask(input).unwrap();
assert_eq!(masked, input, "Text without entities should pass through unchanged");
}
#[test]
fn test_mask_multiple_entities() {
let masker = DataMasker::new();
let input = "A公司的订单花了¥50万联系13812345678";
let masked = masker.mask(input).unwrap();
assert!(!masked.contains("A公司"));
assert!(!masked.contains("¥50万"));
assert!(!masked.contains("13812345678"));
let unmasked = masker.unmask(&masked).unwrap();
assert_eq!(unmasked, input);
}
#[test]
fn test_unmask_empty() {
let masker = DataMasker::new();
let result = masker.unmask("hello world").unwrap();
assert_eq!(result, "hello world");
}
#[test]
fn test_mask_id_card() {
let masker = DataMasker::new();
let input = "身份证号 110101199001011234";
let masked = masker.mask(input).unwrap();
assert!(!masked.contains("110101199001011234"), "ID card should be masked: {}", masked);
let unmasked = masker.unmask(&masked).unwrap();
assert_eq!(unmasked, input);
}
}