- 迁移 000153: 新增 ai_feature_flags / ai_usage_daily / ai_suggestion_feedback 三张表, ai_tenant_configs 增加 billing_enabled 列, seed 12 个功能开关 + 2 个管理权限码 - 新增 FeatureFlagService: 5 分钟缓存 + DB 回退 + 即时更新 - VitalSignsTab 添加 AI 趋势分析按钮 (SSE 流式) - 新增 3 个 Entity (ai_feature_flags / ai_usage_daily / ai_suggestion_feedback) - AiState 扩展 feature_flags 字段 - 设计规格 + 讨论记录文档 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
316 lines
11 KiB
Rust
316 lines
11 KiB
Rust
use erp_core::types::Pagination;
|
||
use futures::Stream;
|
||
use sea_orm::{
|
||
ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder,
|
||
QuerySelect, Set,
|
||
};
|
||
use sha2::{Digest, Sha256};
|
||
use std::pin::Pin;
|
||
use uuid::Uuid;
|
||
|
||
use crate::dto::{AnalysisType, GenerateRequest};
|
||
use crate::entity::ai_analysis;
|
||
use crate::error::{AiError, AiResult};
|
||
use crate::knowledge::KnowledgeSource;
|
||
use crate::prompt::PromptRenderer;
|
||
use crate::provider::registry::ProviderRegistry;
|
||
use crate::sanitization::SanitizationService;
|
||
|
||
pub struct AnalysisService {
|
||
pub provider_registry: std::sync::Arc<ProviderRegistry>,
|
||
pub sanitizer: SanitizationService,
|
||
pub renderer: PromptRenderer,
|
||
pub db: sea_orm::DatabaseConnection,
|
||
pub knowledge_source: Option<std::sync::Arc<dyn KnowledgeSource>>,
|
||
}
|
||
|
||
impl AnalysisService {
|
||
pub fn new(
|
||
provider_registry: std::sync::Arc<ProviderRegistry>,
|
||
db: sea_orm::DatabaseConnection,
|
||
) -> Self {
|
||
Self {
|
||
provider_registry,
|
||
sanitizer: SanitizationService::new(),
|
||
renderer: PromptRenderer::new(),
|
||
db,
|
||
knowledge_source: None,
|
||
}
|
||
}
|
||
|
||
pub fn with_knowledge_source(mut self, source: std::sync::Arc<dyn KnowledgeSource>) -> Self {
|
||
self.knowledge_source = Some(source);
|
||
self
|
||
}
|
||
|
||
/// 执行流式分析 — 返回 SSE 事件流
|
||
#[allow(clippy::too_many_arguments)]
|
||
pub async fn stream_analyze(
|
||
&self,
|
||
tenant_id: Uuid,
|
||
user_id: Uuid,
|
||
patient_id: Uuid,
|
||
analysis_type: AnalysisType,
|
||
source_ref: String,
|
||
system_prompt: String,
|
||
user_template: String,
|
||
sanitized_data: serde_json::Value,
|
||
model: String,
|
||
temperature: f32,
|
||
max_tokens: u32,
|
||
) -> AiResult<(
|
||
Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>,
|
||
uuid::Uuid,
|
||
String,
|
||
)> {
|
||
let analysis_id = Uuid::now_v7();
|
||
let input_hash = self.compute_hash(&sanitized_data);
|
||
|
||
// 从 config_resolver 获取 default_provider,然后从 registry 解析
|
||
let default_provider_name = crate::config_resolver::load_ai_config(tenant_id, &self.db)
|
||
.await
|
||
.default_provider;
|
||
let resolved = self
|
||
.provider_registry
|
||
.resolve(&default_provider_name)
|
||
.await
|
||
.map_err(|e| {
|
||
tracing::error!(error = %e, "无法解析 AI Provider");
|
||
AiError::ProviderUnavailable(default_provider_name.clone())
|
||
})?;
|
||
let provider_name = resolved.provider_name().to_string();
|
||
|
||
// 0. 缓存命中检查(相同输入 + prompt 版本 → 复用已有结果)
|
||
if let Some(cached) = self.find_cached(tenant_id, &input_hash, 1).await? {
|
||
tracing::info!(analysis = %cached.id, "AI 分析缓存命中,复用已有结果");
|
||
let content = cached.result_content.clone().unwrap_or_default();
|
||
let metadata = cached
|
||
.result_metadata
|
||
.clone()
|
||
.unwrap_or(serde_json::json!({}));
|
||
let stream = self.replay_cached(content, metadata);
|
||
return Ok((stream, cached.id, provider_name));
|
||
}
|
||
|
||
tracing::info!(analysis = %analysis_id, tenant = %tenant_id, r#type = %analysis_type.as_str(), "发起 AI 分析");
|
||
|
||
// 0.5 知识库上下文注入
|
||
let system_prompt = if let Some(ref ks) = self.knowledge_source {
|
||
let query = crate::knowledge::KnowledgeQuery {
|
||
tenant_id,
|
||
analysis_type: analysis_type.as_str().to_string(),
|
||
patient_context: None,
|
||
query_text: None,
|
||
};
|
||
match ks.get_context(&query).await {
|
||
Ok(ctx) if ctx.confidence > 0.0 => {
|
||
tracing::info!(
|
||
source = %ctx.source,
|
||
confidence = ctx.confidence,
|
||
"知识库上下文注入"
|
||
);
|
||
format!(
|
||
"{}\n\n=== 知识库参考 ===\n{}",
|
||
system_prompt, ctx.context_text
|
||
)
|
||
}
|
||
Ok(_) => system_prompt,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "知识库查询失败,跳过注入");
|
||
system_prompt
|
||
}
|
||
}
|
||
} else {
|
||
system_prompt
|
||
};
|
||
|
||
// 1. 渲染 Prompt
|
||
let user_prompt = self.renderer.render(&user_template, &sanitized_data)?;
|
||
|
||
// 2. 创建分析记录
|
||
self.create_analysis_record(
|
||
analysis_id,
|
||
tenant_id,
|
||
user_id,
|
||
patient_id,
|
||
analysis_type.as_str(),
|
||
&source_ref,
|
||
&input_hash,
|
||
&provider_name,
|
||
&model,
|
||
)
|
||
.await?;
|
||
|
||
// 3. 调用 AI 流式生成
|
||
let req = GenerateRequest {
|
||
system_prompt,
|
||
user_prompt,
|
||
model,
|
||
temperature,
|
||
max_tokens,
|
||
};
|
||
let stream = resolved.provider().stream_generate(req).await?;
|
||
|
||
Ok((stream, analysis_id, provider_name))
|
||
}
|
||
|
||
/// 将缓存结果构造为一次性 Stream(直接回放纯文本,不额外包装 JSON)
|
||
fn replay_cached(
|
||
&self,
|
||
content: String,
|
||
_metadata: serde_json::Value,
|
||
) -> Pin<Box<dyn Stream<Item = AiResult<String>> + Send>> {
|
||
use futures::stream;
|
||
Box::pin(stream::once(async move { Ok(content) }))
|
||
}
|
||
|
||
/// 更新分析记录为完成
|
||
pub async fn complete_analysis(
|
||
&self,
|
||
analysis_id: Uuid,
|
||
content: String,
|
||
metadata: serde_json::Value,
|
||
) -> AiResult<()> {
|
||
let entity = ai_analysis::Entity::find_by_id(analysis_id)
|
||
.one(&self.db)
|
||
.await?
|
||
.ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?;
|
||
|
||
// 缓存回放时记录已是 completed,跳过重复更新
|
||
if entity.status == "completed" {
|
||
tracing::debug!(analysis = %analysis_id, "分析已完成,跳过重复 complete");
|
||
return Ok(());
|
||
}
|
||
let mut active: ai_analysis::ActiveModel = entity.into();
|
||
active.status = Set("completed".into());
|
||
active.result_content = Set(Some(content));
|
||
active.result_metadata = Set(Some(metadata));
|
||
active.updated_at = Set(chrono::Utc::now());
|
||
active.version_lock = Set(active.version_lock.take().unwrap_or(0) + 1);
|
||
active.update(&self.db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 标记分析失败
|
||
pub async fn fail_analysis(&self, analysis_id: Uuid, error: String) -> AiResult<()> {
|
||
let entity = ai_analysis::Entity::find_by_id(analysis_id)
|
||
.one(&self.db)
|
||
.await?
|
||
.ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?;
|
||
|
||
let mut active: ai_analysis::ActiveModel = entity.into();
|
||
active.status = Set("failed".into());
|
||
active.error_message = Set(Some(error));
|
||
active.updated_at = Set(chrono::Utc::now());
|
||
active.version_lock = Set(active.version_lock.take().unwrap_or(0) + 1);
|
||
active.update(&self.db).await?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 查找缓存
|
||
pub async fn find_cached(
|
||
&self,
|
||
tenant_id: Uuid,
|
||
input_hash: &str,
|
||
prompt_version: i32,
|
||
) -> AiResult<Option<ai_analysis::Model>> {
|
||
let result = ai_analysis::Entity::find()
|
||
.filter(ai_analysis::Column::TenantId.eq(tenant_id))
|
||
.filter(ai_analysis::Column::InputDataHash.eq(input_hash))
|
||
.filter(ai_analysis::Column::PromptVersion.eq(prompt_version))
|
||
.filter(ai_analysis::Column::Status.eq("completed"))
|
||
.filter(ai_analysis::Column::DeletedAt.is_null())
|
||
.one(&self.db)
|
||
.await?;
|
||
Ok(result)
|
||
}
|
||
|
||
fn compute_hash(&self, data: &serde_json::Value) -> String {
|
||
let canonical = serde_json::to_string(data).unwrap_or_default();
|
||
let mut hasher = Sha256::new();
|
||
hasher.update(canonical.as_bytes());
|
||
hex::encode(hasher.finalize())
|
||
}
|
||
|
||
/// 分页查询分析记录
|
||
pub async fn list_analysis(
|
||
&self,
|
||
tenant_id: Uuid,
|
||
patient_id: Option<Uuid>,
|
||
analysis_type: Option<String>,
|
||
pagination: &Pagination,
|
||
) -> AiResult<(Vec<ai_analysis::Model>, u64)> {
|
||
let mut query = ai_analysis::Entity::find()
|
||
.filter(ai_analysis::Column::TenantId.eq(tenant_id))
|
||
.filter(ai_analysis::Column::DeletedAt.is_null());
|
||
|
||
if let Some(pid) = patient_id {
|
||
query = query.filter(ai_analysis::Column::PatientId.eq(pid));
|
||
}
|
||
if let Some(at) = &analysis_type {
|
||
query = query.filter(ai_analysis::Column::AnalysisType.eq(at.as_str()));
|
||
}
|
||
|
||
let total = query.clone().count(&self.db).await?;
|
||
let items = query
|
||
.order_by_desc(ai_analysis::Column::CreatedAt)
|
||
.offset(pagination.offset())
|
||
.limit(pagination.limit())
|
||
.all(&self.db)
|
||
.await?;
|
||
Ok((items, total))
|
||
}
|
||
|
||
/// 获取单条分析记录
|
||
pub async fn get_analysis(&self, id: Uuid, tenant_id: Uuid) -> AiResult<ai_analysis::Model> {
|
||
let model = ai_analysis::Entity::find_by_id(id)
|
||
.one(&self.db)
|
||
.await?
|
||
.ok_or_else(|| AiError::AnalysisNotFound(id.to_string()))?;
|
||
if model.tenant_id != tenant_id {
|
||
return Err(AiError::AnalysisNotFound(id.to_string()));
|
||
}
|
||
Ok(model)
|
||
}
|
||
|
||
#[allow(clippy::too_many_arguments)]
|
||
async fn create_analysis_record(
|
||
&self,
|
||
id: Uuid,
|
||
tenant_id: Uuid,
|
||
user_id: Uuid,
|
||
patient_id: Uuid,
|
||
analysis_type: &str,
|
||
source_ref: &str,
|
||
input_hash: &str,
|
||
_provider: &str,
|
||
model: &str,
|
||
) -> AiResult<()> {
|
||
let now = chrono::Utc::now();
|
||
let active = ai_analysis::ActiveModel {
|
||
id: Set(id),
|
||
tenant_id: Set(tenant_id),
|
||
patient_id: Set(patient_id),
|
||
analysis_type: Set(analysis_type.into()),
|
||
source_ref: Set(source_ref.into()),
|
||
prompt_id: Set(Uuid::nil()),
|
||
prompt_version: Set(1),
|
||
model_used: Set(model.into()),
|
||
input_data_hash: Set(input_hash.into()),
|
||
sanitized_input: Set(None),
|
||
result_content: Set(None),
|
||
result_metadata: Set(None),
|
||
status: Set("streaming".into()),
|
||
error_message: Set(None),
|
||
created_at: Set(now),
|
||
updated_at: Set(now),
|
||
created_by: Set(Some(user_id)),
|
||
updated_by: Set(Some(user_id)),
|
||
deleted_at: Set(None),
|
||
version_lock: Set(1),
|
||
};
|
||
active.insert(&self.db).await?;
|
||
Ok(())
|
||
}
|
||
}
|