feat(ai): 实现 CacheService 两级缓存 + 集成到 AiState

Redis TTL (L1) + DB SHA-256 hash (L2),Redis 不可用时自动降级
CacheKey 基于 tenant_id + analysis_type + input_hash + prompt_version
AiState 新增 cache 字段,main.rs 注入共享 Redis Client
This commit is contained in:
iven
2026-05-05 15:33:58 +08:00
parent 50b9e8d683
commit c268229311
5 changed files with 255 additions and 0 deletions

View File

@@ -23,4 +23,5 @@ reqwest.workspace = true
handlebars.workspace = true
dashmap.workspace = true
sha2.workspace = true
redis.workspace = true
hex.workspace = true

View File

@@ -0,0 +1,244 @@
use redis::AsyncCommands;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use sha2::{Digest, Sha256};
use std::time::Duration;
use uuid::Uuid;
use crate::entity::ai_analysis;
use crate::error::AiResult;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CacheKey {
pub tenant_id: Uuid,
pub analysis_type: String,
pub input_hash: String,
pub prompt_version: i32,
}
impl CacheKey {
pub fn new(tenant_id: Uuid, analysis_type: &str, input: &serde_json::Value, prompt_version: i32) -> Self {
let canonical = serde_json::to_string(input).unwrap_or_default();
let mut hasher = Sha256::new();
hasher.update(canonical.as_bytes());
let hash = hex::encode(hasher.finalize());
Self {
tenant_id,
analysis_type: analysis_type.to_string(),
input_hash: hash,
prompt_version,
}
}
pub fn redis_key(&self) -> String {
format!(
"ai:cache:{}:{}:{}:{}",
self.tenant_id, self.analysis_type, self.input_hash, self.prompt_version
)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CachedAnalysis {
pub analysis_id: Uuid,
pub content: String,
pub metadata: Option<serde_json::Value>,
pub cached_at: chrono::DateTime<chrono::Utc>,
pub is_degraded: bool,
}
pub struct CacheService {
redis: redis::Client,
db: sea_orm::DatabaseConnection,
default_ttl: Duration,
}
impl CacheService {
pub fn new(redis: redis::Client, db: sea_orm::DatabaseConnection, default_ttl: Duration) -> Self {
Self { redis, db, default_ttl }
}
pub async fn get(&self, key: &CacheKey) -> AiResult<Option<CachedAnalysis>> {
let redis_key = key.redis_key();
// L1: Redis 查询
match self.try_redis_get(&redis_key).await {
Ok(Some(cached)) => return Ok(Some(cached)),
Ok(None) => {}
Err(e) => {
tracing::warn!(error = %e, "Redis 缓存查询失败,降级到 DB");
}
}
// L2: DB 查询(复用现有 SHA-256 hash 逻辑)
let db_result = ai_analysis::Entity::find()
.filter(ai_analysis::Column::TenantId.eq(key.tenant_id))
.filter(ai_analysis::Column::InputDataHash.eq(&key.input_hash))
.filter(ai_analysis::Column::PromptVersion.eq(key.prompt_version))
.filter(ai_analysis::Column::Status.eq("completed"))
.filter(ai_analysis::Column::DeletedAt.is_null())
.one(&self.db)
.await?;
if let Some(record) = db_result {
let cached = CachedAnalysis {
analysis_id: record.id,
content: record.result_content.clone().unwrap_or_default(),
metadata: record.result_metadata.clone(),
cached_at: record.created_at,
is_degraded: false,
};
// 回填 Redis
if let Err(e) = self.try_redis_set(&redis_key, &cached).await {
tracing::warn!(error = %e, "Redis 缓存回填失败");
}
return Ok(Some(cached));
}
Ok(None)
}
pub async fn set(&self, key: &CacheKey, value: &CachedAnalysis) -> AiResult<()> {
let redis_key = key.redis_key();
// 写 Redis
if let Err(e) = self.try_redis_set(&redis_key, value).await {
tracing::warn!(error = %e, "Redis 缓存写入失败");
}
Ok(())
}
pub async fn invalidate_tenant(&self, tenant_id: Uuid) -> AiResult<u64> {
let pattern = format!("ai:cache:{}:*", tenant_id);
match self.try_redis_scan_del(&pattern).await {
Ok(count) => Ok(count),
Err(e) => {
tracing::warn!(error = %e, "Redis 缓存失效失败");
Ok(0)
}
}
}
async fn try_redis_get(&self, key: &str) -> redis::RedisResult<Option<CachedAnalysis>> {
let mut conn = self.redis.get_multiplexed_async_connection().await?;
let data: Option<String> = conn.get(key).await?;
match data {
Some(json) => {
let cached: CachedAnalysis = serde_json::from_str(&json)
.map_err(|e| redis::RedisError::from((
redis::ErrorKind::TypeError,
"反序列化失败",
e.to_string(),
)))?;
Ok(Some(cached))
}
None => Ok(None),
}
}
async fn try_redis_set(&self, key: &str, value: &CachedAnalysis) -> redis::RedisResult<()> {
let mut conn = self.redis.get_multiplexed_async_connection().await?;
let json = serde_json::to_string(value).map_err(|e| redis::RedisError::from((
redis::ErrorKind::TypeError,
"序列化失败",
e.to_string(),
)))?;
let (): () = conn.set_ex(key, json, self.default_ttl.as_secs() as u64).await?;
Ok(())
}
async fn try_redis_scan_del(&self, pattern: &str) -> redis::RedisResult<u64> {
let mut conn = self.redis.get_multiplexed_async_connection().await?;
let mut count = 0u64;
let mut cursor: u64 = 0;
loop {
let (new_cursor, keys): (u64, Vec<String>) =
redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await?;
if !keys.is_empty() {
let del_count: u64 = conn.del(&keys).await?;
count += del_count;
}
cursor = new_cursor;
if cursor == 0 {
break;
}
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_key_format() {
let key = CacheKey {
tenant_id: Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
analysis_type: "lab_report".into(),
input_hash: "abc123".into(),
prompt_version: 1,
};
assert_eq!(
key.redis_key(),
"ai:cache:00000000-0000-0000-0000-000000000001:lab_report:abc123:1"
);
}
#[test]
fn cache_key_from_input() {
let input = serde_json::json!({"hemoglobin": 120});
let key = CacheKey::new(
Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
"trend",
&input,
2,
);
assert!(key.input_hash.len() == 64);
assert_eq!(key.analysis_type, "trend");
assert_eq!(key.prompt_version, 2);
assert!(key.redis_key().starts_with("ai:cache:"));
}
#[test]
fn cached_analysis_serde_roundtrip() {
let cached = CachedAnalysis {
analysis_id: Uuid::now_v7(),
content: "血红蛋白偏低".into(),
metadata: Some(serde_json::json!({"model": "claude"})),
cached_at: chrono::Utc::now(),
is_degraded: false,
};
let json = serde_json::to_string(&cached).unwrap();
let back: CachedAnalysis = serde_json::from_str(&json).unwrap();
assert_eq!(back.content, "血红蛋白偏低");
assert!(!back.is_degraded);
}
#[test]
fn same_input_same_hash() {
let input = serde_json::json!({"test": 1});
let key1 = CacheKey::new(Uuid::now_v7(), "lab", &input, 1);
let key2 = CacheKey::new(Uuid::now_v7(), "lab", &input, 1);
assert_eq!(key1.input_hash, key2.input_hash);
}
#[test]
fn different_input_different_hash() {
let input1 = serde_json::json!({"test": 1});
let input2 = serde_json::json!({"test": 2});
let key1 = CacheKey::new(Uuid::now_v7(), "lab", &input1, 1);
let key2 = CacheKey::new(Uuid::now_v7(), "lab", &input2, 1);
assert_ne!(key1.input_hash, key2.input_hash);
}
}

View File

@@ -1,5 +1,6 @@
pub mod analysis;
pub mod auto_analysis;
pub mod cache;
pub mod comparison;
pub mod dialysis_risk_scorer;
pub mod local_rules;

View File

@@ -6,6 +6,7 @@ use sea_orm::DatabaseConnection;
use crate::provider::registry::ProviderRegistry;
use crate::service::analysis::AnalysisService;
use crate::service::cache::CacheService;
use crate::service::prompt::PromptService;
use crate::service::quota::QuotaService;
use crate::service::suggestion::SuggestionService;
@@ -22,4 +23,5 @@ pub struct AiState {
pub health_provider: Arc<dyn HealthDataProvider>,
pub provider_registry: Arc<ProviderRegistry>,
pub quota: Arc<QuotaService>,
pub cache: Arc<CacheService>,
}

View File

@@ -535,6 +535,12 @@ async fn main() -> anyhow::Result<()> {
db.clone(),
config.ai.quota_check_enabled,
));
let cache_ttl = std::time::Duration::from_secs(config.ai.cache_ttl_seconds);
let cache = std::sync::Arc::new(erp_ai::service::cache::CacheService::new(
redis_client.clone(),
db.clone(),
cache_ttl,
));
erp_ai::AiState {
db: db.clone(),
@@ -546,6 +552,7 @@ async fn main() -> anyhow::Result<()> {
health_provider,
provider_registry: registry,
quota,
cache,
}
};