- 修复 RngCore import:使用 rand_core::RngCore 替代 argon2 password_hash 重导出 - 修复 ActiveModel version/id move 问题:先读取再 unwrap - 添加 rand_core 依赖
396 lines
13 KiB
Rust
396 lines
13 KiB
Rust
use argon2::{
|
|
Argon2,
|
|
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
|
};
|
|
use rand_core::RngCore;
|
|
use chrono::Utc;
|
|
use sea_orm::{
|
|
ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set,
|
|
};
|
|
use uuid::Uuid;
|
|
|
|
use crate::entity::api_client;
|
|
use crate::oauth::dto::*;
|
|
use crate::oauth::error::{OAuthError, OAuthResult};
|
|
use crate::oauth::middleware::ClientCredentialsClaims;
|
|
|
|
const ALLOWED_SCOPES: &[&str] = &[
|
|
"Patient.read",
|
|
"Observation.read",
|
|
"Device.read",
|
|
"DiagnosticReport.read",
|
|
"Encounter.read",
|
|
"Practitioner.read",
|
|
"Appointment.read",
|
|
"Task.read",
|
|
];
|
|
|
|
fn validate_scopes(requested: &[String]) -> OAuthResult<Vec<String>> {
|
|
for scope in requested {
|
|
if !ALLOWED_SCOPES.contains(&scope.as_str()) {
|
|
return Err(OAuthError::InvalidScope);
|
|
}
|
|
}
|
|
Ok(requested.to_vec())
|
|
}
|
|
|
|
fn generate_client_id() -> String {
|
|
use rand_core::OsRng;
|
|
let mut bytes = [0u8; 16];
|
|
OsRng.fill_bytes(&mut bytes);
|
|
hex::encode(bytes)
|
|
}
|
|
|
|
fn generate_client_secret() -> OAuthResult<(String, String)> {
|
|
use rand_core::OsRng;
|
|
let mut bytes = [0u8; 32];
|
|
OsRng.fill_bytes(&mut bytes);
|
|
let plain = hex::encode(bytes);
|
|
|
|
let salt = SaltString::generate(&mut OsRng);
|
|
let hash = Argon2::default()
|
|
.hash_password(plain.as_bytes(), &salt)
|
|
.map_err(|e| OAuthError::HashError(e.to_string()))?;
|
|
|
|
Ok((plain, hash.to_string()))
|
|
}
|
|
|
|
fn verify_client_secret(plain: &str, hash: &str) -> OAuthResult<bool> {
|
|
let parsed = PasswordHash::new(hash).map_err(|e| OAuthError::HashError(e.to_string()))?;
|
|
Ok(Argon2::default()
|
|
.verify_password(plain.as_bytes(), &parsed)
|
|
.is_ok())
|
|
}
|
|
|
|
pub struct OAuthService;
|
|
|
|
impl OAuthService {
|
|
/// Client Credentials Grant — 验证客户端并签发 JWT
|
|
pub async fn token(
|
|
db: &DatabaseConnection,
|
|
req: &TokenRequest,
|
|
jwt_secret: &str,
|
|
) -> OAuthResult<TokenResponse> {
|
|
if req.grant_type != "client_credentials" {
|
|
return Err(OAuthError::UnsupportedGrantType);
|
|
}
|
|
|
|
let client = api_client::Entity::find()
|
|
.filter(api_client::Column::ClientId.eq(&req.client_id))
|
|
.filter(api_client::Column::DeletedAt.is_null())
|
|
.one(db)
|
|
.await?
|
|
.ok_or(OAuthError::InvalidClient)?;
|
|
|
|
if !client.is_active {
|
|
return Err(OAuthError::ClientInactive);
|
|
}
|
|
|
|
if !verify_client_secret(&req.client_secret, &client.client_secret_hash)? {
|
|
return Err(OAuthError::InvalidClient);
|
|
}
|
|
|
|
let granted_scopes = if let Some(ref scope_str) = req.scope {
|
|
let requested: Vec<String> = scope_str
|
|
.split(' ')
|
|
.filter(|s| !s.is_empty())
|
|
.map(|s| s.to_string())
|
|
.collect();
|
|
validate_scopes(&requested)?;
|
|
|
|
let allowed: Vec<String> =
|
|
serde_json::from_value(client.scopes.clone()).unwrap_or_default();
|
|
for s in &requested {
|
|
if !allowed.contains(s) {
|
|
return Err(OAuthError::InvalidScope);
|
|
}
|
|
}
|
|
requested
|
|
} else {
|
|
serde_json::from_value(client.scopes.clone()).unwrap_or_default()
|
|
};
|
|
|
|
let claims = ClientCredentialsClaims {
|
|
sub: client.id,
|
|
tid: client.tenant_id,
|
|
scopes: granted_scopes.clone(),
|
|
allowed_patient_ids: client
|
|
.allowed_patient_ids
|
|
.as_ref()
|
|
.and_then(|v| serde_json::from_value(v.clone()).ok()),
|
|
rate_limit_per_minute: client.rate_limit_per_minute,
|
|
exp: Utc::now().timestamp() + client.token_lifetime_seconds as i64,
|
|
iat: Utc::now().timestamp(),
|
|
token_type: "client_credentials".to_string(),
|
|
};
|
|
|
|
let header = jsonwebtoken::Header::default();
|
|
let token = jsonwebtoken::encode(
|
|
&header,
|
|
&claims,
|
|
&jsonwebtoken::EncodingKey::from_secret(jwt_secret.as_bytes()),
|
|
)?;
|
|
|
|
Ok(TokenResponse {
|
|
access_token: token,
|
|
token_type: "Bearer".to_string(),
|
|
expires_in: client.token_lifetime_seconds as i64,
|
|
scope: granted_scopes.join(" "),
|
|
})
|
|
}
|
|
|
|
/// 创建新的 API 客户端
|
|
pub async fn create_client(
|
|
db: &DatabaseConnection,
|
|
tenant_id: Uuid,
|
|
req: &CreateApiClientReq,
|
|
created_by: Uuid,
|
|
) -> OAuthResult<ApiClientResp> {
|
|
let scopes = validate_scopes(&req.scopes)?;
|
|
|
|
let client_id = generate_client_id();
|
|
let (secret_plain, secret_hash) = generate_client_secret()?;
|
|
|
|
let allowed_patient_ids_json = req
|
|
.allowed_patient_ids
|
|
.as_ref()
|
|
.map(|ids| serde_json::to_value(ids).unwrap_or(serde_json::Value::Null));
|
|
|
|
let now = Utc::now();
|
|
let active_model = api_client::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
tenant_id: Set(tenant_id),
|
|
client_id: Set(client_id.clone()),
|
|
client_secret_hash: Set(secret_hash),
|
|
client_name: Set(req.client_name.clone()),
|
|
scopes: Set(serde_json::to_value(&scopes).unwrap_or(serde_json::Value::Array(vec![]))),
|
|
allowed_patient_ids: Set(allowed_patient_ids_json),
|
|
rate_limit_per_minute: Set(req.rate_limit_per_minute),
|
|
is_active: Set(true),
|
|
token_lifetime_seconds: Set(req.token_lifetime_seconds),
|
|
created_at: Set(now.into()),
|
|
updated_at: Set(now.into()),
|
|
created_by: Set(Some(created_by)),
|
|
updated_by: Set(None),
|
|
deleted_at: Set(None),
|
|
version: Set(1),
|
|
};
|
|
|
|
let model = active_model.insert(db).await?;
|
|
|
|
Ok(ApiClientResp {
|
|
id: model.id.to_string(),
|
|
tenant_id: model.tenant_id.to_string(),
|
|
client_id,
|
|
client_secret: secret_plain,
|
|
client_name: model.client_name,
|
|
scopes,
|
|
allowed_patient_ids: req.allowed_patient_ids.clone(),
|
|
rate_limit_per_minute: model.rate_limit_per_minute,
|
|
is_active: model.is_active,
|
|
token_lifetime_seconds: model.token_lifetime_seconds,
|
|
created_at: model.created_at.to_rfc3339(),
|
|
})
|
|
}
|
|
|
|
/// 列出租户下的 API 客户端
|
|
pub async fn list_clients(
|
|
db: &DatabaseConnection,
|
|
tenant_id: Uuid,
|
|
) -> OAuthResult<Vec<ApiClientListItem>> {
|
|
let clients = api_client::Entity::find()
|
|
.filter(api_client::Column::TenantId.eq(tenant_id))
|
|
.filter(api_client::Column::DeletedAt.is_null())
|
|
.all(db)
|
|
.await?;
|
|
|
|
Ok(clients
|
|
.into_iter()
|
|
.map(|c| ApiClientListItem {
|
|
id: c.id.to_string(),
|
|
client_id: c.client_id,
|
|
client_name: c.client_name,
|
|
scopes: serde_json::from_value(c.scopes).unwrap_or_default(),
|
|
rate_limit_per_minute: c.rate_limit_per_minute,
|
|
is_active: c.is_active,
|
|
token_lifetime_seconds: c.token_lifetime_seconds,
|
|
created_at: c.created_at.to_rfc3339(),
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
/// 更新 API 客户端
|
|
pub async fn update_client(
|
|
db: &DatabaseConnection,
|
|
tenant_id: Uuid,
|
|
client_id: Uuid,
|
|
req: &UpdateApiClientReq,
|
|
updated_by: Uuid,
|
|
) -> OAuthResult<ApiClientListItem> {
|
|
let client = api_client::Entity::find_by_id(client_id)
|
|
.one(db)
|
|
.await?
|
|
.ok_or(OAuthError::ClientNotFound)?;
|
|
|
|
if client.tenant_id != tenant_id {
|
|
return Err(OAuthError::ClientNotFound);
|
|
}
|
|
|
|
if client.version != req.version {
|
|
return Err(OAuthError::DbError("版本冲突".into()));
|
|
}
|
|
|
|
let scopes = if let Some(ref s) = req.scopes {
|
|
validate_scopes(s)?;
|
|
serde_json::to_value(s).unwrap_or(serde_json::Value::Array(vec![]))
|
|
} else {
|
|
client.scopes.clone()
|
|
};
|
|
|
|
let mut active: api_client::ActiveModel = client.into();
|
|
if let Some(ref name) = req.client_name {
|
|
active.client_name = Set(name.clone());
|
|
}
|
|
if req.scopes.is_some() {
|
|
active.scopes = Set(scopes);
|
|
}
|
|
if req.allowed_patient_ids.is_some() {
|
|
let ids_json = req.allowed_patient_ids.as_ref().unwrap().as_ref().map(
|
|
|ids| serde_json::to_value(ids).unwrap_or(serde_json::Value::Null),
|
|
);
|
|
active.allowed_patient_ids = Set(ids_json);
|
|
}
|
|
if let Some(rl) = req.rate_limit_per_minute {
|
|
active.rate_limit_per_minute = Set(rl);
|
|
}
|
|
if let Some(active_flag) = req.is_active {
|
|
active.is_active = Set(active_flag);
|
|
}
|
|
if let Some(tl) = req.token_lifetime_seconds {
|
|
active.token_lifetime_seconds = Set(tl);
|
|
}
|
|
active.updated_by = Set(Some(updated_by));
|
|
active.updated_at = Set(Utc::now().into());
|
|
active.version = Set(req.version + 1);
|
|
|
|
let model = active.update(db).await?;
|
|
|
|
Ok(ApiClientListItem {
|
|
id: model.id.to_string(),
|
|
client_id: model.client_id,
|
|
client_name: model.client_name,
|
|
scopes: serde_json::from_value(model.scopes).unwrap_or_default(),
|
|
rate_limit_per_minute: model.rate_limit_per_minute,
|
|
is_active: model.is_active,
|
|
token_lifetime_seconds: model.token_lifetime_seconds,
|
|
created_at: model.created_at.to_rfc3339(),
|
|
})
|
|
}
|
|
|
|
/// 软删除 API 客户端
|
|
pub async fn delete_client(
|
|
db: &DatabaseConnection,
|
|
tenant_id: Uuid,
|
|
client_id: Uuid,
|
|
) -> OAuthResult<()> {
|
|
let client = api_client::Entity::find_by_id(client_id)
|
|
.one(db)
|
|
.await?
|
|
.ok_or(OAuthError::ClientNotFound)?;
|
|
|
|
if client.tenant_id != tenant_id {
|
|
return Err(OAuthError::ClientNotFound);
|
|
}
|
|
|
|
let mut active: api_client::ActiveModel = client.into();
|
|
active.deleted_at = Set(Some(Utc::now().into()));
|
|
active.update(db).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 重新生成 client_secret
|
|
pub async fn regenerate_secret(
|
|
db: &DatabaseConnection,
|
|
tenant_id: Uuid,
|
|
client_id: Uuid,
|
|
) -> OAuthResult<(String, String)> {
|
|
let client = api_client::Entity::find_by_id(client_id)
|
|
.one(db)
|
|
.await?
|
|
.ok_or(OAuthError::ClientNotFound)?;
|
|
|
|
if client.tenant_id != tenant_id {
|
|
return Err(OAuthError::ClientNotFound);
|
|
}
|
|
|
|
let (plain, hash) = generate_client_secret()?;
|
|
|
|
let mut active: api_client::ActiveModel = client.into();
|
|
active.client_secret_hash = Set(hash);
|
|
active.updated_at = Set(Utc::now().into());
|
|
active.version = Set(active.version.clone().unwrap() + 1);
|
|
let id = active.id.clone().unwrap().to_string();
|
|
active.update(db).await?;
|
|
|
|
Ok((id, plain))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn validate_scopes_accepts_valid() {
|
|
let scopes = vec!["Patient.read".into(), "Observation.read".into()];
|
|
assert!(validate_scopes(&scopes).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_scopes_rejects_invalid() {
|
|
let scopes = vec!["Patient.write".into()];
|
|
assert!(validate_scopes(&scopes).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn validate_scopes_accepts_empty() {
|
|
let scopes: Vec<String> = vec![];
|
|
let result = validate_scopes(&scopes);
|
|
assert!(result.is_ok());
|
|
assert!(result.unwrap().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn generate_client_id_is_32_hex_chars() {
|
|
let id = generate_client_id();
|
|
assert_eq!(id.len(), 32);
|
|
assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
|
|
}
|
|
|
|
#[test]
|
|
fn generate_client_secret_produces_valid_hash() {
|
|
let (plain, hash) = generate_client_secret().unwrap();
|
|
assert_eq!(plain.len(), 64);
|
|
assert!(hash.starts_with("$argon2"));
|
|
assert!(verify_client_secret(&plain, &hash).unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn verify_client_secret_rejects_wrong() {
|
|
let (plain, hash) = generate_client_secret().unwrap();
|
|
assert!(!verify_client_secret("wrong_secret", &hash).unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn token_request_dto_constructable() {
|
|
let req = TokenRequest {
|
|
grant_type: "authorization_code".into(),
|
|
client_id: "test".into(),
|
|
client_secret: "test".into(),
|
|
scope: None,
|
|
};
|
|
assert_eq!(req.grant_type, "authorization_code");
|
|
}
|
|
}
|