Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
- runtime: 移除未使用的 SessionId/Datelike import,修复 unused variable - intelligence: 模块级 #![allow(dead_code)] 抑制 Hermes 预留代码警告 - mcp.rs/persist.rs/nl_schedule.rs: 标注 #[allow(dead_code)] 保留接口
324 lines
11 KiB
Rust
324 lines
11 KiB
Rust
//! Data Masking Middleware — protect sensitive business data from leaving the user's machine.
|
||
//!
|
||
//! Before LLM calls, replaces detected entities (company names, amounts, phone numbers)
|
||
//! with deterministic tokens. After responses, the caller can restore the original entities.
|
||
//!
|
||
//! Priority: 90 (runs before Compaction@100 and Memory@150)
|
||
|
||
use std::collections::HashMap;
|
||
use std::sync::atomic::{AtomicU64, Ordering};
|
||
use std::sync::{Arc, LazyLock, RwLock};
|
||
|
||
use async_trait::async_trait;
|
||
use regex::Regex;
|
||
use zclaw_types::{Message, Result};
|
||
|
||
use super::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Pre-compiled regex patterns (compiled once, reused across all calls)
|
||
// ---------------------------------------------------------------------------
|
||
|
||
static RE_COMPANY: LazyLock<Regex> = LazyLock::new(|| {
|
||
Regex::new(r"[^\s]{1,20}(?:公司|厂|集团|工作室|商行|有限|股份)").unwrap()
|
||
});
|
||
static RE_MONEY: LazyLock<Regex> = LazyLock::new(|| {
|
||
Regex::new(r"[¥¥$]\s*[\d,.]+[万亿]?元?|[\d,.]+[万亿]元").unwrap()
|
||
});
|
||
static RE_PHONE: LazyLock<Regex> = LazyLock::new(|| {
|
||
Regex::new(r"1[3-9]\d-?\d{4}-?\d{4}").unwrap()
|
||
});
|
||
static RE_EMAIL: LazyLock<Regex> = LazyLock::new(|| {
|
||
Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}").unwrap()
|
||
});
|
||
static RE_ID_CARD: LazyLock<Regex> = LazyLock::new(|| {
|
||
Regex::new(r"\b\d{17}[\dXx]\b").unwrap()
|
||
});
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// DataMasker — entity detection and token mapping
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Counts entities by type for token generation.
|
||
static ENTITY_COUNTER: AtomicU64 = AtomicU64::new(1);
|
||
|
||
/// Detects and replaces sensitive entities with deterministic tokens.
|
||
pub struct DataMasker {
|
||
/// entity text → token mapping (persistent across conversations).
|
||
forward: Arc<RwLock<HashMap<String, String>>>,
|
||
/// token → entity text reverse mapping (in-memory only).
|
||
reverse: Arc<RwLock<HashMap<String, String>>>,
|
||
}
|
||
|
||
impl DataMasker {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
forward: Arc::new(RwLock::new(HashMap::new())),
|
||
reverse: Arc::new(RwLock::new(HashMap::new())),
|
||
}
|
||
}
|
||
|
||
/// Mask all detected entities in `text`, replacing them with tokens.
|
||
pub fn mask(&self, text: &str) -> Result<String> {
|
||
let entities = self.detect_entities(text);
|
||
if entities.is_empty() {
|
||
return Ok(text.to_string());
|
||
}
|
||
|
||
let mut result = text.to_string();
|
||
for entity in entities {
|
||
let token = self.get_or_create_token(&entity);
|
||
// Replace all occurrences (longest entities first to avoid partial matches)
|
||
result = result.replace(&entity, &token);
|
||
}
|
||
Ok(result)
|
||
}
|
||
|
||
/// Restore all tokens in `text` back to their original entities.
|
||
pub fn unmask(&self, text: &str) -> Result<String> {
|
||
let reverse = self.reverse.read().map_err(|e| zclaw_types::ZclawError::IoError(std::io::Error::other(e.to_string())))?;
|
||
if reverse.is_empty() {
|
||
return Ok(text.to_string());
|
||
}
|
||
|
||
let mut result = text.to_string();
|
||
for (token, entity) in reverse.iter() {
|
||
result = result.replace(token, entity);
|
||
}
|
||
Ok(result)
|
||
}
|
||
|
||
/// Detect sensitive entities in text using regex patterns.
|
||
fn detect_entities(&self, text: &str) -> Vec<String> {
|
||
let mut entities = Vec::new();
|
||
|
||
// Company names: X公司、XX集团、XX工作室 (1-20 char prefix + suffix)
|
||
for cap in RE_COMPANY.find_iter(text) {
|
||
entities.push(cap.as_str().to_string());
|
||
}
|
||
|
||
// Money amounts: ¥50万、¥100元、$200、50万元
|
||
for cap in RE_MONEY.find_iter(text) {
|
||
entities.push(cap.as_str().to_string());
|
||
}
|
||
|
||
// Phone numbers: 1XX-XXXX-XXXX or 1XXXXXXXXXX
|
||
for cap in RE_PHONE.find_iter(text) {
|
||
entities.push(cap.as_str().to_string());
|
||
}
|
||
|
||
// Email addresses
|
||
for cap in RE_EMAIL.find_iter(text) {
|
||
entities.push(cap.as_str().to_string());
|
||
}
|
||
|
||
// ID card numbers (simplified): 18 digits
|
||
for cap in RE_ID_CARD.find_iter(text) {
|
||
entities.push(cap.as_str().to_string());
|
||
}
|
||
|
||
// Sort by length descending to replace longest entities first
|
||
entities.sort_by(|a, b| b.len().cmp(&a.len()));
|
||
entities.dedup();
|
||
entities
|
||
}
|
||
|
||
/// Get existing token for entity or create a new one.
|
||
fn get_or_create_token(&self, entity: &str) -> String {
|
||
/// Recover from a poisoned RwLock by taking the inner value and re-wrapping.
|
||
/// A poisoned lock only means a panic occurred while holding it — the data is still valid.
|
||
fn recover_read<T>(lock: &RwLock<T>) -> std::sync::LockResult<std::sync::RwLockReadGuard<'_, T>> {
|
||
match lock.read() {
|
||
Ok(guard) => Ok(guard),
|
||
Err(_e) => {
|
||
tracing::warn!("[DataMasker] RwLock poisoned during read, recovering");
|
||
// Poison error still gives us access to the inner guard
|
||
lock.read()
|
||
}
|
||
}
|
||
}
|
||
|
||
fn recover_write<T>(lock: &RwLock<T>) -> std::sync::LockResult<std::sync::RwLockWriteGuard<'_, T>> {
|
||
match lock.write() {
|
||
Ok(guard) => Ok(guard),
|
||
Err(_e) => {
|
||
tracing::warn!("[DataMasker] RwLock poisoned during write, recovering");
|
||
lock.write()
|
||
}
|
||
}
|
||
}
|
||
|
||
// Check if already mapped
|
||
{
|
||
if let Ok(forward) = recover_read(&self.forward) {
|
||
if let Some(token) = forward.get(entity) {
|
||
return token.clone();
|
||
}
|
||
}
|
||
}
|
||
|
||
// Create new token
|
||
let counter = ENTITY_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||
let token = format!("__ENTITY_{}__", counter);
|
||
|
||
// Store in both mappings
|
||
if let Ok(mut forward) = recover_write(&self.forward) {
|
||
forward.insert(entity.to_string(), token.clone());
|
||
}
|
||
if let Ok(mut reverse) = recover_write(&self.reverse) {
|
||
reverse.insert(token.clone(), entity.to_string());
|
||
}
|
||
|
||
token
|
||
}
|
||
}
|
||
|
||
impl Default for DataMasker {
|
||
fn default() -> Self {
|
||
Self::new()
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// DataMaskingMiddleware — masks user messages before LLM completion
|
||
// ---------------------------------------------------------------------------
|
||
|
||
pub struct DataMaskingMiddleware {
|
||
masker: Arc<DataMasker>,
|
||
}
|
||
|
||
impl DataMaskingMiddleware {
|
||
pub fn new(masker: Arc<DataMasker>) -> Self {
|
||
Self { masker }
|
||
}
|
||
|
||
/// Get a reference to the masker for unmasking responses externally.
|
||
pub fn masker(&self) -> &Arc<DataMasker> {
|
||
&self.masker
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl AgentMiddleware for DataMaskingMiddleware {
|
||
fn name(&self) -> &str { "data_masking" }
|
||
fn priority(&self) -> i32 { 90 }
|
||
|
||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||
// Mask user messages — replace sensitive entities with tokens
|
||
for msg in &mut ctx.messages {
|
||
if let Message::User { ref mut content } = msg {
|
||
let masked = self.masker.mask(content)?;
|
||
*content = masked;
|
||
}
|
||
}
|
||
|
||
// Also mask user_input field
|
||
if !ctx.user_input.is_empty() {
|
||
ctx.user_input = self.masker.mask(&ctx.user_input)?;
|
||
}
|
||
|
||
Ok(MiddlewareDecision::Continue)
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_mask_company_name() {
|
||
let masker = DataMasker::new();
|
||
let input = "A公司的订单被退了";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert!(!masked.contains("A公司"), "Company name should be masked: {}", masked);
|
||
assert!(masked.contains("__ENTITY_"), "Should contain token: {}", masked);
|
||
|
||
let unmasked = masker.unmask(&masked).unwrap();
|
||
assert_eq!(unmasked, input, "Unmask should restore original");
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_consistency() {
|
||
let masker = DataMasker::new();
|
||
let masked1 = masker.mask("A公司").unwrap();
|
||
let masked2 = masker.mask("A公司").unwrap();
|
||
assert_eq!(masked1, masked2, "Same entity should always get same token");
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_money() {
|
||
let masker = DataMasker::new();
|
||
let input = "成本是¥50万";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert!(!masked.contains("¥50万"), "Money should be masked: {}", masked);
|
||
|
||
let unmasked = masker.unmask(&masked).unwrap();
|
||
assert_eq!(unmasked, input);
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_phone() {
|
||
let masker = DataMasker::new();
|
||
let input = "联系13812345678";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert!(!masked.contains("13812345678"), "Phone should be masked: {}", masked);
|
||
|
||
let unmasked = masker.unmask(&masked).unwrap();
|
||
assert_eq!(unmasked, input);
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_email() {
|
||
let masker = DataMasker::new();
|
||
let input = "发到 test@example.com 吧";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert!(!masked.contains("test@example.com"), "Email should be masked: {}", masked);
|
||
|
||
let unmasked = masker.unmask(&masked).unwrap();
|
||
assert_eq!(unmasked, input);
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_no_entities() {
|
||
let masker = DataMasker::new();
|
||
let input = "今天天气不错";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert_eq!(masked, input, "Text without entities should pass through unchanged");
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_multiple_entities() {
|
||
let masker = DataMasker::new();
|
||
let input = "A公司的订单花了¥50万,联系13812345678";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert!(!masked.contains("A公司"));
|
||
assert!(!masked.contains("¥50万"));
|
||
assert!(!masked.contains("13812345678"));
|
||
|
||
let unmasked = masker.unmask(&masked).unwrap();
|
||
assert_eq!(unmasked, input);
|
||
}
|
||
|
||
#[test]
|
||
fn test_unmask_empty() {
|
||
let masker = DataMasker::new();
|
||
let result = masker.unmask("hello world").unwrap();
|
||
assert_eq!(result, "hello world");
|
||
}
|
||
|
||
#[test]
|
||
fn test_mask_id_card() {
|
||
let masker = DataMasker::new();
|
||
let input = "身份证号 110101199001011234";
|
||
let masked = masker.mask(input).unwrap();
|
||
assert!(!masked.contains("110101199001011234"), "ID card should be masked: {}", masked);
|
||
|
||
let unmasked = masker.unmask(&masked).unwrap();
|
||
assert_eq!(unmasked, input);
|
||
}
|
||
}
|