Files
hms/crates/erp-ai/src/service/analysis.rs
iven 1f91dcc5cc
Some checks failed
CI / rust-check (push) Has been cancelled
CI / rust-test (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled
fix(ai): 修复分析结果 JSON 嵌套 bug
- replay_cached 直接回放纯文本,不再包装 JSON 壳
- complete_analysis 跳过已完成的记录,防止缓存命中时覆写
- 前端 AnalysisContent 增加 extractPlainText 递归解析 JSON
2026-05-05 19:45:36 +08:00

291 lines
9.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use futures::Stream;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set};
use sha2::{Digest, Sha256};
use std::pin::Pin;
use uuid::Uuid;
use erp_core::types::Pagination;
use crate::dto::{AnalysisType, GenerateRequest};
use crate::entity::ai_analysis;
use crate::error::{AiError, AiResult};
use crate::knowledge::KnowledgeSource;
use crate::prompt::PromptRenderer;
use crate::provider::AiProvider;
use crate::sanitization::SanitizationService;
pub struct AnalysisService {
pub provider: Box<dyn AiProvider>,
pub sanitizer: SanitizationService,
pub renderer: PromptRenderer,
pub db: sea_orm::DatabaseConnection,
pub knowledge_source: Option<std::sync::Arc<dyn KnowledgeSource>>,
}
impl AnalysisService {
pub fn new(provider: Box<dyn AiProvider>, db: sea_orm::DatabaseConnection) -> Self {
Self {
provider,
sanitizer: SanitizationService::new(),
renderer: PromptRenderer::new(),
db,
knowledge_source: None,
}
}
pub fn with_knowledge_source(mut self, source: std::sync::Arc<dyn KnowledgeSource>) -> Self {
self.knowledge_source = Some(source);
self
}
/// 执行流式分析 — 返回 SSE 事件流
pub async fn stream_analyze(
&self,
tenant_id: Uuid,
user_id: Uuid,
patient_id: Uuid,
analysis_type: AnalysisType,
source_ref: String,
system_prompt: String,
user_template: String,
sanitized_data: serde_json::Value,
model: String,
temperature: f32,
max_tokens: u32,
) -> AiResult<(
Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>,
uuid::Uuid,
String,
)> {
let analysis_id = Uuid::now_v7();
let input_hash = self.compute_hash(&sanitized_data);
let provider_name = self.provider.name().to_string();
// 0. 缓存命中检查(相同输入 + prompt 版本 → 复用已有结果)
if let Some(cached) = self.find_cached(tenant_id, &input_hash, 1).await? {
tracing::info!(analysis = %cached.id, "AI 分析缓存命中,复用已有结果");
let content = cached.result_content.clone().unwrap_or_default();
let metadata = cached.result_metadata.clone().unwrap_or(serde_json::json!({}));
let stream = self.replay_cached(content, metadata);
return Ok((stream, cached.id, provider_name));
}
tracing::info!(analysis = %analysis_id, tenant = %tenant_id, r#type = %analysis_type.as_str(), "发起 AI 分析");
// 0.5 知识库上下文注入
let system_prompt = if let Some(ref ks) = self.knowledge_source {
let query = crate::knowledge::KnowledgeQuery {
tenant_id,
analysis_type: analysis_type.as_str().to_string(),
patient_context: None,
query_text: None,
};
match ks.get_context(&query).await {
Ok(ctx) if ctx.confidence > 0.0 => {
tracing::info!(
source = %ctx.source,
confidence = ctx.confidence,
"知识库上下文注入"
);
format!("{}\n\n=== 知识库参考 ===\n{}", system_prompt, ctx.context_text)
}
Ok(_) => system_prompt,
Err(e) => {
tracing::warn!(error = %e, "知识库查询失败,跳过注入");
system_prompt
}
}
} else {
system_prompt
};
// 1. 渲染 Prompt
let user_prompt = self.renderer.render(&user_template, &sanitized_data)?;
// 2. 创建分析记录
self.create_analysis_record(
analysis_id,
tenant_id,
user_id,
patient_id,
analysis_type.as_str(),
&source_ref,
&input_hash,
&provider_name,
&model,
)
.await?;
// 3. 调用 AI 流式生成
let req = GenerateRequest {
system_prompt,
user_prompt,
model,
temperature,
max_tokens,
};
let stream = self.provider.stream_generate(req).await?;
Ok((stream, analysis_id, provider_name))
}
/// 将缓存结果构造为一次性 Stream直接回放纯文本不额外包装 JSON
fn replay_cached(
&self,
content: String,
_metadata: serde_json::Value,
) -> Pin<Box<dyn Stream<Item = AiResult<String>> + Send>> {
use futures::stream;
Box::pin(stream::once(async move { Ok(content) }))
}
/// 更新分析记录为完成
pub async fn complete_analysis(
&self,
analysis_id: Uuid,
content: String,
metadata: serde_json::Value,
) -> AiResult<()> {
let entity = ai_analysis::Entity::find_by_id(analysis_id)
.one(&self.db)
.await?
.ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?;
// 缓存回放时记录已是 completed跳过重复更新
if entity.status == "completed" {
tracing::debug!(analysis = %analysis_id, "分析已完成,跳过重复 complete");
return Ok(());
}
let mut active: ai_analysis::ActiveModel = entity.into();
active.status = Set("completed".into());
active.result_content = Set(Some(content));
active.result_metadata = Set(Some(metadata));
active.updated_at = Set(chrono::Utc::now());
active.update(&self.db).await?;
Ok(())
}
/// 标记分析失败
pub async fn fail_analysis(&self, analysis_id: Uuid, error: String) -> AiResult<()> {
let entity = ai_analysis::Entity::find_by_id(analysis_id)
.one(&self.db)
.await?
.ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?;
let mut active: ai_analysis::ActiveModel = entity.into();
active.status = Set("failed".into());
active.error_message = Set(Some(error));
active.updated_at = Set(chrono::Utc::now());
active.update(&self.db).await?;
Ok(())
}
/// 查找缓存
pub async fn find_cached(
&self,
tenant_id: Uuid,
input_hash: &str,
prompt_version: i32,
) -> AiResult<Option<ai_analysis::Model>> {
let result = ai_analysis::Entity::find()
.filter(ai_analysis::Column::TenantId.eq(tenant_id))
.filter(ai_analysis::Column::InputDataHash.eq(input_hash))
.filter(ai_analysis::Column::PromptVersion.eq(prompt_version))
.filter(ai_analysis::Column::Status.eq("completed"))
.filter(ai_analysis::Column::DeletedAt.is_null())
.one(&self.db)
.await?;
Ok(result)
}
fn compute_hash(&self, data: &serde_json::Value) -> String {
let canonical = serde_json::to_string(data).unwrap_or_default();
let mut hasher = Sha256::new();
hasher.update(canonical.as_bytes());
hex::encode(hasher.finalize())
}
/// 分页查询分析记录
pub async fn list_analysis(
&self,
tenant_id: Uuid,
patient_id: Option<Uuid>,
analysis_type: Option<String>,
pagination: &Pagination,
) -> AiResult<(Vec<ai_analysis::Model>, u64)> {
let mut query = ai_analysis::Entity::find()
.filter(ai_analysis::Column::TenantId.eq(tenant_id))
.filter(ai_analysis::Column::DeletedAt.is_null());
if let Some(pid) = patient_id {
query = query.filter(ai_analysis::Column::PatientId.eq(pid));
}
if let Some(at) = &analysis_type {
query = query.filter(ai_analysis::Column::AnalysisType.eq(at.as_str()));
}
let total = query.clone().count(&self.db).await?;
let items = query
.order_by_desc(ai_analysis::Column::CreatedAt)
.offset(pagination.offset())
.limit(pagination.limit())
.all(&self.db)
.await?;
Ok((items, total))
}
/// 获取单条分析记录
pub async fn get_analysis(
&self,
id: Uuid,
tenant_id: Uuid,
) -> AiResult<ai_analysis::Model> {
let model = ai_analysis::Entity::find_by_id(id)
.one(&self.db)
.await?
.ok_or_else(|| AiError::AnalysisNotFound(id.to_string()))?;
if model.tenant_id != tenant_id {
return Err(AiError::AnalysisNotFound(id.to_string()));
}
Ok(model)
}
async fn create_analysis_record(
&self,
id: Uuid,
tenant_id: Uuid,
user_id: Uuid,
patient_id: Uuid,
analysis_type: &str,
source_ref: &str,
input_hash: &str,
_provider: &str,
model: &str,
) -> AiResult<()> {
let now = chrono::Utc::now();
let active = ai_analysis::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
patient_id: Set(patient_id),
analysis_type: Set(analysis_type.into()),
source_ref: Set(source_ref.into()),
prompt_id: Set(Uuid::nil()),
prompt_version: Set(1),
model_used: Set(model.into()),
input_data_hash: Set(input_hash.into()),
sanitized_input: Set(None),
result_content: Set(None),
result_metadata: Set(None),
status: Set("streaming".into()),
error_message: Set(None),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(Some(user_id)),
updated_by: Set(Some(user_id)),
deleted_at: Set(None),
version_lock: Set(1),
};
active.insert(&self.db).await?;
Ok(())
}
}