use futures::Stream; use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set}; use sha2::{Digest, Sha256}; use std::pin::Pin; use uuid::Uuid; use erp_core::types::Pagination; use crate::dto::{AnalysisType, GenerateRequest}; use crate::entity::ai_analysis; use crate::error::{AiError, AiResult}; use crate::knowledge::KnowledgeSource; use crate::prompt::PromptRenderer; use crate::provider::AiProvider; use crate::sanitization::SanitizationService; pub struct AnalysisService { pub provider: Box, pub sanitizer: SanitizationService, pub renderer: PromptRenderer, pub db: sea_orm::DatabaseConnection, pub knowledge_source: Option>, } impl AnalysisService { pub fn new(provider: Box, db: sea_orm::DatabaseConnection) -> Self { Self { provider, sanitizer: SanitizationService::new(), renderer: PromptRenderer::new(), db, knowledge_source: None, } } pub fn with_knowledge_source(mut self, source: std::sync::Arc) -> Self { self.knowledge_source = Some(source); self } /// 执行流式分析 — 返回 SSE 事件流 pub async fn stream_analyze( &self, tenant_id: Uuid, user_id: Uuid, patient_id: Uuid, analysis_type: AnalysisType, source_ref: String, system_prompt: String, user_template: String, sanitized_data: serde_json::Value, model: String, temperature: f32, max_tokens: u32, ) -> AiResult<( Pin> + Send>>, uuid::Uuid, String, )> { let analysis_id = Uuid::now_v7(); let input_hash = self.compute_hash(&sanitized_data); let provider_name = self.provider.name().to_string(); // 0. 缓存命中检查(相同输入 + prompt 版本 → 复用已有结果) if let Some(cached) = self.find_cached(tenant_id, &input_hash, 1).await? { tracing::info!(analysis = %cached.id, "AI 分析缓存命中,复用已有结果"); let content = cached.result_content.clone().unwrap_or_default(); let metadata = cached.result_metadata.clone().unwrap_or(serde_json::json!({})); let stream = self.replay_cached(content, metadata); return Ok((stream, cached.id, provider_name)); } tracing::info!(analysis = %analysis_id, tenant = %tenant_id, r#type = %analysis_type.as_str(), "发起 AI 分析"); // 0.5 知识库上下文注入 let system_prompt = if let Some(ref ks) = self.knowledge_source { let query = crate::knowledge::KnowledgeQuery { tenant_id, analysis_type: analysis_type.as_str().to_string(), patient_context: None, query_text: None, }; match ks.get_context(&query).await { Ok(ctx) if ctx.confidence > 0.0 => { tracing::info!( source = %ctx.source, confidence = ctx.confidence, "知识库上下文注入" ); format!("{}\n\n=== 知识库参考 ===\n{}", system_prompt, ctx.context_text) } Ok(_) => system_prompt, Err(e) => { tracing::warn!(error = %e, "知识库查询失败,跳过注入"); system_prompt } } } else { system_prompt }; // 1. 渲染 Prompt let user_prompt = self.renderer.render(&user_template, &sanitized_data)?; // 2. 创建分析记录 self.create_analysis_record( analysis_id, tenant_id, user_id, patient_id, analysis_type.as_str(), &source_ref, &input_hash, &provider_name, &model, ) .await?; // 3. 调用 AI 流式生成 let req = GenerateRequest { system_prompt, user_prompt, model, temperature, max_tokens, }; let stream = self.provider.stream_generate(req).await?; Ok((stream, analysis_id, provider_name)) } /// 将缓存结果构造为一次性 Stream(直接回放纯文本,不额外包装 JSON) fn replay_cached( &self, content: String, _metadata: serde_json::Value, ) -> Pin> + Send>> { use futures::stream; Box::pin(stream::once(async move { Ok(content) })) } /// 更新分析记录为完成 pub async fn complete_analysis( &self, analysis_id: Uuid, content: String, metadata: serde_json::Value, ) -> AiResult<()> { let entity = ai_analysis::Entity::find_by_id(analysis_id) .one(&self.db) .await? .ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?; // 缓存回放时记录已是 completed,跳过重复更新 if entity.status == "completed" { tracing::debug!(analysis = %analysis_id, "分析已完成,跳过重复 complete"); return Ok(()); } let mut active: ai_analysis::ActiveModel = entity.into(); active.status = Set("completed".into()); active.result_content = Set(Some(content)); active.result_metadata = Set(Some(metadata)); active.updated_at = Set(chrono::Utc::now()); active.update(&self.db).await?; Ok(()) } /// 标记分析失败 pub async fn fail_analysis(&self, analysis_id: Uuid, error: String) -> AiResult<()> { let entity = ai_analysis::Entity::find_by_id(analysis_id) .one(&self.db) .await? .ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?; let mut active: ai_analysis::ActiveModel = entity.into(); active.status = Set("failed".into()); active.error_message = Set(Some(error)); active.updated_at = Set(chrono::Utc::now()); active.update(&self.db).await?; Ok(()) } /// 查找缓存 pub async fn find_cached( &self, tenant_id: Uuid, input_hash: &str, prompt_version: i32, ) -> AiResult> { let result = ai_analysis::Entity::find() .filter(ai_analysis::Column::TenantId.eq(tenant_id)) .filter(ai_analysis::Column::InputDataHash.eq(input_hash)) .filter(ai_analysis::Column::PromptVersion.eq(prompt_version)) .filter(ai_analysis::Column::Status.eq("completed")) .filter(ai_analysis::Column::DeletedAt.is_null()) .one(&self.db) .await?; Ok(result) } fn compute_hash(&self, data: &serde_json::Value) -> String { let canonical = serde_json::to_string(data).unwrap_or_default(); let mut hasher = Sha256::new(); hasher.update(canonical.as_bytes()); hex::encode(hasher.finalize()) } /// 分页查询分析记录 pub async fn list_analysis( &self, tenant_id: Uuid, patient_id: Option, analysis_type: Option, pagination: &Pagination, ) -> AiResult<(Vec, u64)> { let mut query = ai_analysis::Entity::find() .filter(ai_analysis::Column::TenantId.eq(tenant_id)) .filter(ai_analysis::Column::DeletedAt.is_null()); if let Some(pid) = patient_id { query = query.filter(ai_analysis::Column::PatientId.eq(pid)); } if let Some(at) = &analysis_type { query = query.filter(ai_analysis::Column::AnalysisType.eq(at.as_str())); } let total = query.clone().count(&self.db).await?; let items = query .order_by_desc(ai_analysis::Column::CreatedAt) .offset(pagination.offset()) .limit(pagination.limit()) .all(&self.db) .await?; Ok((items, total)) } /// 获取单条分析记录 pub async fn get_analysis( &self, id: Uuid, tenant_id: Uuid, ) -> AiResult { let model = ai_analysis::Entity::find_by_id(id) .one(&self.db) .await? .ok_or_else(|| AiError::AnalysisNotFound(id.to_string()))?; if model.tenant_id != tenant_id { return Err(AiError::AnalysisNotFound(id.to_string())); } Ok(model) } async fn create_analysis_record( &self, id: Uuid, tenant_id: Uuid, user_id: Uuid, patient_id: Uuid, analysis_type: &str, source_ref: &str, input_hash: &str, _provider: &str, model: &str, ) -> AiResult<()> { let now = chrono::Utc::now(); let active = ai_analysis::ActiveModel { id: Set(id), tenant_id: Set(tenant_id), patient_id: Set(patient_id), analysis_type: Set(analysis_type.into()), source_ref: Set(source_ref.into()), prompt_id: Set(Uuid::nil()), prompt_version: Set(1), model_used: Set(model.into()), input_data_hash: Set(input_hash.into()), sanitized_input: Set(None), result_content: Set(None), result_metadata: Set(None), status: Set("streaming".into()), error_message: Set(None), created_at: Set(now), updated_at: Set(now), created_by: Set(Some(user_id)), updated_by: Set(Some(user_id)), deleted_at: Set(None), version_lock: Set(1), }; active.insert(&self.db).await?; Ok(()) } }