use axum::extract::{Extension, FromRef, Path, Query, State}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::Json; use erp_core::health_provider::TimeRange; use erp_core::rbac::require_permission; use erp_core::types::{ApiResponse, TenantContext}; use futures::StreamExt; use serde::Deserialize; use std::convert::Infallible; use crate::dto::{AnalysisSseEvent, AnalysisType}; use crate::service::suggestion::SuggestionService; use crate::state::AiState; pub mod suggestion_handler; // === 分析请求 Body === #[derive(Debug, Deserialize)] pub struct AnalyzeBody { pub report_id: Option, pub patient_id: Option, pub metrics: Option>, } // === SSE 分析端点 === pub async fn stream_lab_report( State(state): State, Extension(ctx): Extension, Json(body): Json, ) -> Result>>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.analysis.manage")?; let report_id = body.report_id.ok_or_else(|| { erp_core::error::AppError::Validation("report_id 必填".into()) })?; let lab_dto = state .health_provider .get_lab_report(ctx.tenant_id, report_id) .await?; let sanitized_data = state.analysis.sanitizer.sanitize_lab_report(&lab_dto)?; let prompt = state .prompt .get_active_prompt(ctx.tenant_id, "lab_report_interpretation") .await?; let model_config = &prompt.model_config; let model = model_config["model"].as_str().unwrap_or("claude-sonnet-4-6").to_string(); let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32; let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32; let (stream, analysis_id, _provider_name) = state .analysis .stream_analyze( ctx.tenant_id, ctx.user_id, uuid::Uuid::nil(), AnalysisType::LabReport, report_id.to_string(), prompt.system_prompt, prompt.user_prompt_template, sanitized_data, model, temperature, max_tokens, ) .await?; let analysis_id_clone = analysis_id; let state_clone = state.clone(); let patient_id_clone = uuid::Uuid::nil(); // lab report 场景 patient_id 从 report 关联 let doctor_id_clone = ctx.user_id; let sse_stream = build_sse_stream(stream, analysis_id_clone, state_clone, "lab_report", ctx.tenant_id, patient_id_clone, doctor_id_clone); Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) } pub async fn stream_trends( State(state): State, Extension(ctx): Extension, Json(body): Json, ) -> Result>>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.analysis.manage")?; let patient_id = body.patient_id.ok_or_else(|| { erp_core::error::AppError::Validation("patient_id 必填".into()) })?; let metrics = body.metrics.unwrap_or_else(|| { vec![ "systolic_bp_morning".into(), "diastolic_bp_morning".into(), "heart_rate".into(), "weight".into(), "blood_sugar".into(), ] }); let range = TimeRange { start: chrono::Utc::now() - chrono::Duration::days(90), end: chrono::Utc::now(), }; let trend_data = state .health_provider .get_trend_analysis_data(ctx.tenant_id, patient_id, &metrics, &range) .await?; let sanitized_data = state.analysis.sanitizer.sanitize_trend_analysis(&trend_data)?; let prompt = state .prompt .get_active_prompt(ctx.tenant_id, "health_trend_analysis") .await?; let model_config = &prompt.model_config; let model = model_config["model"].as_str().unwrap_or("claude-sonnet-4-6").to_string(); let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32; let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32; let (stream, analysis_id, _) = state .analysis .stream_analyze( ctx.tenant_id, ctx.user_id, patient_id, AnalysisType::Trends, patient_id.to_string(), prompt.system_prompt, prompt.user_prompt_template, sanitized_data, model, temperature, max_tokens, ) .await?; let analysis_id_clone = analysis_id; let state_clone = state.clone(); let sse_stream = build_sse_stream(stream, analysis_id_clone, state_clone, "trend", ctx.tenant_id, uuid::Uuid::nil(), ctx.user_id); Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) } pub async fn stream_checkup_plan( State(state): State, Extension(ctx): Extension, Json(body): Json, ) -> Result>>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.analysis.manage")?; let patient_id = body.patient_id.ok_or_else(|| { erp_core::error::AppError::Validation("patient_id 必填".into()) })?; let summary_dto = state .health_provider .get_patient_summary(ctx.tenant_id, patient_id) .await?; let sanitized_data = state .analysis .sanitizer .sanitize_patient_summary(&summary_dto)?; let prompt = state .prompt .get_active_prompt(ctx.tenant_id, "personalized_checkup_plan") .await?; let model_config = &prompt.model_config; let model = model_config["model"].as_str().unwrap_or("claude-sonnet-4-6").to_string(); let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32; let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32; let (stream, analysis_id, _) = state .analysis .stream_analyze( ctx.tenant_id, ctx.user_id, patient_id, AnalysisType::CheckupPlan, patient_id.to_string(), prompt.system_prompt, prompt.user_prompt_template, sanitized_data, model, temperature, max_tokens, ) .await?; let analysis_id_clone = analysis_id; let state_clone = state.clone(); let sse_stream = build_sse_stream(stream, analysis_id_clone, state_clone, "checkup_plan", ctx.tenant_id, uuid::Uuid::nil(), ctx.user_id); Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) } pub async fn stream_report_summary( State(state): State, Extension(ctx): Extension, Json(body): Json, ) -> Result>>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.analysis.manage")?; let report_id = body.report_id.ok_or_else(|| { erp_core::error::AppError::Validation("report_id 必填".into()) })?; let report_dto = state .health_provider .get_full_report(ctx.tenant_id, report_id) .await?; let sanitized_data = state .analysis .sanitizer .sanitize_health_report(&report_dto)?; let prompt = state .prompt .get_active_prompt(ctx.tenant_id, "report_summary_generation") .await?; let model_config = &prompt.model_config; let model = model_config["model"].as_str().unwrap_or("claude-sonnet-4-6").to_string(); let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32; let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32; let (stream, analysis_id, _) = state .analysis .stream_analyze( ctx.tenant_id, ctx.user_id, uuid::Uuid::nil(), AnalysisType::ReportSummary, report_id.to_string(), prompt.system_prompt, prompt.user_prompt_template, sanitized_data, model, temperature, max_tokens, ) .await?; let analysis_id_clone = analysis_id; let state_clone = state.clone(); let sse_stream = build_sse_stream(stream, analysis_id_clone, state_clone, "report_summary", ctx.tenant_id, uuid::Uuid::nil(), ctx.user_id); Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) } // === 分析历史 === #[derive(Debug, Deserialize)] pub struct ListAnalysisQuery { pub patient_id: Option, pub analysis_type: Option, pub page: Option, pub page_size: Option, } pub async fn list_analysis( State(state): State, Extension(ctx): Extension, Query(params): Query, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.analysis.list")?; let pagination = erp_core::types::Pagination { page: params.page, page_size: params.page_size, }; let (items, total) = state .analysis .list_analysis(ctx.tenant_id, params.patient_id, params.analysis_type, &pagination) .await?; Ok(Json(ApiResponse::ok(serde_json::json!({ "data": items, "total": total, "page": pagination.page.unwrap_or(1), "page_size": pagination.limit(), })))) } pub async fn get_analysis( State(state): State, Extension(ctx): Extension, Path(id): Path, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.analysis.list")?; let analysis = state.analysis.get_analysis(id, ctx.tenant_id).await?; Ok(Json(ApiResponse::ok(analysis))) } // === Prompt 管理 === #[derive(Debug, Deserialize)] pub struct ListPromptsQuery { pub category: Option, pub page: Option, pub page_size: Option, } pub async fn list_prompts( State(state): State, Extension(ctx): Extension, Query(params): Query, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.prompt.list")?; let pagination = erp_core::types::Pagination { page: params.page, page_size: params.page_size, }; let (items, total) = state .prompt .list_prompts(ctx.tenant_id, params.category, &pagination) .await?; Ok(Json(ApiResponse::ok(serde_json::json!({ "data": items, "total": total, "page": pagination.page.unwrap_or(1), "page_size": pagination.limit(), })))) } #[derive(Debug, Deserialize)] pub struct CreatePromptBody { pub name: String, pub description: Option, pub system_prompt: String, pub user_prompt_template: String, pub model_config: serde_json::Value, pub category: String, } pub async fn create_prompt( State(state): State, Extension(ctx): Extension, Json(body): Json, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.prompt.manage")?; let prompt = state .prompt .create_prompt( ctx.tenant_id, ctx.user_id, body.name, body.system_prompt, body.user_prompt_template, body.model_config, body.category, ) .await?; Ok(Json(ApiResponse::ok(prompt))) } pub async fn activate_prompt( State(state): State, Extension(ctx): Extension, Path(id): Path, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.prompt.manage")?; let prompt = state.prompt.activate_prompt(id, ctx.tenant_id).await?; Ok(Json(ApiResponse::ok(prompt))) } pub async fn rollback_prompt( State(state): State, Extension(ctx): Extension, Path(id): Path, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.prompt.manage")?; let prompt = state.prompt.rollback_prompt(id, ctx.tenant_id).await?; Ok(Json(ApiResponse::ok(prompt))) } // === 用量统计 === pub async fn usage_overview( State(state): State, Extension(ctx): Extension, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.usage.list")?; let overview = state.usage.get_overview(ctx.tenant_id).await?; Ok(Json(ApiResponse::ok(serde_json::json!({ "total_count": overview.total_count, })))) } pub async fn usage_by_type( State(state): State, Extension(ctx): Extension, ) -> Result>>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.usage.list")?; let types = state.usage.get_by_type(ctx.tenant_id).await?; let result: Vec = types .into_iter() .map(|t| { serde_json::json!({ "analysis_type": t.analysis_type, "count": t.count, }) }) .collect(); Ok(Json(ApiResponse::ok(result))) } // === SSE 流构建辅助 === fn build_sse_stream( stream: std::pin::Pin> + Send>>, analysis_id: uuid::Uuid, state: AiState, analysis_type: &'static str, tenant_id: uuid::Uuid, patient_id: uuid::Uuid, doctor_id: uuid::Uuid, ) -> impl futures::Stream> { async_stream::stream! { let mut full_content = String::new(); let mut index: u32 = 0; let mut stream = std::pin::pin!(stream); while let Some(result) = stream.next().await { match result { Ok(chunk) => { full_content.push_str(&chunk); index += 1; let event = AnalysisSseEvent::Chunk { content: chunk, index }; let data = serde_json::to_string(&event).unwrap_or_default(); yield Ok(Event::default().event("chunk").data(data)); } Err(e) => { let event = AnalysisSseEvent::Error { message: e.to_string() }; let data = serde_json::to_string(&event).unwrap_or_default(); yield Ok(Event::default().event("error").data(data)); let _ = state.analysis.fail_analysis(analysis_id, e.to_string()).await; // 发布 AI 分析失败事件 let fail_event = erp_core::events::DomainEvent::new( "ai.analysis.failed", tenant_id, erp_core::events::build_event_payload(serde_json::json!({ "analysis_id": analysis_id, "error": e.to_string(), })), ); state.event_bus.publish(fail_event, &state.db).await; return; } } } let metadata = serde_json::json!({"analysis_type": analysis_type}); let _ = state.analysis.complete_analysis(analysis_id, full_content.clone(), metadata).await; // 解析双通道输出并创建建议记录 let parsed = crate::service::output_parser::parse_dual_channel(&full_content).unwrap_or( crate::dto::suggestion::ParsedOutput { text_content: full_content.clone(), structured: None, }, ); let mut event_payload = serde_json::json!({ "analysis_id": analysis_id, "analysis_type": analysis_type, "patient_id": patient_id, "doctor_id": doctor_id, }); if let Some(ref structured) = parsed.structured { event_payload["risk_level"] = serde_json::json!(structured.risk_level.as_str()); event_payload["suggestion_count"] = serde_json::json!(structured.suggestions.len()); if !structured.suggestions.is_empty() { let _ = SuggestionService::create_suggestions( &state.db, tenant_id, analysis_id, &structured.suggestions, structured.risk_level, &structured.baseline_summary, Some(doctor_id), ).await; } } else { let _ = SuggestionService::mark_parse_failed(&state.db, analysis_id).await; } // 发布 AI 分析完成事件 let event = erp_core::events::DomainEvent::new( "ai.analysis.completed", tenant_id, erp_core::events::build_event_payload(event_payload), ); state.event_bus.publish(event, &state.db).await; let done_event = AnalysisSseEvent::Done { analysis_id, status: "completed".into(), }; let data = serde_json::to_string(&done_event).unwrap_or_default(); yield Ok(Event::default().event("done").data(data)); } }