Files
hms/crates/erp-ai/src/handler/mod.rs
iven 710b2e2423 feat(ai): 新增 AI 客服聊天功能 + 消息页重构为小华助手
- 新增 POST /ai/chat 端点,由 LLM(Ollama qwen3)担任 24h 健康客服"小华"
- 新增 ai.chat.send 权限,绑定管理员/患者/医生/护士/健康管理师角色
- 消息页从咨询列表重构为单窗口 AI 对话(欢迎态 + 聊天态 + 快捷问诊)
- 通知功能迁移到"我的"页面菜单项(带未读角标),独立通知列表页
- 修复气泡文字截断:改用百分比 max-width + block Text + pre-wrap 换行
- 修复权限绑定:迁移 SQL 角色名从英文改为中文(admin→管理员,patient→患者)
2026-05-17 00:49:41 +08:00

913 lines
28 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::Json;
use axum::extract::{Extension, FromRef, Path, Query, State};
use axum::response::sse::{Event, KeepAlive, Sse};
use erp_core::health_provider::TimeRange;
use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, TenantContext};
use futures::StreamExt;
use serde::Deserialize;
use std::convert::Infallible;
use crate::dto::{AnalysisSseEvent, AnalysisType};
use crate::state::AiState;
pub mod chat_handler;
pub mod insight_handler;
pub mod risk_handler;
pub mod rule_handler;
pub mod suggestion_handler;
// === 分析请求 Body ===
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct AnalyzeBody {
pub report_id: Option<uuid::Uuid>,
pub patient_id: Option<uuid::Uuid>,
pub metrics: Option<Vec<String>>,
}
// === SSE 分析端点 ===
#[utoipa::path(
post,
path = "/ai/analyze/lab-report",
request_body = AnalyzeBody,
responses((status = 200, description = "SSE 化验报告分析流")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn stream_lab_report<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Json(body): Json<AnalyzeBody>,
) -> Result<Sse<impl futures::Stream<Item = Result<Event, Infallible>>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.manage")?;
let report_id = body
.report_id
.ok_or_else(|| erp_core::error::AppError::Validation("report_id 必填".into()))?;
let lab_dto = state
.health_provider
.get_lab_report(ctx.tenant_id, report_id)
.await?;
if lab_dto.items.is_empty() {
return Err(erp_core::error::AppError::Validation(
"化验报告缺少检查项目数据,无法进行 AI 分析。请先录入完整的化验指标。".into(),
));
}
let sanitized_data = state.analysis.sanitizer.sanitize_lab_report(&lab_dto)?;
let prompt = state
.prompt
.get_active_prompt(ctx.tenant_id, "lab_report_interpretation")
.await?;
let model_config = &prompt.model_config;
let model = model_config["model"]
.as_str()
.unwrap_or("claude-sonnet-4-6")
.to_string();
let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32;
let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32;
let (stream, analysis_id, _provider_name) = state
.analysis
.stream_analyze(
ctx.tenant_id,
ctx.user_id,
uuid::Uuid::nil(),
AnalysisType::LabReport,
report_id.to_string(),
prompt.system_prompt,
prompt.user_prompt_template,
sanitized_data,
model,
temperature,
max_tokens,
)
.await?;
let analysis_id_clone = analysis_id;
let state_clone = state.clone();
let patient_id_clone = uuid::Uuid::nil(); // lab report 场景 patient_id 从 report 关联
let doctor_id_clone = ctx.user_id;
let sse_stream = build_sse_stream(
stream,
analysis_id_clone,
state_clone,
"lab_report",
ctx.tenant_id,
patient_id_clone,
doctor_id_clone,
);
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
#[utoipa::path(
post,
path = "/ai/analyze/trends",
request_body = AnalyzeBody,
responses((status = 200, description = "SSE 趋势分析流")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn stream_trends<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Json(body): Json<AnalyzeBody>,
) -> Result<Sse<impl futures::Stream<Item = Result<Event, Infallible>>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.manage")?;
let patient_id = body
.patient_id
.ok_or_else(|| erp_core::error::AppError::Validation("patient_id 必填".into()))?;
let metrics = body.metrics.unwrap_or_else(|| {
vec![
"systolic_bp_morning".into(),
"diastolic_bp_morning".into(),
"heart_rate".into(),
"weight".into(),
"blood_sugar".into(),
]
});
let range = TimeRange {
start: chrono::Utc::now() - chrono::Duration::days(90),
end: chrono::Utc::now(),
};
let trend_data = state
.health_provider
.get_trend_analysis_data(ctx.tenant_id, patient_id, &metrics, &range)
.await?;
if trend_data.metrics.is_empty() {
return Err(erp_core::error::AppError::Validation(
"患者在选定时间段内无体征监测数据,无法进行趋势分析。".into(),
));
}
let sanitized_data = state
.analysis
.sanitizer
.sanitize_trend_analysis(&trend_data)?;
let prompt = state
.prompt
.get_active_prompt(ctx.tenant_id, "health_trend_analysis")
.await?;
let model_config = &prompt.model_config;
let model = model_config["model"]
.as_str()
.unwrap_or("claude-sonnet-4-6")
.to_string();
let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32;
let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32;
let (stream, analysis_id, _) = state
.analysis
.stream_analyze(
ctx.tenant_id,
ctx.user_id,
patient_id,
AnalysisType::Trends,
patient_id.to_string(),
prompt.system_prompt,
prompt.user_prompt_template,
sanitized_data,
model,
temperature,
max_tokens,
)
.await?;
let analysis_id_clone = analysis_id;
let state_clone = state.clone();
let sse_stream = build_sse_stream(
stream,
analysis_id_clone,
state_clone,
"trend",
ctx.tenant_id,
uuid::Uuid::nil(),
ctx.user_id,
);
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
#[utoipa::path(
post,
path = "/ai/analyze/checkup-plan",
request_body = AnalyzeBody,
responses((status = 200, description = "SSE 体检计划分析流")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn stream_checkup_plan<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Json(body): Json<AnalyzeBody>,
) -> Result<Sse<impl futures::Stream<Item = Result<Event, Infallible>>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.manage")?;
let patient_id = body
.patient_id
.ok_or_else(|| erp_core::error::AppError::Validation("patient_id 必填".into()))?;
let summary_dto = state
.health_provider
.get_patient_summary(ctx.tenant_id, patient_id)
.await?;
let sanitized_data = state
.analysis
.sanitizer
.sanitize_patient_summary(&summary_dto)?;
let prompt = state
.prompt
.get_active_prompt(ctx.tenant_id, "personalized_checkup_plan")
.await?;
let model_config = &prompt.model_config;
let model = model_config["model"]
.as_str()
.unwrap_or("claude-sonnet-4-6")
.to_string();
let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32;
let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32;
let (stream, analysis_id, _) = state
.analysis
.stream_analyze(
ctx.tenant_id,
ctx.user_id,
patient_id,
AnalysisType::CheckupPlan,
patient_id.to_string(),
prompt.system_prompt,
prompt.user_prompt_template,
sanitized_data,
model,
temperature,
max_tokens,
)
.await?;
let analysis_id_clone = analysis_id;
let state_clone = state.clone();
let sse_stream = build_sse_stream(
stream,
analysis_id_clone,
state_clone,
"checkup_plan",
ctx.tenant_id,
uuid::Uuid::nil(),
ctx.user_id,
);
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
#[utoipa::path(
post,
path = "/ai/analyze/report-summary",
request_body = AnalyzeBody,
responses((status = 200, description = "SSE 报告摘要分析流")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn stream_report_summary<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Json(body): Json<AnalyzeBody>,
) -> Result<Sse<impl futures::Stream<Item = Result<Event, Infallible>>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.manage")?;
let report_id = body
.report_id
.ok_or_else(|| erp_core::error::AppError::Validation("report_id 必填".into()))?;
let report_dto = state
.health_provider
.get_full_report(ctx.tenant_id, report_id)
.await?;
if report_dto.sections.is_empty() {
return Err(erp_core::error::AppError::Validation(
"健康报告缺少内容数据,无法生成摘要。请先完善报告内容。".into(),
));
}
let sanitized_data = state
.analysis
.sanitizer
.sanitize_health_report(&report_dto)?;
let prompt = state
.prompt
.get_active_prompt(ctx.tenant_id, "report_summary_generation")
.await?;
let model_config = &prompt.model_config;
let model = model_config["model"]
.as_str()
.unwrap_or("claude-sonnet-4-6")
.to_string();
let temperature = model_config["temperature"].as_f64().unwrap_or(0.3) as f32;
let max_tokens = model_config["max_tokens"].as_u64().unwrap_or(2048) as u32;
let (stream, analysis_id, _) = state
.analysis
.stream_analyze(
ctx.tenant_id,
ctx.user_id,
uuid::Uuid::nil(),
AnalysisType::ReportSummary,
report_id.to_string(),
prompt.system_prompt,
prompt.user_prompt_template,
sanitized_data,
model,
temperature,
max_tokens,
)
.await?;
let analysis_id_clone = analysis_id;
let state_clone = state.clone();
let sse_stream = build_sse_stream(
stream,
analysis_id_clone,
state_clone,
"report_summary",
ctx.tenant_id,
uuid::Uuid::nil(),
ctx.user_id,
);
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
// === 分析历史 ===
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct ListAnalysisQuery {
pub patient_id: Option<uuid::Uuid>,
pub analysis_type: Option<String>,
pub page: Option<u64>,
pub page_size: Option<u64>,
}
#[utoipa::path(
get,
path = "/ai/analysis/history",
params(ListAnalysisQuery),
responses((status = 200, description = "分析历史列表")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn list_analysis<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Query(params): Query<ListAnalysisQuery>,
) -> Result<Json<ApiResponse<serde_json::Value>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.list")?;
let pagination = erp_core::types::Pagination {
page: params.page,
page_size: params.page_size,
};
let (items, total) = state
.analysis
.list_analysis(
ctx.tenant_id,
params.patient_id,
params.analysis_type,
&pagination,
)
.await?;
// 批量查询 patient_name通过 raw SQL 避免跨 crate 依赖 erp-health
let patient_ids: std::collections::HashSet<uuid::Uuid> = items
.iter()
.filter(|a| a.patient_id != uuid::Uuid::nil())
.map(|a| a.patient_id)
.collect();
let patient_names: std::collections::HashMap<uuid::Uuid, String> = if !patient_ids.is_empty() {
#[derive(sea_orm::FromQueryResult)]
struct PatientName {
id: uuid::Uuid,
name: String,
}
let ids: Vec<uuid::Uuid> = patient_ids.into_iter().collect();
use sea_orm::FromQueryResult;
PatientName::find_by_statement(sea_orm::Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
"SELECT id, name FROM patient WHERE id = ANY($1) AND tenant_id = $2 AND deleted_at IS NULL",
[ids.into(), ctx.tenant_id.into()],
))
.all(&state.db)
.await
.unwrap_or_default()
.into_iter()
.map(|p| (p.id, p.name))
.collect()
} else {
std::collections::HashMap::new()
};
let data: Vec<serde_json::Value> = items
.into_iter()
.map(|a| {
let mut val = serde_json::to_value(&a).unwrap_or_default();
if let Some(obj) = val.as_object_mut() {
obj.insert(
"patient_name".to_string(),
serde_json::json!(patient_names.get(&a.patient_id).cloned()),
);
}
val
})
.collect();
Ok(Json(ApiResponse::ok(serde_json::json!({
"data": data,
"total": total,
"page": pagination.page.unwrap_or(1),
"page_size": pagination.limit(),
}))))
}
#[utoipa::path(
get,
path = "/ai/analysis/{id}",
responses((status = 200, description = "分析详情")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn get_analysis<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Path(id): Path<uuid::Uuid>,
) -> Result<Json<ApiResponse<crate::entity::ai_analysis::Model>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.list")?;
let analysis = state.analysis.get_analysis(id, ctx.tenant_id).await?;
Ok(Json(ApiResponse::ok(analysis)))
}
// === Prompt 管理 ===
#[derive(Debug, Deserialize, utoipa::IntoParams)]
pub struct ListPromptsQuery {
pub category: Option<String>,
pub page: Option<u64>,
pub page_size: Option<u64>,
}
#[utoipa::path(
get,
path = "/ai/prompts",
params(ListPromptsQuery),
responses((status = 200, description = "Prompt 模板列表")),
tag = "AI Prompt",
security(("bearer_auth" = [])),
)]
pub async fn list_prompts<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Query(params): Query<ListPromptsQuery>,
) -> Result<Json<ApiResponse<serde_json::Value>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.prompt.list")?;
let pagination = erp_core::types::Pagination {
page: params.page,
page_size: params.page_size,
};
let (items, total) = state
.prompt
.list_prompts(ctx.tenant_id, params.category, &pagination)
.await?;
Ok(Json(ApiResponse::ok(serde_json::json!({
"data": items,
"total": total,
"page": pagination.page.unwrap_or(1),
"page_size": pagination.limit(),
}))))
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreatePromptBody {
pub name: String,
pub description: Option<String>,
pub system_prompt: String,
pub user_prompt_template: String,
pub model_config: serde_json::Value,
pub category: String,
}
#[utoipa::path(
post,
path = "/ai/prompts",
request_body = CreatePromptBody,
responses((status = 200, description = "创建 Prompt 模板")),
tag = "AI Prompt",
security(("bearer_auth" = [])),
)]
pub async fn create_prompt<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Json(body): Json<CreatePromptBody>,
) -> Result<Json<ApiResponse<crate::entity::ai_prompt::Model>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.prompt.manage")?;
validate_prompt_safety(&body.system_prompt)?;
validate_prompt_safety(&body.user_prompt_template)?;
let prompt = state
.prompt
.create_prompt(
ctx.tenant_id,
ctx.user_id,
body.name,
body.system_prompt,
body.user_prompt_template,
body.model_config,
body.category,
)
.await?;
Ok(Json(ApiResponse::ok(prompt)))
}
#[utoipa::path(
post,
path = "/ai/prompts/{id}/activate",
responses((status = 200, description = "激活 Prompt 模板")),
tag = "AI Prompt",
security(("bearer_auth" = [])),
)]
pub async fn activate_prompt<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Path(id): Path<uuid::Uuid>,
) -> Result<Json<ApiResponse<crate::entity::ai_prompt::Model>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.prompt.manage")?;
let prompt = state.prompt.activate_prompt(id, ctx.tenant_id).await?;
Ok(Json(ApiResponse::ok(prompt)))
}
#[utoipa::path(
post,
path = "/ai/prompts/{id}/rollback",
responses((status = 200, description = "回滚 Prompt 模板")),
tag = "AI Prompt",
security(("bearer_auth" = [])),
)]
pub async fn rollback_prompt<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
Path(id): Path<uuid::Uuid>,
) -> Result<Json<ApiResponse<crate::entity::ai_prompt::Model>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.prompt.manage")?;
let prompt = state.prompt.rollback_prompt(id, ctx.tenant_id).await?;
Ok(Json(ApiResponse::ok(prompt)))
}
// === 用量统计 ===
#[utoipa::path(
get,
path = "/ai/usage/overview",
responses((status = 200, description = "AI 用量概览")),
tag = "AI 用量",
security(("bearer_auth" = [])),
)]
pub async fn usage_overview<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<serde_json::Value>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.usage.list")?;
let overview = state.usage.get_overview(ctx.tenant_id).await?;
Ok(Json(ApiResponse::ok(serde_json::json!({
"total_count": overview.total_count,
}))))
}
#[utoipa::path(
get,
path = "/ai/usage/by-type",
responses((status = 200, description = "按类型用量统计")),
tag = "AI 用量",
security(("bearer_auth" = [])),
)]
pub async fn usage_by_type<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<Vec<serde_json::Value>>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.usage.list")?;
let types = state.usage.get_by_type(ctx.tenant_id).await?;
let result: Vec<serde_json::Value> = types
.into_iter()
.map(|t| {
serde_json::json!({
"analysis_type": t.analysis_type,
"count": t.count,
})
})
.collect();
Ok(Json(ApiResponse::ok(result)))
}
// === Provider 管理 ===
#[utoipa::path(
get,
path = "/ai/providers/health",
responses((status = 200, description = "AI Provider 健康检查")),
tag = "AI Provider",
security(("bearer_auth" = [])),
)]
pub async fn provider_health<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<serde_json::Value>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.list")?;
let statuses = state.provider_registry.health_check_all().await;
let result: serde_json::Value = statuses.iter().map(|entry| {
let (name, health) = entry.pair();
serde_json::json!({
"provider": name,
"healthy": health.is_healthy(),
"status": match health {
crate::provider::registry::ProviderHealth::Healthy { last_check } =>
serde_json::json!({"status": "healthy", "last_check": last_check.to_rfc3339()}),
crate::provider::registry::ProviderHealth::Degraded { last_check, error } =>
serde_json::json!({"status": "degraded", "last_check": last_check.to_rfc3339(), "error": error}),
crate::provider::registry::ProviderHealth::Unavailable { since, error } =>
serde_json::json!({"status": "unavailable", "since": since.to_rfc3339(), "error": error}),
},
})
}).collect();
Ok(Json(ApiResponse::ok(result)))
}
#[utoipa::path(
get,
path = "/ai/providers",
responses((status = 200, description = "AI Provider 列表")),
tag = "AI Provider",
security(("bearer_auth" = [])),
)]
pub async fn provider_names<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<Vec<String>>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.list")?;
Ok(Json(ApiResponse::ok(
state.provider_registry.provider_names(),
)))
}
#[utoipa::path(
get,
path = "/ai/quota/summary",
responses((status = 200, description = "AI 配额汇总")),
tag = "AI 用量",
security(("bearer_auth" = [])),
)]
pub async fn quota_summary<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<crate::service::quota::QuotaSummary>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.usage.list")?;
let summary = state.quota.get_usage_summary(ctx.tenant_id).await?;
Ok(Json(ApiResponse::ok(summary)))
}
// === 透析风险评估KDIGO 规则) ===
#[utoipa::path(
post,
path = "/ai/dialysis/risk-assessment",
responses((status = 200, description = "透析风险评估")),
tag = "AI 分析",
security(("bearer_auth" = [])),
)]
pub async fn assess_dialysis_risk<S>(
Extension(ctx): Extension<TenantContext>,
Json(body): Json<crate::service::dialysis_risk_scorer::DialysisLabInput>,
) -> Result<
Json<ApiResponse<crate::service::dialysis_risk_scorer::DialysisRiskAssessment>>,
erp_core::error::AppError,
>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.analysis.manage")?;
let scorer = crate::service::dialysis_risk_scorer::DialysisRiskScorer::new();
let result = scorer.assess(&body);
Ok(Json(ApiResponse::ok(result)))
}
// === 成本与预算 ===
#[utoipa::path(
get,
path = "/ai/budget/status",
responses((status = 200, description = "AI 预算状态")),
tag = "AI 用量",
security(("bearer_auth" = [])),
)]
pub async fn budget_status<S>(
State(state): State<AiState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<crate::service::cost::BudgetStatus>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.usage.list")?;
let cost_svc = crate::service::cost::CostService::new(state.db.clone());
let status = cost_svc.get_budget_status(ctx.tenant_id).await?;
Ok(Json(ApiResponse::ok(status)))
}
#[derive(Debug, Deserialize)]
pub struct CostEstimateQuery {
pub analysis_type: String,
pub model: Option<String>,
}
#[utoipa::path(
get,
path = "/ai/cost/estimate",
responses((status = 200, description = "AI 成本预估")),
tag = "AI 用量",
security(("bearer_auth" = [])),
)]
pub async fn cost_estimate<S>(
Extension(ctx): Extension<TenantContext>,
Query(params): Query<CostEstimateQuery>,
) -> Result<Json<ApiResponse<crate::service::cost::CostEstimate>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.usage.list")?;
let model = params
.model
.unwrap_or_else(|| "claude-sonnet-4-6".to_string());
let estimate = crate::service::cost::CostService::estimate_cost(&params.analysis_type, &model);
Ok(Json(ApiResponse::ok(estimate)))
}
// === SSE 流构建辅助 ===
fn build_sse_stream(
stream: std::pin::Pin<Box<dyn futures::Stream<Item = crate::error::AiResult<String>> + Send>>,
analysis_id: uuid::Uuid,
state: AiState,
analysis_type: &'static str,
tenant_id: uuid::Uuid,
patient_id: uuid::Uuid,
doctor_id: uuid::Uuid,
) -> impl futures::Stream<Item = Result<Event, Infallible>> {
async_stream::stream! {
let mut full_content = String::new();
let mut index: u32 = 0;
let mut stream = std::pin::pin!(stream);
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
full_content.push_str(&chunk);
index += 1;
let event = AnalysisSseEvent::Chunk { content: chunk, index };
let data = serde_json::to_string(&event).unwrap_or_default();
yield Ok(Event::default().event("chunk").data(data));
}
Err(e) => {
let event = AnalysisSseEvent::Error { message: e.to_string() };
let data = serde_json::to_string(&event).unwrap_or_default();
yield Ok(Event::default().event("error").data(data));
let _ = state.analysis.fail_analysis(analysis_id, e.to_string()).await;
// 发布 AI 分析失败事件
let fail_event = erp_core::events::DomainEvent::new(
"ai.analysis.failed",
tenant_id,
erp_core::events::build_event_payload(serde_json::json!({
"analysis_id": analysis_id,
"error": e.to_string(),
})),
);
state.event_bus.publish(fail_event, &state.db).await;
return;
}
}
}
let metadata = serde_json::json!({"analysis_type": analysis_type});
let _ = state.analysis.complete_analysis(analysis_id, full_content.clone(), metadata.clone()).await;
// 后处理:解析双通道输出、创建建议、发布事件
crate::service::post_process::post_process_analysis(
&state,
analysis_id,
&full_content,
tenant_id,
patient_id,
doctor_id,
analysis_type,
metadata,
).await;
let done_event = AnalysisSseEvent::Done {
analysis_id,
status: "completed".into(),
};
let data = serde_json::to_string(&done_event).unwrap_or_default();
yield Ok(Event::default().event("done").data(data));
}
}
/// 检查提示词内容是否包含可疑注入模式
fn validate_prompt_safety(content: &str) -> Result<(), erp_core::error::AppError> {
let suspicious = [
"ignore previous",
"ignore all previous",
"ignore above",
"disregard previous",
"you are now",
"new instructions:",
];
let lower = content.to_lowercase();
for pattern in &suspicious {
if lower.contains(pattern) {
return Err(erp_core::error::AppError::Validation(format!(
"提示词内容包含不安全模式: {}",
pattern
)));
}
}
Ok(())
}