Files
hms/crates/erp-health/src/oauth/service.rs
iven 975d699e42 feat(health): 告警降噪集成 alert_engine + OAuth service 编译修复
- alert_engine: create_alert_and_notify 调用 noise_reducer,升级严重度+suppressed标记
- oauth/service: 修复 OsRng import + ActiveModel move 问题
- fhir/handler: linter 补全完整实现
2026-05-04 02:43:32 +08:00

395 lines
13 KiB
Rust

use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::RngCore},
};
use chrono::Utc;
use sea_orm::{
ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set,
};
use uuid::Uuid;
use crate::entity::api_client;
use crate::oauth::dto::*;
use crate::oauth::error::{OAuthError, OAuthResult};
use crate::oauth::middleware::ClientCredentialsClaims;
const ALLOWED_SCOPES: &[&str] = &[
"Patient.read",
"Observation.read",
"Device.read",
"DiagnosticReport.read",
"Encounter.read",
"Practitioner.read",
"Appointment.read",
"Task.read",
];
fn validate_scopes(requested: &[String]) -> OAuthResult<Vec<String>> {
for scope in requested {
if !ALLOWED_SCOPES.contains(&scope.as_str()) {
return Err(OAuthError::InvalidScope);
}
}
Ok(requested.to_vec())
}
fn generate_client_id() -> String {
use argon2::password_hash::rand_core::OsRng;
let mut bytes = [0u8; 16];
OsRng.fill_bytes(&mut bytes);
hex::encode(bytes)
}
fn generate_client_secret() -> OAuthResult<(String, String)> {
use argon2::password_hash::rand_core::OsRng;
let mut bytes = [0u8; 32];
OsRng.fill_bytes(&mut bytes);
let plain = hex::encode(bytes);
let salt = SaltString::generate(&mut OsRng);
let hash = Argon2::default()
.hash_password(plain.as_bytes(), &salt)
.map_err(|e| OAuthError::HashError(e.to_string()))?;
Ok((plain, hash.to_string()))
}
fn verify_client_secret(plain: &str, hash: &str) -> OAuthResult<bool> {
let parsed = PasswordHash::new(hash).map_err(|e| OAuthError::HashError(e.to_string()))?;
Ok(Argon2::default()
.verify_password(plain.as_bytes(), &parsed)
.is_ok())
}
pub struct OAuthService;
impl OAuthService {
/// Client Credentials Grant — 验证客户端并签发 JWT
pub async fn token(
db: &DatabaseConnection,
req: &TokenRequest,
jwt_secret: &str,
) -> OAuthResult<TokenResponse> {
if req.grant_type != "client_credentials" {
return Err(OAuthError::UnsupportedGrantType);
}
let client = api_client::Entity::find()
.filter(api_client::Column::ClientId.eq(&req.client_id))
.filter(api_client::Column::DeletedAt.is_null())
.one(db)
.await?
.ok_or(OAuthError::InvalidClient)?;
if !client.is_active {
return Err(OAuthError::ClientInactive);
}
if !verify_client_secret(&req.client_secret, &client.client_secret_hash)? {
return Err(OAuthError::InvalidClient);
}
let granted_scopes = if let Some(ref scope_str) = req.scope {
let requested: Vec<String> = scope_str
.split(' ')
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect();
validate_scopes(&requested)?;
let allowed: Vec<String> =
serde_json::from_value(client.scopes.clone()).unwrap_or_default();
for s in &requested {
if !allowed.contains(s) {
return Err(OAuthError::InvalidScope);
}
}
requested
} else {
serde_json::from_value(client.scopes.clone()).unwrap_or_default()
};
let claims = ClientCredentialsClaims {
sub: client.id,
tid: client.tenant_id,
scopes: granted_scopes.clone(),
allowed_patient_ids: client
.allowed_patient_ids
.as_ref()
.and_then(|v| serde_json::from_value(v.clone()).ok()),
rate_limit_per_minute: client.rate_limit_per_minute,
exp: Utc::now().timestamp() + client.token_lifetime_seconds as i64,
iat: Utc::now().timestamp(),
token_type: "client_credentials".to_string(),
};
let header = jsonwebtoken::Header::default();
let token = jsonwebtoken::encode(
&header,
&claims,
&jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes()),
)?;
Ok(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: client.token_lifetime_seconds as i64,
scope: granted_scopes.join(" "),
})
}
/// 创建新的 API 客户端
pub async fn create_client(
db: &DatabaseConnection,
tenant_id: Uuid,
req: &CreateApiClientReq,
created_by: Uuid,
) -> OAuthResult<ApiClientResp> {
let scopes = validate_scopes(&req.scopes)?;
let client_id = generate_client_id();
let (secret_plain, secret_hash) = generate_client_secret()?;
let allowed_patient_ids_json = req
.allowed_patient_ids
.as_ref()
.map(|ids| serde_json::to_value(ids).unwrap_or(serde_json::Value::Null));
let now = Utc::now();
let active_model = api_client::ActiveModel {
id: Set(Uuid::now_v7()),
tenant_id: Set(tenant_id),
client_id: Set(client_id.clone()),
client_secret_hash: Set(secret_hash),
client_name: Set(req.client_name.clone()),
scopes: Set(serde_json::to_value(&scopes).unwrap_or(serde_json::Value::Array(vec![]))),
allowed_patient_ids: Set(allowed_patient_ids_json),
rate_limit_per_minute: Set(req.rate_limit_per_minute),
is_active: Set(true),
token_lifetime_seconds: Set(req.token_lifetime_seconds),
created_at: Set(now.into()),
updated_at: Set(now.into()),
created_by: Set(Some(created_by)),
updated_by: Set(None),
deleted_at: Set(None),
version: Set(1),
};
let model = active_model.insert(db).await?;
Ok(ApiClientResp {
id: model.id.to_string(),
tenant_id: model.tenant_id.to_string(),
client_id,
client_secret: secret_plain,
client_name: model.client_name,
scopes,
allowed_patient_ids: req.allowed_patient_ids.clone(),
rate_limit_per_minute: model.rate_limit_per_minute,
is_active: model.is_active,
token_lifetime_seconds: model.token_lifetime_seconds,
created_at: model.created_at.to_rfc3339(),
})
}
/// 列出租户下的 API 客户端
pub async fn list_clients(
db: &DatabaseConnection,
tenant_id: Uuid,
) -> OAuthResult<Vec<ApiClientListItem>> {
let clients = api_client::Entity::find()
.filter(api_client::Column::TenantId.eq(tenant_id))
.filter(api_client::Column::DeletedAt.is_null())
.all(db)
.await?;
Ok(clients
.into_iter()
.map(|c| ApiClientListItem {
id: c.id.to_string(),
client_id: c.client_id,
client_name: c.client_name,
scopes: serde_json::from_value(c.scopes).unwrap_or_default(),
rate_limit_per_minute: c.rate_limit_per_minute,
is_active: c.is_active,
token_lifetime_seconds: c.token_lifetime_seconds,
created_at: c.created_at.to_rfc3339(),
})
.collect())
}
/// 更新 API 客户端
pub async fn update_client(
db: &DatabaseConnection,
tenant_id: Uuid,
client_id: Uuid,
req: &UpdateApiClientReq,
updated_by: Uuid,
) -> OAuthResult<ApiClientListItem> {
let client = api_client::Entity::find_by_id(client_id)
.one(db)
.await?
.ok_or(OAuthError::ClientNotFound)?;
if client.tenant_id != tenant_id {
return Err(OAuthError::ClientNotFound);
}
if client.version != req.version {
return Err(OAuthError::DbError("版本冲突".into()));
}
let scopes = if let Some(ref s) = req.scopes {
validate_scopes(s)?;
serde_json::to_value(s).unwrap_or(serde_json::Value::Array(vec![]))
} else {
client.scopes.clone()
};
let mut active: api_client::ActiveModel = client.into();
if let Some(ref name) = req.client_name {
active.client_name = Set(name.clone());
}
if req.scopes.is_some() {
active.scopes = Set(scopes);
}
if req.allowed_patient_ids.is_some() {
let ids_json = req.allowed_patient_ids.as_ref().unwrap().as_ref().map(
|ids| serde_json::to_value(ids).unwrap_or(serde_json::Value::Null),
);
active.allowed_patient_ids = Set(ids_json);
}
if let Some(rl) = req.rate_limit_per_minute {
active.rate_limit_per_minute = Set(rl);
}
if let Some(active_flag) = req.is_active {
active.is_active = Set(active_flag);
}
if let Some(tl) = req.token_lifetime_seconds {
active.token_lifetime_seconds = Set(tl);
}
active.updated_by = Set(Some(updated_by));
active.updated_at = Set(Utc::now().into());
active.version = Set(active.version.unwrap() + 1);
let model = active.update(db).await?;
Ok(ApiClientListItem {
id: model.id.to_string(),
client_id: model.client_id,
client_name: model.client_name,
scopes: serde_json::from_value(model.scopes).unwrap_or_default(),
rate_limit_per_minute: model.rate_limit_per_minute,
is_active: model.is_active,
token_lifetime_seconds: model.token_lifetime_seconds,
created_at: model.created_at.to_rfc3339(),
})
}
/// 软删除 API 客户端
pub async fn delete_client(
db: &DatabaseConnection,
tenant_id: Uuid,
client_id: Uuid,
) -> OAuthResult<()> {
let client = api_client::Entity::find_by_id(client_id)
.one(db)
.await?
.ok_or(OAuthError::ClientNotFound)?;
if client.tenant_id != tenant_id {
return Err(OAuthError::ClientNotFound);
}
let mut active: api_client::ActiveModel = client.into();
active.deleted_at = Set(Some(Utc::now().into()));
active.update(db).await?;
Ok(())
}
/// 重新生成 client_secret
pub async fn regenerate_secret(
db: &DatabaseConnection,
tenant_id: Uuid,
client_id: Uuid,
) -> OAuthResult<(String, String)> {
let client = api_client::Entity::find_by_id(client_id)
.one(db)
.await?
.ok_or(OAuthError::ClientNotFound)?;
if client.tenant_id != tenant_id {
return Err(OAuthError::ClientNotFound);
}
let (plain, hash) = generate_client_secret()?;
let mut active: api_client::ActiveModel = client.into();
let id = active.id.clone().unwrap().to_string();
active.client_secret_hash = Set(hash);
active.updated_at = Set(Utc::now().into());
active.version = Set(active.version.clone().unwrap() + 1);
active.update(db).await?;
Ok((id, plain))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_scopes_accepts_valid() {
let scopes = vec!["Patient.read".into(), "Observation.read".into()];
assert!(validate_scopes(&scopes).is_ok());
}
#[test]
fn validate_scopes_rejects_invalid() {
let scopes = vec!["Patient.write".into()];
assert!(validate_scopes(&scopes).is_err());
}
#[test]
fn validate_scopes_accepts_empty() {
let scopes: Vec<String> = vec![];
let result = validate_scopes(&scopes);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn generate_client_id_is_32_hex_chars() {
let id = generate_client_id();
assert_eq!(id.len(), 32);
assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn generate_client_secret_produces_valid_hash() {
let (plain, hash) = generate_client_secret().unwrap();
assert_eq!(plain.len(), 64);
assert!(hash.starts_with("$argon2"));
assert!(verify_client_secret(&plain, &hash).unwrap());
}
#[test]
fn verify_client_secret_rejects_wrong() {
let (plain, hash) = generate_client_secret().unwrap();
assert!(!verify_client_secret("wrong_secret", &hash).unwrap());
}
#[test]
fn token_request_dto_constructable() {
let req = TokenRequest {
grant_type: "authorization_code".into(),
client_id: "test".into(),
client_secret: "test".into(),
scope: None,
};
assert_eq!(req.grant_type, "authorization_code");
}
}