Files
hms/crates/erp-ai/src/service/feature_flag_service.rs
iven bf37acc681 feat(ai): AI 健康管家 V2 基础设施 — 功能开关 + 角色沙箱准备 + 体征页 AI 趋势分析
- 迁移 000153: 新增 ai_feature_flags / ai_usage_daily / ai_suggestion_feedback 三张表,
  ai_tenant_configs 增加 billing_enabled 列, seed 12 个功能开关 + 2 个管理权限码
- 新增 FeatureFlagService: 5 分钟缓存 + DB 回退 + 即时更新
- VitalSignsTab 添加 AI 趋势分析按钮 (SSE 流式)
- 新增 3 个 Entity (ai_feature_flags / ai_usage_daily / ai_suggestion_feedback)
- AiState 扩展 feature_flags 字段
- 设计规格 + 讨论记录文档

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-18 22:55:40 +08:00

168 lines
4.8 KiB
Rust

use std::collections::HashMap;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::entity::ai_feature_flags;
use crate::error::AiResult;
pub struct FeatureFlagService {
db: sea_orm::DatabaseConnection,
cache: RwLock<HashMap<(Uuid, String), CacheEntry>>,
cache_ttl: std::time::Duration,
}
struct CacheEntry {
enabled: bool,
cached_at: std::time::Instant,
}
impl FeatureFlagService {
pub fn new(db: sea_orm::DatabaseConnection) -> Self {
Self {
db,
cache: RwLock::new(HashMap::new()),
cache_ttl: std::time::Duration::from_secs(300),
}
}
pub async fn is_enabled(&self, tenant_id: Uuid, feature: &str) -> bool {
let key = (tenant_id, feature.to_string());
// 查缓存
{
let cache = self.cache.read().await;
if let Some(entry) = cache.get(&key)
&& entry.cached_at.elapsed() < self.cache_ttl
{
return entry.enabled;
}
}
// 查数据库
let enabled = match self.query_db(tenant_id, feature).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(tenant_id = %tenant_id, feature = %feature, error = %e, "Feature flag query failed, defaulting to enabled");
true
}
};
// 写缓存
{
let mut cache = self.cache.write().await;
cache.insert(
key,
CacheEntry {
enabled,
cached_at: std::time::Instant::now(),
},
);
}
enabled
}
pub async fn set_enabled(
&self,
tenant_id: Uuid,
feature: &str,
enabled: bool,
updated_by: Uuid,
) -> AiResult<()> {
let existing = ai_feature_flags::Entity::find()
.filter(ai_feature_flags::Column::TenantId.eq(tenant_id))
.filter(ai_feature_flags::Column::Feature.eq(feature))
.one(&self.db)
.await?;
if let Some(model) = existing {
let mut active: ai_feature_flags::ActiveModel = model.into();
active.is_enabled = Set(enabled);
active.updated_at = Set(chrono::Utc::now());
active.updated_by = Set(Some(updated_by));
active.update(&self.db).await?;
} else {
let id = Uuid::now_v7();
let active = ai_feature_flags::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
feature: Set(feature.to_string()),
is_enabled: Set(enabled),
config: Set(None),
updated_at: Set(chrono::Utc::now()),
updated_by: Set(Some(updated_by)),
};
active.insert(&self.db).await?;
}
// 清缓存
{
let mut cache = self.cache.write().await;
cache.remove(&(tenant_id, feature.to_string()));
}
tracing::info!(tenant_id = %tenant_id, feature = %feature, enabled = enabled, "Feature flag updated");
Ok(())
}
pub async fn get_all(&self, tenant_id: Uuid) -> AiResult<Vec<FeatureFlag>> {
let rows = ai_feature_flags::Entity::find()
.filter(ai_feature_flags::Column::TenantId.eq(tenant_id))
.all(&self.db)
.await?;
Ok(rows
.into_iter()
.map(|r| FeatureFlag {
feature: r.feature,
is_enabled: r.is_enabled,
})
.collect())
}
async fn query_db(&self, tenant_id: Uuid, feature: &str) -> AiResult<bool> {
let result = ai_feature_flags::Entity::find()
.filter(ai_feature_flags::Column::TenantId.eq(tenant_id))
.filter(ai_feature_flags::Column::Feature.eq(feature))
.one(&self.db)
.await?;
// 不存在 → 默认启用
Ok(result.map(|r| r.is_enabled).unwrap_or(true))
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct FeatureFlag {
pub feature: String,
pub is_enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn feature_flag_serialization() {
let flag = FeatureFlag {
feature: "ai.chat".to_string(),
is_enabled: true,
};
let json = serde_json::to_value(&flag).unwrap();
assert_eq!(json["feature"], "ai.chat");
assert_eq!(json["is_enabled"], true);
}
#[test]
fn cache_entry_expiry() {
let entry = CacheEntry {
enabled: false,
cached_at: std::time::Instant::now() - std::time::Duration::from_secs(301),
};
assert!(entry.cached_at.elapsed() >= std::time::Duration::from_secs(300));
}
}