feat(ai): 实现 CacheService 两级缓存 + 集成到 AiState
Redis TTL (L1) + DB SHA-256 hash (L2),Redis 不可用时自动降级 CacheKey 基于 tenant_id + analysis_type + input_hash + prompt_version AiState 新增 cache 字段,main.rs 注入共享 Redis Client
This commit is contained in:
@@ -23,4 +23,5 @@ reqwest.workspace = true
|
||||
handlebars.workspace = true
|
||||
dashmap.workspace = true
|
||||
sha2.workspace = true
|
||||
redis.workspace = true
|
||||
hex.workspace = true
|
||||
|
||||
244
crates/erp-ai/src/service/cache.rs
Normal file
244
crates/erp-ai/src/service/cache.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
use redis::AsyncCommands;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::time::Duration;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::entity::ai_analysis;
|
||||
use crate::error::AiResult;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CacheKey {
|
||||
pub tenant_id: Uuid,
|
||||
pub analysis_type: String,
|
||||
pub input_hash: String,
|
||||
pub prompt_version: i32,
|
||||
}
|
||||
|
||||
impl CacheKey {
|
||||
pub fn new(tenant_id: Uuid, analysis_type: &str, input: &serde_json::Value, prompt_version: i32) -> Self {
|
||||
let canonical = serde_json::to_string(input).unwrap_or_default();
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(canonical.as_bytes());
|
||||
let hash = hex::encode(hasher.finalize());
|
||||
|
||||
Self {
|
||||
tenant_id,
|
||||
analysis_type: analysis_type.to_string(),
|
||||
input_hash: hash,
|
||||
prompt_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn redis_key(&self) -> String {
|
||||
format!(
|
||||
"ai:cache:{}:{}:{}:{}",
|
||||
self.tenant_id, self.analysis_type, self.input_hash, self.prompt_version
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CachedAnalysis {
|
||||
pub analysis_id: Uuid,
|
||||
pub content: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub cached_at: chrono::DateTime<chrono::Utc>,
|
||||
pub is_degraded: bool,
|
||||
}
|
||||
|
||||
pub struct CacheService {
|
||||
redis: redis::Client,
|
||||
db: sea_orm::DatabaseConnection,
|
||||
default_ttl: Duration,
|
||||
}
|
||||
|
||||
impl CacheService {
|
||||
pub fn new(redis: redis::Client, db: sea_orm::DatabaseConnection, default_ttl: Duration) -> Self {
|
||||
Self { redis, db, default_ttl }
|
||||
}
|
||||
|
||||
pub async fn get(&self, key: &CacheKey) -> AiResult<Option<CachedAnalysis>> {
|
||||
let redis_key = key.redis_key();
|
||||
|
||||
// L1: Redis 查询
|
||||
match self.try_redis_get(&redis_key).await {
|
||||
Ok(Some(cached)) => return Ok(Some(cached)),
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Redis 缓存查询失败,降级到 DB");
|
||||
}
|
||||
}
|
||||
|
||||
// L2: DB 查询(复用现有 SHA-256 hash 逻辑)
|
||||
let db_result = ai_analysis::Entity::find()
|
||||
.filter(ai_analysis::Column::TenantId.eq(key.tenant_id))
|
||||
.filter(ai_analysis::Column::InputDataHash.eq(&key.input_hash))
|
||||
.filter(ai_analysis::Column::PromptVersion.eq(key.prompt_version))
|
||||
.filter(ai_analysis::Column::Status.eq("completed"))
|
||||
.filter(ai_analysis::Column::DeletedAt.is_null())
|
||||
.one(&self.db)
|
||||
.await?;
|
||||
|
||||
if let Some(record) = db_result {
|
||||
let cached = CachedAnalysis {
|
||||
analysis_id: record.id,
|
||||
content: record.result_content.clone().unwrap_or_default(),
|
||||
metadata: record.result_metadata.clone(),
|
||||
cached_at: record.created_at,
|
||||
is_degraded: false,
|
||||
};
|
||||
|
||||
// 回填 Redis
|
||||
if let Err(e) = self.try_redis_set(&redis_key, &cached).await {
|
||||
tracing::warn!(error = %e, "Redis 缓存回填失败");
|
||||
}
|
||||
|
||||
return Ok(Some(cached));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn set(&self, key: &CacheKey, value: &CachedAnalysis) -> AiResult<()> {
|
||||
let redis_key = key.redis_key();
|
||||
|
||||
// 写 Redis
|
||||
if let Err(e) = self.try_redis_set(&redis_key, value).await {
|
||||
tracing::warn!(error = %e, "Redis 缓存写入失败");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn invalidate_tenant(&self, tenant_id: Uuid) -> AiResult<u64> {
|
||||
let pattern = format!("ai:cache:{}:*", tenant_id);
|
||||
match self.try_redis_scan_del(&pattern).await {
|
||||
Ok(count) => Ok(count),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Redis 缓存失效失败");
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_redis_get(&self, key: &str) -> redis::RedisResult<Option<CachedAnalysis>> {
|
||||
let mut conn = self.redis.get_multiplexed_async_connection().await?;
|
||||
let data: Option<String> = conn.get(key).await?;
|
||||
match data {
|
||||
Some(json) => {
|
||||
let cached: CachedAnalysis = serde_json::from_str(&json)
|
||||
.map_err(|e| redis::RedisError::from((
|
||||
redis::ErrorKind::TypeError,
|
||||
"反序列化失败",
|
||||
e.to_string(),
|
||||
)))?;
|
||||
Ok(Some(cached))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_redis_set(&self, key: &str, value: &CachedAnalysis) -> redis::RedisResult<()> {
|
||||
let mut conn = self.redis.get_multiplexed_async_connection().await?;
|
||||
let json = serde_json::to_string(value).map_err(|e| redis::RedisError::from((
|
||||
redis::ErrorKind::TypeError,
|
||||
"序列化失败",
|
||||
e.to_string(),
|
||||
)))?;
|
||||
let (): () = conn.set_ex(key, json, self.default_ttl.as_secs() as u64).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn try_redis_scan_del(&self, pattern: &str) -> redis::RedisResult<u64> {
|
||||
let mut conn = self.redis.get_multiplexed_async_connection().await?;
|
||||
let mut count = 0u64;
|
||||
let mut cursor: u64 = 0;
|
||||
loop {
|
||||
let (new_cursor, keys): (u64, Vec<String>) =
|
||||
redis::cmd("SCAN")
|
||||
.arg(cursor)
|
||||
.arg("MATCH")
|
||||
.arg(pattern)
|
||||
.arg("COUNT")
|
||||
.arg(100)
|
||||
.query_async(&mut conn)
|
||||
.await?;
|
||||
if !keys.is_empty() {
|
||||
let del_count: u64 = conn.del(&keys).await?;
|
||||
count += del_count;
|
||||
}
|
||||
cursor = new_cursor;
|
||||
if cursor == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(count)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cache_key_format() {
|
||||
let key = CacheKey {
|
||||
tenant_id: Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
|
||||
analysis_type: "lab_report".into(),
|
||||
input_hash: "abc123".into(),
|
||||
prompt_version: 1,
|
||||
};
|
||||
assert_eq!(
|
||||
key.redis_key(),
|
||||
"ai:cache:00000000-0000-0000-0000-000000000001:lab_report:abc123:1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cache_key_from_input() {
|
||||
let input = serde_json::json!({"hemoglobin": 120});
|
||||
let key = CacheKey::new(
|
||||
Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
|
||||
"trend",
|
||||
&input,
|
||||
2,
|
||||
);
|
||||
assert!(key.input_hash.len() == 64);
|
||||
assert_eq!(key.analysis_type, "trend");
|
||||
assert_eq!(key.prompt_version, 2);
|
||||
assert!(key.redis_key().starts_with("ai:cache:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cached_analysis_serde_roundtrip() {
|
||||
let cached = CachedAnalysis {
|
||||
analysis_id: Uuid::now_v7(),
|
||||
content: "血红蛋白偏低".into(),
|
||||
metadata: Some(serde_json::json!({"model": "claude"})),
|
||||
cached_at: chrono::Utc::now(),
|
||||
is_degraded: false,
|
||||
};
|
||||
let json = serde_json::to_string(&cached).unwrap();
|
||||
let back: CachedAnalysis = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.content, "血红蛋白偏低");
|
||||
assert!(!back.is_degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_input_same_hash() {
|
||||
let input = serde_json::json!({"test": 1});
|
||||
let key1 = CacheKey::new(Uuid::now_v7(), "lab", &input, 1);
|
||||
let key2 = CacheKey::new(Uuid::now_v7(), "lab", &input, 1);
|
||||
assert_eq!(key1.input_hash, key2.input_hash);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_input_different_hash() {
|
||||
let input1 = serde_json::json!({"test": 1});
|
||||
let input2 = serde_json::json!({"test": 2});
|
||||
let key1 = CacheKey::new(Uuid::now_v7(), "lab", &input1, 1);
|
||||
let key2 = CacheKey::new(Uuid::now_v7(), "lab", &input2, 1);
|
||||
assert_ne!(key1.input_hash, key2.input_hash);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod analysis;
|
||||
pub mod auto_analysis;
|
||||
pub mod cache;
|
||||
pub mod comparison;
|
||||
pub mod dialysis_risk_scorer;
|
||||
pub mod local_rules;
|
||||
|
||||
@@ -6,6 +6,7 @@ use sea_orm::DatabaseConnection;
|
||||
|
||||
use crate::provider::registry::ProviderRegistry;
|
||||
use crate::service::analysis::AnalysisService;
|
||||
use crate::service::cache::CacheService;
|
||||
use crate::service::prompt::PromptService;
|
||||
use crate::service::quota::QuotaService;
|
||||
use crate::service::suggestion::SuggestionService;
|
||||
@@ -22,4 +23,5 @@ pub struct AiState {
|
||||
pub health_provider: Arc<dyn HealthDataProvider>,
|
||||
pub provider_registry: Arc<ProviderRegistry>,
|
||||
pub quota: Arc<QuotaService>,
|
||||
pub cache: Arc<CacheService>,
|
||||
}
|
||||
|
||||
@@ -535,6 +535,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
db.clone(),
|
||||
config.ai.quota_check_enabled,
|
||||
));
|
||||
let cache_ttl = std::time::Duration::from_secs(config.ai.cache_ttl_seconds);
|
||||
let cache = std::sync::Arc::new(erp_ai::service::cache::CacheService::new(
|
||||
redis_client.clone(),
|
||||
db.clone(),
|
||||
cache_ttl,
|
||||
));
|
||||
|
||||
erp_ai::AiState {
|
||||
db: db.clone(),
|
||||
@@ -546,6 +552,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
health_provider,
|
||||
provider_registry: registry,
|
||||
quota,
|
||||
cache,
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user