diff --git a/crates/erp-ai/Cargo.toml b/crates/erp-ai/Cargo.toml index a087f91..3bf8ad9 100644 --- a/crates/erp-ai/Cargo.toml +++ b/crates/erp-ai/Cargo.toml @@ -23,4 +23,5 @@ reqwest.workspace = true handlebars.workspace = true dashmap.workspace = true sha2.workspace = true +redis.workspace = true hex.workspace = true diff --git a/crates/erp-ai/src/service/cache.rs b/crates/erp-ai/src/service/cache.rs new file mode 100644 index 0000000..5ecad86 --- /dev/null +++ b/crates/erp-ai/src/service/cache.rs @@ -0,0 +1,244 @@ +use redis::AsyncCommands; +use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; +use sha2::{Digest, Sha256}; +use std::time::Duration; +use uuid::Uuid; + +use crate::entity::ai_analysis; +use crate::error::AiResult; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CacheKey { + pub tenant_id: Uuid, + pub analysis_type: String, + pub input_hash: String, + pub prompt_version: i32, +} + +impl CacheKey { + pub fn new(tenant_id: Uuid, analysis_type: &str, input: &serde_json::Value, prompt_version: i32) -> Self { + let canonical = serde_json::to_string(input).unwrap_or_default(); + let mut hasher = Sha256::new(); + hasher.update(canonical.as_bytes()); + let hash = hex::encode(hasher.finalize()); + + Self { + tenant_id, + analysis_type: analysis_type.to_string(), + input_hash: hash, + prompt_version, + } + } + + pub fn redis_key(&self) -> String { + format!( + "ai:cache:{}:{}:{}:{}", + self.tenant_id, self.analysis_type, self.input_hash, self.prompt_version + ) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CachedAnalysis { + pub analysis_id: Uuid, + pub content: String, + pub metadata: Option, + pub cached_at: chrono::DateTime, + pub is_degraded: bool, +} + +pub struct CacheService { + redis: redis::Client, + db: sea_orm::DatabaseConnection, + default_ttl: Duration, +} + +impl CacheService { + pub fn new(redis: redis::Client, db: sea_orm::DatabaseConnection, default_ttl: Duration) -> Self { + Self { redis, db, default_ttl } + } + + pub async fn get(&self, key: &CacheKey) -> AiResult> { + let redis_key = key.redis_key(); + + // L1: Redis 查询 + match self.try_redis_get(&redis_key).await { + Ok(Some(cached)) => return Ok(Some(cached)), + Ok(None) => {} + Err(e) => { + tracing::warn!(error = %e, "Redis 缓存查询失败,降级到 DB"); + } + } + + // L2: DB 查询(复用现有 SHA-256 hash 逻辑) + let db_result = ai_analysis::Entity::find() + .filter(ai_analysis::Column::TenantId.eq(key.tenant_id)) + .filter(ai_analysis::Column::InputDataHash.eq(&key.input_hash)) + .filter(ai_analysis::Column::PromptVersion.eq(key.prompt_version)) + .filter(ai_analysis::Column::Status.eq("completed")) + .filter(ai_analysis::Column::DeletedAt.is_null()) + .one(&self.db) + .await?; + + if let Some(record) = db_result { + let cached = CachedAnalysis { + analysis_id: record.id, + content: record.result_content.clone().unwrap_or_default(), + metadata: record.result_metadata.clone(), + cached_at: record.created_at, + is_degraded: false, + }; + + // 回填 Redis + if let Err(e) = self.try_redis_set(&redis_key, &cached).await { + tracing::warn!(error = %e, "Redis 缓存回填失败"); + } + + return Ok(Some(cached)); + } + + Ok(None) + } + + pub async fn set(&self, key: &CacheKey, value: &CachedAnalysis) -> AiResult<()> { + let redis_key = key.redis_key(); + + // 写 Redis + if let Err(e) = self.try_redis_set(&redis_key, value).await { + tracing::warn!(error = %e, "Redis 缓存写入失败"); + } + + Ok(()) + } + + pub async fn invalidate_tenant(&self, tenant_id: Uuid) -> AiResult { + let pattern = format!("ai:cache:{}:*", tenant_id); + match self.try_redis_scan_del(&pattern).await { + Ok(count) => Ok(count), + Err(e) => { + tracing::warn!(error = %e, "Redis 缓存失效失败"); + Ok(0) + } + } + } + + async fn try_redis_get(&self, key: &str) -> redis::RedisResult> { + let mut conn = self.redis.get_multiplexed_async_connection().await?; + let data: Option = conn.get(key).await?; + match data { + Some(json) => { + let cached: CachedAnalysis = serde_json::from_str(&json) + .map_err(|e| redis::RedisError::from(( + redis::ErrorKind::TypeError, + "反序列化失败", + e.to_string(), + )))?; + Ok(Some(cached)) + } + None => Ok(None), + } + } + + async fn try_redis_set(&self, key: &str, value: &CachedAnalysis) -> redis::RedisResult<()> { + let mut conn = self.redis.get_multiplexed_async_connection().await?; + let json = serde_json::to_string(value).map_err(|e| redis::RedisError::from(( + redis::ErrorKind::TypeError, + "序列化失败", + e.to_string(), + )))?; + let (): () = conn.set_ex(key, json, self.default_ttl.as_secs() as u64).await?; + Ok(()) + } + + async fn try_redis_scan_del(&self, pattern: &str) -> redis::RedisResult { + let mut conn = self.redis.get_multiplexed_async_connection().await?; + let mut count = 0u64; + let mut cursor: u64 = 0; + loop { + let (new_cursor, keys): (u64, Vec) = + redis::cmd("SCAN") + .arg(cursor) + .arg("MATCH") + .arg(pattern) + .arg("COUNT") + .arg(100) + .query_async(&mut conn) + .await?; + if !keys.is_empty() { + let del_count: u64 = conn.del(&keys).await?; + count += del_count; + } + cursor = new_cursor; + if cursor == 0 { + break; + } + } + Ok(count) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cache_key_format() { + let key = CacheKey { + tenant_id: Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(), + analysis_type: "lab_report".into(), + input_hash: "abc123".into(), + prompt_version: 1, + }; + assert_eq!( + key.redis_key(), + "ai:cache:00000000-0000-0000-0000-000000000001:lab_report:abc123:1" + ); + } + + #[test] + fn cache_key_from_input() { + let input = serde_json::json!({"hemoglobin": 120}); + let key = CacheKey::new( + Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(), + "trend", + &input, + 2, + ); + assert!(key.input_hash.len() == 64); + assert_eq!(key.analysis_type, "trend"); + assert_eq!(key.prompt_version, 2); + assert!(key.redis_key().starts_with("ai:cache:")); + } + + #[test] + fn cached_analysis_serde_roundtrip() { + let cached = CachedAnalysis { + analysis_id: Uuid::now_v7(), + content: "血红蛋白偏低".into(), + metadata: Some(serde_json::json!({"model": "claude"})), + cached_at: chrono::Utc::now(), + is_degraded: false, + }; + let json = serde_json::to_string(&cached).unwrap(); + let back: CachedAnalysis = serde_json::from_str(&json).unwrap(); + assert_eq!(back.content, "血红蛋白偏低"); + assert!(!back.is_degraded); + } + + #[test] + fn same_input_same_hash() { + let input = serde_json::json!({"test": 1}); + let key1 = CacheKey::new(Uuid::now_v7(), "lab", &input, 1); + let key2 = CacheKey::new(Uuid::now_v7(), "lab", &input, 1); + assert_eq!(key1.input_hash, key2.input_hash); + } + + #[test] + fn different_input_different_hash() { + let input1 = serde_json::json!({"test": 1}); + let input2 = serde_json::json!({"test": 2}); + let key1 = CacheKey::new(Uuid::now_v7(), "lab", &input1, 1); + let key2 = CacheKey::new(Uuid::now_v7(), "lab", &input2, 1); + assert_ne!(key1.input_hash, key2.input_hash); + } +} diff --git a/crates/erp-ai/src/service/mod.rs b/crates/erp-ai/src/service/mod.rs index b4b4b80..8c23d09 100644 --- a/crates/erp-ai/src/service/mod.rs +++ b/crates/erp-ai/src/service/mod.rs @@ -1,5 +1,6 @@ pub mod analysis; pub mod auto_analysis; +pub mod cache; pub mod comparison; pub mod dialysis_risk_scorer; pub mod local_rules; diff --git a/crates/erp-ai/src/state.rs b/crates/erp-ai/src/state.rs index 78c7a9c..ddbc95f 100644 --- a/crates/erp-ai/src/state.rs +++ b/crates/erp-ai/src/state.rs @@ -6,6 +6,7 @@ use sea_orm::DatabaseConnection; use crate::provider::registry::ProviderRegistry; use crate::service::analysis::AnalysisService; +use crate::service::cache::CacheService; use crate::service::prompt::PromptService; use crate::service::quota::QuotaService; use crate::service::suggestion::SuggestionService; @@ -22,4 +23,5 @@ pub struct AiState { pub health_provider: Arc, pub provider_registry: Arc, pub quota: Arc, + pub cache: Arc, } diff --git a/crates/erp-server/src/main.rs b/crates/erp-server/src/main.rs index 2fc481b..7a0b8fa 100644 --- a/crates/erp-server/src/main.rs +++ b/crates/erp-server/src/main.rs @@ -535,6 +535,12 @@ async fn main() -> anyhow::Result<()> { db.clone(), config.ai.quota_check_enabled, )); + let cache_ttl = std::time::Duration::from_secs(config.ai.cache_ttl_seconds); + let cache = std::sync::Arc::new(erp_ai::service::cache::CacheService::new( + redis_client.clone(), + db.clone(), + cache_ttl, + )); erp_ai::AiState { db: db.clone(), @@ -546,6 +552,7 @@ async fn main() -> anyhow::Result<()> { health_provider, provider_registry: registry, quota, + cache, } };