feat(message): SSE 增强 — Event ID + 心跳保活 + Last-Event-ID + 患者订阅

- 每个 SSE 事件附加 id 字段(UUID v7)用于断点续传
- 30s timeout 心跳保活防止连接断开
- Last-Event-ID header 恢复:重连跳过已发送事件
- ?patient_ids=id1,id2 查询参数选择性订阅患者
This commit is contained in:
iven
2026-05-04 02:49:23 +08:00
parent 975d699e42
commit bb5298ee0f

View File

@@ -1,8 +1,12 @@
use std::cell::Cell;
use std::collections::HashSet;
use std::convert::Infallible;
use axum::extract::Extension;
use axum::extract::{Extension, Query};
use axum::http::HeaderMap;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::stream::Stream;
use serde::Deserialize;
use sea_orm::ConnectionTrait;
use uuid::Uuid;
@@ -11,34 +15,75 @@ use erp_core::types::TenantContext;
use crate::message_state::MessageState;
/// SSE 查询参数
#[derive(Debug, Deserialize, Default)]
pub struct SseQuery {
/// 逗号分隔的患者 ID 列表,为空则订阅所有管床患者
pub patient_ids: Option<String>,
}
/// SSE 消息推送端点。
///
/// 监听所有事件,按类型分发为不同 SSE event
/// - `message.sent` → SSE event: `message`
/// - `alert.triggered` → SSE event: `alert`
/// - `device.readings.synced` → SSE event: `vital_update`
///
/// 增强:
/// - Event ID支持 Last-Event-ID 断点续传)
/// - 30s 心跳保活
/// - 患者选择性订阅(?patient_ids=id1,id2
pub async fn message_stream(
axum::extract::State(state): axum::extract::State<MessageState>,
Extension(ctx): Extension<TenantContext>,
headers: HeaderMap,
Query(query): Query<SseQuery>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, AppError> {
let user_id = ctx.user_id;
let tenant_id = ctx.tenant_id;
// 空前缀 = 订阅所有事件
let last_event_id: Option<Uuid> = headers
.get("Last-Event-ID")
.and_then(|v| v.to_str().ok())
.and_then(|s| Uuid::parse_str(s).ok());
let subscribed_patient_ids: Option<HashSet<String>> = query.patient_ids.as_ref().map(|s| {
s.split(',')
.map(|id| id.trim().to_string())
.filter(|id| !id.is_empty())
.collect()
});
let (mut rx, _handle) = state.event_bus.subscribe_filtered(String::new());
let db = state.db.clone();
let last_event_id_cell = Cell::new(last_event_id);
let sse_stream = async_stream::stream! {
loop {
match rx.recv().await {
Some(event) => {
let result = tokio::time::timeout(
std::time::Duration::from_secs(30),
rx.recv(),
).await;
match result {
Ok(Some(event)) => {
if event.tenant_id != tenant_id {
continue;
}
// Last-Event-ID 恢复:跳过已发送的事件
if let Some(skip_until) = last_event_id_cell.take() {
if event.id <= skip_until {
last_event_id_cell.set(Some(skip_until));
continue;
}
}
match event.event_type.as_str() {
"message.sent" => {
let is_recipient = event.payload.get("recipient_id")
.and_then(|v: &serde_json::Value| v.as_str())
.and_then(|v| v.as_str())
.map(|s| s == user_id.to_string())
.unwrap_or(false);
if !is_recipient {
@@ -48,12 +93,20 @@ pub async fn message_stream(
.unwrap_or_default();
yield Ok(Event::default()
.event("message")
.id(event.id.to_string())
.data(data));
}
"alert.triggered" => {
// 医患关系过滤:只推送给该患者的管床医生
let patient_id = event.payload.get("patient_id")
.and_then(|v| v.as_str());
// 患者订阅过滤
if let (Some(pid_str), Some(subscribed)) = (patient_id, &subscribed_patient_ids) {
if !subscribed.contains(pid_str) {
continue;
}
}
if let Some(pid_str) = patient_id {
let pid = Uuid::parse_str(pid_str).ok();
if let Some(pid) = pid {
@@ -69,12 +122,20 @@ pub async fn message_stream(
.unwrap_or_default();
yield Ok(Event::default()
.event("alert")
.id(event.id.to_string())
.data(data));
}
"device.readings.synced" => {
// 医患关系过滤:只推送给该患者的管床医生
let patient_id = event.payload.get("patient_id")
.and_then(|v| v.as_str());
// 患者订阅过滤
if let (Some(pid_str), Some(subscribed)) = (patient_id, &subscribed_patient_ids) {
if !subscribed.contains(pid_str) {
continue;
}
}
if let Some(pid_str) = patient_id {
let pid = Uuid::parse_str(pid_str).ok();
if let Some(pid) = pid {
@@ -90,29 +151,31 @@ pub async fn message_stream(
.unwrap_or_default();
yield Ok(Event::default()
.event("vital_update")
.id(event.id.to_string())
.data(data));
}
_ => {}
}
}
None => {
Ok(None) => {
break;
}
Err(_) => {
// 超时 = 发送心跳
yield Ok(Event::default().comment("ping"));
}
}
}
};
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
Ok(Sse::new(sse_stream).keep_alive(
KeepAlive::new()
.interval(std::time::Duration::from_secs(30))
.text("ping"),
))
}
/// 检查 user_id 对应的医生是否是某患者的管床医生。
///
/// 查询 `patient_doctor_relation` 表:
/// - `doctor_id` 匹配 `user_id`doctor_profile 主键即 user_id
/// - `patient_id` 匹配目标患者
/// - 未软删除
///
/// 查询失败时返回 false宁可漏推不可误推
async fn is_doctor_for_patient(
db: &sea_orm::DatabaseConnection,
tenant_id: Uuid,
@@ -149,10 +212,6 @@ async fn is_doctor_for_patient(
mod tests {
use super::*;
/// 验证 is_doctor_for_patient 函数签名和基础逻辑。
///
/// 由于需要真实数据库连接,此处仅测试参数构造正确性。
/// 完整的数据库集成测试在 erp-health 的集成测试中覆盖。
#[test]
fn patient_id_parsing_from_payload() {
let payload = serde_json::json!({
@@ -189,4 +248,59 @@ mod tests {
let pid = Uuid::parse_str(pid_str.unwrap()).ok();
assert!(pid.is_none());
}
#[test]
fn sse_query_parses_patient_ids() {
let query: SseQuery = serde_urlencoded::from_str("patient_ids=id1,id2,id3").unwrap();
assert!(query.patient_ids.is_some());
let ids = query.patient_ids.unwrap();
assert_eq!(ids, "id1,id2,id3");
}
#[test]
fn sse_query_default_is_empty() {
let query: SseQuery = serde_urlencoded::from_str("").unwrap();
assert!(query.patient_ids.is_none());
}
#[test]
fn subscribed_patient_ids_parsing() {
let query: SseQuery = serde_urlencoded::from_str("patient_ids=aaa,bbb,ccc").unwrap();
let set: Option<HashSet<String>> = query.patient_ids.map(|s| {
s.split(',')
.map(|id| id.trim().to_string())
.filter(|id| !id.is_empty())
.collect()
});
assert!(set.is_some());
let set = set.unwrap();
assert_eq!(set.len(), 3);
assert!(set.contains("aaa"));
assert!(set.contains("bbb"));
assert!(set.contains("ccc"));
}
#[test]
fn last_event_id_parsing_from_headers() {
let event_id = Uuid::now_v7();
let mut headers = HeaderMap::new();
headers.insert("Last-Event-ID", event_id.to_string().parse().unwrap());
let parsed: Option<Uuid> = headers
.get("Last-Event-ID")
.and_then(|v| v.to_str().ok())
.and_then(|s| Uuid::parse_str(s).ok());
assert_eq!(parsed, Some(event_id));
}
#[test]
fn last_event_id_missing_returns_none() {
let headers = HeaderMap::new();
let parsed: Option<Uuid> = headers
.get("Last-Event-ID")
.and_then(|v| v.to_str().ok())
.and_then(|s| Uuid::parse_str(s).ok());
assert!(parsed.is_none());
}
}