use futures::Stream; use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, Set}; use sha2::{Digest, Sha256}; use std::pin::Pin; use uuid::Uuid; use erp_core::types::Pagination; use crate::dto::{AnalysisType, GenerateRequest}; use crate::entity::ai_analysis; use crate::error::{AiError, AiResult}; use crate::prompt::PromptRenderer; use crate::provider::AiProvider; use crate::sanitization::SanitizationService; pub struct AnalysisService { pub provider: Box, pub sanitizer: SanitizationService, pub renderer: PromptRenderer, pub db: sea_orm::DatabaseConnection, } impl AnalysisService { pub fn new(provider: Box, db: sea_orm::DatabaseConnection) -> Self { Self { provider, sanitizer: SanitizationService::new(), renderer: PromptRenderer::new(), db, } } /// 执行流式分析 — 返回 SSE 事件流 pub async fn stream_analyze( &self, tenant_id: Uuid, user_id: Uuid, patient_id: Uuid, analysis_type: AnalysisType, source_ref: String, system_prompt: String, user_template: String, sanitized_data: serde_json::Value, model: String, temperature: f32, max_tokens: u32, ) -> AiResult<( Pin> + Send>>, uuid::Uuid, String, )> { let analysis_id = Uuid::now_v7(); let input_hash = self.compute_hash(&sanitized_data); let provider_name = self.provider.name().to_string(); // 1. 渲染 Prompt let user_prompt = self.renderer.render(&user_template, &sanitized_data)?; // 2. 创建分析记录 self.create_analysis_record( analysis_id, tenant_id, user_id, patient_id, analysis_type.as_str(), &source_ref, &input_hash, &provider_name, &model, ) .await?; // 3. 调用 AI 流式生成 let req = GenerateRequest { system_prompt, user_prompt, model, temperature, max_tokens, }; let stream = self.provider.stream_generate(req).await?; Ok((stream, analysis_id, provider_name)) } /// 更新分析记录为完成 pub async fn complete_analysis( &self, analysis_id: Uuid, content: String, metadata: serde_json::Value, ) -> AiResult<()> { let entity = ai_analysis::Entity::find_by_id(analysis_id) .one(&self.db) .await? .ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?; let mut active: ai_analysis::ActiveModel = entity.into(); active.status = Set("completed".into()); active.result_content = Set(Some(content)); active.result_metadata = Set(Some(metadata)); active.updated_at = Set(chrono::Utc::now()); active.update(&self.db).await?; Ok(()) } /// 标记分析失败 pub async fn fail_analysis(&self, analysis_id: Uuid, error: String) -> AiResult<()> { let entity = ai_analysis::Entity::find_by_id(analysis_id) .one(&self.db) .await? .ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?; let mut active: ai_analysis::ActiveModel = entity.into(); active.status = Set("failed".into()); active.error_message = Set(Some(error)); active.updated_at = Set(chrono::Utc::now()); active.update(&self.db).await?; Ok(()) } /// 查找缓存 pub async fn find_cached( &self, tenant_id: Uuid, input_hash: &str, prompt_version: i32, ) -> AiResult> { let result = ai_analysis::Entity::find() .filter(ai_analysis::Column::TenantId.eq(tenant_id)) .filter(ai_analysis::Column::InputDataHash.eq(input_hash)) .filter(ai_analysis::Column::PromptVersion.eq(prompt_version)) .filter(ai_analysis::Column::Status.eq("completed")) .filter(ai_analysis::Column::DeletedAt.is_null()) .one(&self.db) .await?; Ok(result) } fn compute_hash(&self, data: &serde_json::Value) -> String { let canonical = serde_json::to_string(data).unwrap_or_default(); let mut hasher = Sha256::new(); hasher.update(canonical.as_bytes()); hex::encode(hasher.finalize()) } /// 分页查询分析记录 pub async fn list_analysis( &self, tenant_id: Uuid, patient_id: Option, analysis_type: Option, pagination: &Pagination, ) -> AiResult<(Vec, u64)> { let mut query = ai_analysis::Entity::find() .filter(ai_analysis::Column::TenantId.eq(tenant_id)) .filter(ai_analysis::Column::DeletedAt.is_null()); if let Some(pid) = patient_id { query = query.filter(ai_analysis::Column::PatientId.eq(pid)); } if let Some(at) = &analysis_type { query = query.filter(ai_analysis::Column::AnalysisType.eq(at.as_str())); } let total = query.clone().count(&self.db).await?; let items = query .order_by_desc(ai_analysis::Column::CreatedAt) .offset(pagination.offset()) .limit(pagination.limit()) .all(&self.db) .await?; Ok((items, total)) } /// 获取单条分析记录 pub async fn get_analysis( &self, id: Uuid, tenant_id: Uuid, ) -> AiResult { let model = ai_analysis::Entity::find_by_id(id) .one(&self.db) .await? .ok_or_else(|| AiError::AnalysisNotFound(id.to_string()))?; if model.tenant_id != tenant_id { return Err(AiError::AnalysisNotFound(id.to_string())); } Ok(model) } async fn create_analysis_record( &self, id: Uuid, tenant_id: Uuid, user_id: Uuid, patient_id: Uuid, analysis_type: &str, source_ref: &str, input_hash: &str, _provider: &str, model: &str, ) -> AiResult<()> { let now = chrono::Utc::now(); let active = ai_analysis::ActiveModel { id: Set(id), tenant_id: Set(tenant_id), patient_id: Set(patient_id), analysis_type: Set(analysis_type.into()), source_ref: Set(source_ref.into()), prompt_id: Set(Uuid::nil()), prompt_version: Set(1), model_used: Set(model.into()), input_data_hash: Set(input_hash.into()), sanitized_input: Set(None), result_content: Set(None), result_metadata: Set(None), status: Set("streaming".into()), error_message: Set(None), created_at: Set(now), updated_at: Set(now), created_by: Set(Some(user_id)), updated_by: Set(Some(user_id)), deleted_at: Set(None), version_lock: Set(1), }; active.insert(&self.db).await?; Ok(()) } }