feat(message): add message center module (Phase 5)

Implement the complete message center with:
- Database migrations for message_templates, messages, message_subscriptions tables
- erp-message crate with entities, DTOs, services, handlers
- Message CRUD, send, read/unread tracking, soft delete
- Template management with variable interpolation
- Subscription preferences with DND support
- Frontend: messages page, notification panel, unread count badge
- Server integration with module registration and routing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
iven
2026-04-11 12:25:05 +08:00
parent 91ecaa3ed7
commit 5ceed71e62
35 changed files with 2252 additions and 15 deletions

View File

@@ -5,12 +5,16 @@ edition.workspace = true
[dependencies]
erp-core.workspace = true
tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
uuid.workspace = true
chrono.workspace = true
axum.workspace = true
sea-orm.workspace = true
tracing.workspace = true
tokio = { workspace = true, features = ["full"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
uuid = { workspace = true, features = ["v7", "serde"] }
chrono = { workspace = true, features = ["serde"] }
axum = { workspace = true }
sea-orm = { workspace = true, features = ["sqlx-postgres", "runtime-tokio-rustls", "with-uuid", "with-chrono", "with-json"] }
tracing = { workspace = true }
anyhow.workspace = true
thiserror.workspace = true
utoipa = { workspace = true, features = ["uuid", "chrono"] }
async-trait.workspace = true
validator.workspace = true

View File

@@ -0,0 +1,144 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use utoipa::ToSchema;
use validator::Validate;
// ============ 消息 DTO ============
/// 消息响应
#[derive(Debug, Serialize, ToSchema)]
pub struct MessageResp {
pub id: Uuid,
pub tenant_id: Uuid,
pub template_id: Option<Uuid>,
pub sender_id: Option<Uuid>,
pub sender_type: String,
pub recipient_id: Uuid,
pub recipient_type: String,
pub title: String,
pub body: String,
pub priority: String,
pub business_type: Option<String>,
pub business_id: Option<Uuid>,
pub is_read: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub read_at: Option<DateTime<Utc>>,
pub is_archived: bool,
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sent_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 发送消息请求
#[derive(Debug, Deserialize, Validate, ToSchema)]
pub struct SendMessageReq {
#[validate(length(min = 1, max = 200, message = "标题不能为空且不超过200字符"))]
pub title: String,
#[validate(length(min = 1, message = "内容不能为空"))]
pub body: String,
pub recipient_id: Uuid,
#[serde(default = "default_recipient_type")]
pub recipient_type: String,
#[serde(default = "default_priority")]
pub priority: String,
pub template_id: Option<Uuid>,
pub business_type: Option<String>,
pub business_id: Option<Uuid>,
}
fn default_recipient_type() -> String {
"user".to_string()
}
fn default_priority() -> String {
"normal".to_string()
}
/// 消息列表查询参数
#[derive(Debug, Deserialize, ToSchema)]
pub struct MessageQuery {
pub page: Option<u64>,
pub page_size: Option<u64>,
pub is_read: Option<bool>,
pub priority: Option<String>,
pub business_type: Option<String>,
pub status: Option<String>,
}
/// 未读消息计数响应
#[derive(Debug, Serialize, ToSchema)]
pub struct UnreadCountResp {
pub count: i64,
}
// ============ 消息模板 DTO ============
/// 消息模板响应
#[derive(Debug, Serialize, ToSchema)]
pub struct MessageTemplateResp {
pub id: Uuid,
pub tenant_id: Uuid,
pub name: String,
pub code: String,
pub channel: String,
pub title_template: String,
pub body_template: String,
pub language: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 创建消息模板请求
#[derive(Debug, Deserialize, Validate, ToSchema)]
pub struct CreateTemplateReq {
#[validate(length(min = 1, max = 100, message = "名称不能为空且不超过100字符"))]
pub name: String,
#[validate(length(min = 1, max = 50, message = "编码不能为空且不超过50字符"))]
pub code: String,
#[serde(default = "default_channel")]
pub channel: String,
#[validate(length(min = 1, max = 200, message = "标题模板不能为空"))]
pub title_template: String,
#[validate(length(min = 1, message = "内容模板不能为空"))]
pub body_template: String,
#[serde(default = "default_language")]
pub language: String,
}
fn default_channel() -> String {
"in_app".to_string()
}
fn default_language() -> String {
"zh-CN".to_string()
}
// ============ 消息订阅偏好 DTO ============
/// 消息订阅偏好响应
#[derive(Debug, Serialize, ToSchema)]
pub struct MessageSubscriptionResp {
pub id: Uuid,
pub tenant_id: Uuid,
pub user_id: Uuid,
pub notification_types: Option<serde_json::Value>,
pub channel_preferences: Option<serde_json::Value>,
pub dnd_enabled: bool,
pub dnd_start: Option<String>,
pub dnd_end: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 更新消息订阅偏好请求
#[derive(Debug, Deserialize, ToSchema)]
pub struct UpdateSubscriptionReq {
pub notification_types: Option<serde_json::Value>,
pub channel_preferences: Option<serde_json::Value>,
pub dnd_enabled: Option<bool>,
pub dnd_start: Option<String>,
pub dnd_end: Option<String>,
}

View File

@@ -0,0 +1,57 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "messages")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: Uuid,
pub tenant_id: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
pub template_id: Option<Uuid>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sender_id: Option<Uuid>,
pub sender_type: String,
pub recipient_id: Uuid,
pub recipient_type: String,
pub title: String,
pub body: String,
pub priority: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub business_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub business_id: Option<Uuid>,
pub is_read: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub read_at: Option<DateTimeUtc>,
pub is_archived: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub archived_at: Option<DateTimeUtc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sent_at: Option<DateTimeUtc>,
pub status: String,
pub created_at: DateTimeUtc,
pub updated_at: DateTimeUtc,
pub created_by: Uuid,
pub updated_by: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_at: Option<DateTimeUtc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::message_template::Entity",
from = "Column::TemplateId",
to = "super::message_template::Column::Id"
)]
MessageTemplate,
}
impl Related<super::message_template::Entity> for Entity {
fn to() -> RelationDef {
Relation::MessageTemplate.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,31 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "message_subscriptions")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: Uuid,
pub tenant_id: Uuid,
pub user_id: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
pub notification_types: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub channel_preferences: Option<serde_json::Value>,
pub dnd_enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub dnd_start: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dnd_end: Option<String>,
pub created_at: DateTimeUtc,
pub updated_at: DateTimeUtc,
pub created_by: Uuid,
pub updated_by: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_at: Option<DateTimeUtc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,36 @@
use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "message_templates")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: Uuid,
pub tenant_id: Uuid,
pub name: String,
pub code: String,
pub channel: String,
pub title_template: String,
pub body_template: String,
pub language: String,
pub created_at: DateTimeUtc,
pub updated_at: DateTimeUtc,
pub created_by: Uuid,
pub updated_by: Uuid,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_at: Option<DateTimeUtc>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::message::Entity")]
Message,
}
impl Related<super::message::Entity> for Entity {
fn to() -> RelationDef {
Relation::Message.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -0,0 +1,3 @@
pub mod message;
pub mod message_subscription;
pub mod message_template;

View File

@@ -0,0 +1,41 @@
use erp_core::error::AppError;
/// 消息中心模块错误类型。
#[derive(Debug, thiserror::Error)]
pub enum MessageError {
#[error("验证失败: {0}")]
Validation(String),
#[error("未找到: {0}")]
NotFound(String),
#[error("模板编码已存在: {0}")]
DuplicateTemplateCode(String),
#[error("渲染失败: {0}")]
TemplateRenderError(String),
}
impl From<MessageError> for AppError {
fn from(err: MessageError) -> Self {
match err {
MessageError::Validation(msg) => AppError::Validation(msg),
MessageError::NotFound(msg) => AppError::NotFound(msg),
MessageError::DuplicateTemplateCode(msg) => AppError::Conflict(msg),
MessageError::TemplateRenderError(msg) => AppError::Internal(msg),
}
}
}
impl From<sea_orm::TransactionError<MessageError>> for MessageError {
fn from(err: sea_orm::TransactionError<MessageError>) -> Self {
match err {
sea_orm::TransactionError::Connection(db_err) => {
MessageError::Validation(db_err.to_string())
}
sea_orm::TransactionError::Transaction(msg_err) => msg_err,
}
}
}
pub type MessageResult<T> = Result<T, MessageError>;

View File

@@ -0,0 +1,122 @@
use axum::extract::{Extension, Path, Query, State};
use axum::extract::FromRef;
use axum::Json;
use uuid::Uuid;
use erp_core::error::AppError;
use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, PaginatedResponse, TenantContext};
use validator::Validate;
use crate::dto::{MessageQuery, MessageResp, SendMessageReq, UnreadCountResp};
use crate::message_state::MessageState;
use crate::service::message_service::MessageService;
/// 查询消息列表。
pub async fn list_messages<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Query(query): Query<MessageQuery>,
) -> Result<Json<ApiResponse<PaginatedResponse<MessageResp>>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "message:list")?;
let db = &_state.db;
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let (messages, total) = MessageService::list(ctx.tenant_id, ctx.user_id, &query, db).await?;
let total_pages = (total + page_size - 1) / page_size;
Ok(Json(ApiResponse::ok(PaginatedResponse {
data: messages,
total,
page,
page_size,
total_pages,
})))
}
/// 获取未读消息数量。
pub async fn unread_count<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<UnreadCountResp>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
let result = MessageService::unread_count(ctx.tenant_id, ctx.user_id, &_state.db).await?;
Ok(Json(ApiResponse::ok(result)))
}
/// 发送消息。
pub async fn send_message<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Json(req): Json<SendMessageReq>,
) -> Result<Json<ApiResponse<MessageResp>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "message:send")?;
req.validate()
.map_err(|e| AppError::Validation(e.to_string()))?;
let resp = MessageService::send(
ctx.tenant_id,
ctx.user_id,
&req,
&_state.db,
&_state.event_bus,
)
.await?;
Ok(Json(ApiResponse::ok(resp)))
}
/// 标记消息已读。
pub async fn mark_read<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Path(id): Path<Uuid>,
) -> Result<Json<ApiResponse<()>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
MessageService::mark_read(id, ctx.tenant_id, ctx.user_id, &_state.db).await?;
Ok(Json(ApiResponse::ok(())))
}
/// 标记所有消息已读。
pub async fn mark_all_read<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
) -> Result<Json<ApiResponse<()>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
MessageService::mark_all_read(ctx.tenant_id, ctx.user_id, &_state.db).await?;
Ok(Json(ApiResponse::ok(())))
}
/// 删除消息。
pub async fn delete_message<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Path(id): Path<Uuid>,
) -> Result<Json<ApiResponse<()>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
MessageService::delete(id, ctx.tenant_id, ctx.user_id, &_state.db).await?;
Ok(Json(ApiResponse::ok(())))
}

View File

@@ -0,0 +1,3 @@
pub mod message_handler;
pub mod subscription_handler;
pub mod template_handler;

View File

@@ -0,0 +1,31 @@
use axum::extract::{Extension, State};
use axum::extract::FromRef;
use axum::Json;
use erp_core::error::AppError;
use erp_core::types::{ApiResponse, TenantContext};
use crate::dto::UpdateSubscriptionReq;
use crate::message_state::MessageState;
use crate::service::subscription_service::SubscriptionService;
/// 更新消息订阅偏好。
pub async fn update_subscription<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Json(req): Json<UpdateSubscriptionReq>,
) -> Result<Json<ApiResponse<crate::dto::MessageSubscriptionResp>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
let resp = SubscriptionService::upsert(
ctx.tenant_id,
ctx.user_id,
&req,
&_state.db,
)
.await?;
Ok(Json(ApiResponse::ok(resp)))
}

View File

@@ -0,0 +1,66 @@
use axum::extract::{Extension, Query, State};
use axum::extract::FromRef;
use axum::Json;
use serde::Deserialize;
use erp_core::error::AppError;
use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, PaginatedResponse, TenantContext};
use validator::Validate;
use crate::dto::{CreateTemplateReq, MessageTemplateResp};
use crate::message_state::MessageState;
use crate::service::template_service::TemplateService;
#[derive(Debug, Deserialize)]
pub struct TemplateQuery {
pub page: Option<u64>,
pub page_size: Option<u64>,
}
/// 查询消息模板列表。
pub async fn list_templates<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Query(query): Query<TemplateQuery>,
) -> Result<Json<ApiResponse<PaginatedResponse<MessageTemplateResp>>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "message:template:list")?;
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let (templates, total) =
TemplateService::list(ctx.tenant_id, page, page_size, &_state.db).await?;
let total_pages = (total + page_size - 1) / page_size;
Ok(Json(ApiResponse::ok(PaginatedResponse {
data: templates,
total,
page,
page_size,
total_pages,
})))
}
/// 创建消息模板。
pub async fn create_template<S>(
State(_state): State<MessageState>,
Extension(ctx): Extension<TenantContext>,
Json(req): Json<CreateTemplateReq>,
) -> Result<Json<ApiResponse<MessageTemplateResp>>, AppError>
where
MessageState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "message:template:create")?;
req.validate()
.map_err(|e| AppError::Validation(e.to_string()))?;
let resp = TemplateService::create(ctx.tenant_id, ctx.user_id, &req, &_state.db).await?;
Ok(Json(ApiResponse::ok(resp)))
}

View File

@@ -1 +1,10 @@
// erp-message: 消息中心模块 (Phase 5)
pub mod dto;
pub mod entity;
pub mod error;
pub mod handler;
pub mod message_state;
pub mod module;
pub mod service;
pub use message_state::MessageState;
pub use module::MessageModule;

View File

@@ -0,0 +1,9 @@
use erp_core::events::EventBus;
use sea_orm::DatabaseConnection;
/// 消息中心模块状态,通过 FromRef 从 AppState 提取。
#[derive(Clone)]
pub struct MessageState {
pub db: DatabaseConnection,
pub event_bus: EventBus,
}

View File

@@ -0,0 +1,99 @@
use axum::Router;
use axum::routing::{delete, get, put};
use uuid::Uuid;
use erp_core::error::AppResult;
use erp_core::events::EventBus;
use erp_core::module::ErpModule;
use crate::handler::{
message_handler, subscription_handler, template_handler,
};
/// 消息中心模块,实现 ErpModule trait。
pub struct MessageModule;
impl MessageModule {
pub fn new() -> Self {
Self
}
/// 构建需要认证的路由。
pub fn protected_routes<S>() -> Router<S>
where
crate::message_state::MessageState: axum::extract::FromRef<S>,
S: Clone + Send + Sync + 'static,
{
Router::new()
// 消息路由
.route(
"/messages",
get(message_handler::list_messages).post(message_handler::send_message),
)
.route(
"/messages/unread-count",
get(message_handler::unread_count),
)
.route(
"/messages/{id}/read",
put(message_handler::mark_read),
)
.route(
"/messages/read-all",
put(message_handler::mark_all_read),
)
.route(
"/messages/{id}",
delete(message_handler::delete_message),
)
// 模板路由
.route(
"/message-templates",
get(template_handler::list_templates).post(template_handler::create_template),
)
// 订阅偏好路由
.route(
"/message-subscriptions",
put(subscription_handler::update_subscription),
)
}
}
impl Default for MessageModule {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl ErpModule for MessageModule {
fn name(&self) -> &str {
"message"
}
fn version(&self) -> &str {
env!("CARGO_PKG_VERSION")
}
fn dependencies(&self) -> Vec<&str> {
vec!["auth"]
}
fn register_routes(&self, router: Router) -> Router {
router
}
fn register_event_handlers(&self, _bus: &EventBus) {}
async fn on_tenant_created(&self, _tenant_id: Uuid) -> AppResult<()> {
Ok(())
}
async fn on_tenant_deleted(&self, _tenant_id: Uuid) -> AppResult<()> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}

View File

@@ -0,0 +1,316 @@
use chrono::Utc;
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, Set,
};
use uuid::Uuid;
use crate::dto::{MessageQuery, MessageResp, SendMessageReq, UnreadCountResp};
use crate::entity::message;
use crate::error::{MessageError, MessageResult};
use erp_core::events::EventBus;
/// 消息服务。
pub struct MessageService;
impl MessageService {
/// 查询消息列表(分页)。
pub async fn list(
tenant_id: Uuid,
recipient_id: Uuid,
query: &MessageQuery,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<(Vec<MessageResp>, u64)> {
let page_size = query.page_size.unwrap_or(20);
let mut q = message::Entity::find()
.filter(message::Column::TenantId.eq(tenant_id))
.filter(message::Column::RecipientId.eq(recipient_id))
.filter(message::Column::DeletedAt.is_null());
if let Some(is_read) = query.is_read {
q = q.filter(message::Column::IsRead.eq(is_read));
}
if let Some(ref priority) = query.priority {
q = q.filter(message::Column::Priority.eq(priority.as_str()));
}
if let Some(ref business_type) = query.business_type {
q = q.filter(message::Column::BusinessType.eq(business_type.as_str()));
}
if let Some(ref status) = query.status {
q = q.filter(message::Column::Status.eq(status.as_str()));
}
let paginator = q.paginate(db, page_size);
let total = paginator
.num_items()
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
let page_index = query.page.unwrap_or(1).saturating_sub(1) as u64;
let models = paginator
.fetch_page(page_index)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
let resps = models.iter().map(Self::model_to_resp).collect();
Ok((resps, total))
}
/// 获取未读消息数量。
pub async fn unread_count(
tenant_id: Uuid,
recipient_id: Uuid,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<UnreadCountResp> {
let count = message::Entity::find()
.filter(message::Column::TenantId.eq(tenant_id))
.filter(message::Column::RecipientId.eq(recipient_id))
.filter(message::Column::IsRead.eq(false))
.filter(message::Column::DeletedAt.is_null())
.count(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
Ok(UnreadCountResp {
count: count as i64,
})
}
/// 发送消息。
pub async fn send(
tenant_id: Uuid,
sender_id: Uuid,
req: &SendMessageReq,
db: &sea_orm::DatabaseConnection,
event_bus: &EventBus,
) -> MessageResult<MessageResp> {
let id = Uuid::now_v7();
let now = Utc::now();
let model = message::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
template_id: Set(req.template_id),
sender_id: Set(Some(sender_id)),
sender_type: Set("user".to_string()),
recipient_id: Set(req.recipient_id),
recipient_type: Set(req.recipient_type.clone()),
title: Set(req.title.clone()),
body: Set(req.body.clone()),
priority: Set(req.priority.clone()),
business_type: Set(req.business_type.clone()),
business_id: Set(req.business_id),
is_read: Set(false),
read_at: Set(None),
is_archived: Set(false),
archived_at: Set(None),
sent_at: Set(Some(now)),
status: Set("sent".to_string()),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(sender_id),
updated_by: Set(sender_id),
deleted_at: Set(None),
};
let inserted = model
.insert(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
event_bus.publish(erp_core::events::DomainEvent::new(
"message.sent",
tenant_id,
serde_json::json!({
"message_id": id,
"recipient_id": req.recipient_id,
"title": req.title,
}),
));
Ok(Self::model_to_resp(&inserted))
}
/// 系统发送消息(由事件处理器调用)。
pub async fn send_system(
tenant_id: Uuid,
recipient_id: Uuid,
title: String,
body: String,
priority: &str,
business_type: Option<String>,
business_id: Option<Uuid>,
db: &sea_orm::DatabaseConnection,
event_bus: &EventBus,
) -> MessageResult<MessageResp> {
let id = Uuid::now_v7();
let now = Utc::now();
let system_user = Uuid::nil();
let model = message::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
template_id: Set(None),
sender_id: Set(None),
sender_type: Set("system".to_string()),
recipient_id: Set(recipient_id),
recipient_type: Set("user".to_string()),
title: Set(title),
body: Set(body),
priority: Set(priority.to_string()),
business_type: Set(business_type),
business_id: Set(business_id),
is_read: Set(false),
read_at: Set(None),
is_archived: Set(false),
archived_at: Set(None),
sent_at: Set(Some(now)),
status: Set("sent".to_string()),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(system_user),
updated_by: Set(system_user),
deleted_at: Set(None),
};
let inserted = model
.insert(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
event_bus.publish(erp_core::events::DomainEvent::new(
"message.sent",
tenant_id,
serde_json::json!({
"message_id": id,
"recipient_id": recipient_id,
}),
));
Ok(Self::model_to_resp(&inserted))
}
/// 标记消息已读。
pub async fn mark_read(
id: Uuid,
tenant_id: Uuid,
user_id: Uuid,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<()> {
let model = message::Entity::find_by_id(id)
.one(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?
.filter(|m| m.tenant_id == tenant_id && m.deleted_at.is_none())
.ok_or_else(|| MessageError::NotFound(format!("消息不存在: {id}")))?;
if model.recipient_id != user_id {
return Err(MessageError::Validation(
"只能标记自己的消息为已读".to_string(),
));
}
if model.is_read {
return Ok(());
}
let mut active: message::ActiveModel = model.into();
active.is_read = Set(true);
active.read_at = Set(Some(Utc::now()));
active.updated_at = Set(Utc::now());
active.updated_by = Set(user_id);
active
.update(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
Ok(())
}
/// 标记所有消息已读。
pub async fn mark_all_read(
tenant_id: Uuid,
user_id: Uuid,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<()> {
let unread = message::Entity::find()
.filter(message::Column::TenantId.eq(tenant_id))
.filter(message::Column::RecipientId.eq(user_id))
.filter(message::Column::IsRead.eq(false))
.filter(message::Column::DeletedAt.is_null())
.all(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
let now = Utc::now();
for m in unread {
let mut active: message::ActiveModel = m.into();
active.is_read = Set(true);
active.read_at = Set(Some(now));
active.updated_at = Set(now);
active.updated_by = Set(user_id);
active
.update(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
}
Ok(())
}
/// 删除消息(软删除)。
pub async fn delete(
id: Uuid,
tenant_id: Uuid,
user_id: Uuid,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<()> {
let model = message::Entity::find_by_id(id)
.one(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?
.filter(|m| m.tenant_id == tenant_id && m.deleted_at.is_none())
.ok_or_else(|| MessageError::NotFound(format!("消息不存在: {id}")))?;
if model.recipient_id != user_id {
return Err(MessageError::Validation(
"只能删除自己的消息".to_string(),
));
}
let mut active: message::ActiveModel = model.into();
active.deleted_at = Set(Some(Utc::now()));
active.updated_at = Set(Utc::now());
active.updated_by = Set(user_id);
active
.update(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
Ok(())
}
fn model_to_resp(m: &message::Model) -> MessageResp {
MessageResp {
id: m.id,
tenant_id: m.tenant_id,
template_id: m.template_id,
sender_id: m.sender_id,
sender_type: m.sender_type.clone(),
recipient_id: m.recipient_id,
recipient_type: m.recipient_type.clone(),
title: m.title.clone(),
body: m.body.clone(),
priority: m.priority.clone(),
business_type: m.business_type.clone(),
business_id: m.business_id,
is_read: m.is_read,
read_at: m.read_at,
is_archived: m.is_archived,
status: m.status.clone(),
sent_at: m.sent_at,
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}

View File

@@ -0,0 +1,3 @@
pub mod message_service;
pub mod subscription_service;
pub mod template_service;

View File

@@ -0,0 +1,116 @@
use chrono::Utc;
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set};
use uuid::Uuid;
use crate::dto::{MessageSubscriptionResp, UpdateSubscriptionReq};
use crate::entity::message_subscription;
use crate::error::{MessageError, MessageResult};
/// 消息订阅偏好服务。
pub struct SubscriptionService;
impl SubscriptionService {
/// 获取用户订阅偏好。
pub async fn get(
tenant_id: Uuid,
user_id: Uuid,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<MessageSubscriptionResp> {
let model = message_subscription::Entity::find()
.filter(message_subscription::Column::TenantId.eq(tenant_id))
.filter(message_subscription::Column::UserId.eq(user_id))
.filter(message_subscription::Column::DeletedAt.is_null())
.one(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?
.ok_or_else(|| MessageError::NotFound("订阅偏好不存在".to_string()))?;
Ok(Self::model_to_resp(&model))
}
/// 创建或更新用户订阅偏好upsert
pub async fn upsert(
tenant_id: Uuid,
user_id: Uuid,
req: &UpdateSubscriptionReq,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<MessageSubscriptionResp> {
let existing = message_subscription::Entity::find()
.filter(message_subscription::Column::TenantId.eq(tenant_id))
.filter(message_subscription::Column::UserId.eq(user_id))
.filter(message_subscription::Column::DeletedAt.is_null())
.one(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
let now = Utc::now();
if let Some(model) = existing {
let mut active: message_subscription::ActiveModel = model.into();
if let Some(types) = &req.notification_types {
active.notification_types = Set(Some(types.clone()));
}
if let Some(prefs) = &req.channel_preferences {
active.channel_preferences = Set(Some(prefs.clone()));
}
if let Some(dnd) = req.dnd_enabled {
active.dnd_enabled = Set(dnd);
}
if let Some(ref start) = req.dnd_start {
active.dnd_start = Set(Some(start.clone()));
}
if let Some(ref end) = req.dnd_end {
active.dnd_end = Set(Some(end.clone()));
}
active.updated_at = Set(now);
active.updated_by = Set(user_id);
let updated = active
.update(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
Ok(Self::model_to_resp(&updated))
} else {
let id = Uuid::now_v7();
let model = message_subscription::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
user_id: Set(user_id),
notification_types: Set(req.notification_types.clone()),
channel_preferences: Set(req.channel_preferences.clone()),
dnd_enabled: Set(req.dnd_enabled.unwrap_or(false)),
dnd_start: Set(req.dnd_start.clone()),
dnd_end: Set(req.dnd_end.clone()),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(user_id),
updated_by: Set(user_id),
deleted_at: Set(None),
};
let inserted = model
.insert(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
Ok(Self::model_to_resp(&inserted))
}
}
fn model_to_resp(m: &message_subscription::Model) -> MessageSubscriptionResp {
MessageSubscriptionResp {
id: m.id,
tenant_id: m.tenant_id,
user_id: m.user_id,
notification_types: m.notification_types.clone(),
channel_preferences: m.channel_preferences.clone(),
dnd_enabled: m.dnd_enabled,
dnd_start: m.dnd_start.clone(),
dnd_end: m.dnd_end.clone(),
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}

View File

@@ -0,0 +1,116 @@
use chrono::Utc;
use sea_orm::{
ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, Set,
};
use uuid::Uuid;
use crate::dto::{CreateTemplateReq, MessageTemplateResp};
use crate::entity::message_template;
use crate::error::{MessageError, MessageResult};
/// 消息模板服务。
pub struct TemplateService;
impl TemplateService {
/// 查询模板列表。
pub async fn list(
tenant_id: Uuid,
page: u64,
page_size: u64,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<(Vec<MessageTemplateResp>, u64)> {
let paginator = message_template::Entity::find()
.filter(message_template::Column::TenantId.eq(tenant_id))
.filter(message_template::Column::DeletedAt.is_null())
.paginate(db, page_size);
let total = paginator
.num_items()
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
let page_index = page.saturating_sub(1);
let models = paginator
.fetch_page(page_index)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
let resps = models.iter().map(Self::model_to_resp).collect();
Ok((resps, total))
}
/// 创建消息模板。
pub async fn create(
tenant_id: Uuid,
operator_id: Uuid,
req: &CreateTemplateReq,
db: &sea_orm::DatabaseConnection,
) -> MessageResult<MessageTemplateResp> {
// 检查编码唯一性
let existing = message_template::Entity::find()
.filter(message_template::Column::TenantId.eq(tenant_id))
.filter(message_template::Column::Code.eq(&req.code))
.filter(message_template::Column::DeletedAt.is_null())
.one(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
if existing.is_some() {
return Err(MessageError::DuplicateTemplateCode(format!(
"模板编码已存在: {}",
req.code
)));
}
let id = Uuid::now_v7();
let now = Utc::now();
let model = message_template::ActiveModel {
id: Set(id),
tenant_id: Set(tenant_id),
name: Set(req.name.clone()),
code: Set(req.code.clone()),
channel: Set(req.channel.clone()),
title_template: Set(req.title_template.clone()),
body_template: Set(req.body_template.clone()),
language: Set(req.language.clone()),
created_at: Set(now),
updated_at: Set(now),
created_by: Set(operator_id),
updated_by: Set(operator_id),
deleted_at: Set(None),
};
let inserted = model
.insert(db)
.await
.map_err(|e| MessageError::Validation(e.to_string()))?;
Ok(Self::model_to_resp(&inserted))
}
/// 使用模板渲染消息内容。
/// 支持 {{variable}} 格式的变量插值。
pub fn render(template: &str, variables: &std::collections::HashMap<String, String>) -> String {
let mut result = template.to_string();
for (key, value) in variables {
result = result.replace(&format!("{{{{{}}}}}", key), value);
}
result
}
fn model_to_resp(m: &message_template::Model) -> MessageTemplateResp {
MessageTemplateResp {
id: m.id,
tenant_id: m.tenant_id,
name: m.name.clone(),
code: m.code.clone(),
channel: m.channel.clone(),
title_template: m.title_template.clone(),
body_template: m.body_template.clone(),
language: m.language.clone(),
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}

View File

@@ -22,6 +22,9 @@ mod m20260412_000019_create_process_instances;
mod m20260412_000020_create_tokens;
mod m20260412_000021_create_tasks;
mod m20260412_000022_create_process_variables;
mod m20260413_000023_create_message_templates;
mod m20260413_000024_create_messages;
mod m20260413_000025_create_message_subscriptions;
pub struct Migrator;
@@ -51,6 +54,9 @@ impl MigratorTrait for Migrator {
Box::new(m20260412_000020_create_tokens::Migration),
Box::new(m20260412_000021_create_tasks::Migration),
Box::new(m20260412_000022_create_process_variables::Migration),
Box::new(m20260413_000023_create_message_templates::Migration),
Box::new(m20260413_000024_create_messages::Migration),
Box::new(m20260413_000025_create_message_subscriptions::Migration),
]
}
}

View File

@@ -0,0 +1,93 @@
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(MessageTemplates::Table)
.if_not_exists()
.col(
ColumnDef::new(MessageTemplates::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(MessageTemplates::TenantId).uuid().not_null())
.col(
ColumnDef::new(MessageTemplates::Name)
.string()
.not_null(),
)
.col(
ColumnDef::new(MessageTemplates::Code)
.string()
.not_null(),
)
.col(
ColumnDef::new(MessageTemplates::Channel)
.string()
.not_null()
.default("in_app"),
)
.col(
ColumnDef::new(MessageTemplates::TitleTemplate)
.string()
.not_null(),
)
.col(
ColumnDef::new(MessageTemplates::BodyTemplate)
.text()
.not_null(),
)
.col(
ColumnDef::new(MessageTemplates::Language)
.string()
.not_null()
.default("zh-CN"),
)
.col(ColumnDef::new(MessageTemplates::CreatedAt).timestamp_with_time_zone().not_null())
.col(ColumnDef::new(MessageTemplates::UpdatedAt).timestamp_with_time_zone().not_null())
.col(ColumnDef::new(MessageTemplates::CreatedBy).uuid().not_null())
.col(ColumnDef::new(MessageTemplates::UpdatedBy).uuid().not_null())
.col(ColumnDef::new(MessageTemplates::DeletedAt).timestamp_with_time_zone().null())
.to_owned(),
)
.await?;
manager.get_connection().execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"CREATE UNIQUE INDEX idx_message_templates_tenant_code ON message_templates (tenant_id, code) WHERE deleted_at IS NULL".to_string(),
)).await.map_err(|e| DbErr::Custom(e.to_string()))?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(MessageTemplates::Table).to_owned())
.await
}
}
#[derive(DeriveIden)]
enum MessageTemplates {
Table,
Id,
TenantId,
Name,
Code,
Channel,
TitleTemplate,
BodyTemplate,
Language,
CreatedAt,
UpdatedAt,
CreatedBy,
UpdatedBy,
DeletedAt,
}

View File

@@ -0,0 +1,143 @@
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(Messages::Table)
.if_not_exists()
.col(
ColumnDef::new(Messages::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(Messages::TenantId).uuid().not_null())
.col(ColumnDef::new(Messages::TemplateId).uuid().null())
.col(ColumnDef::new(Messages::SenderId).uuid().null())
.col(
ColumnDef::new(Messages::SenderType)
.string()
.not_null()
.default("system"),
)
.col(ColumnDef::new(Messages::RecipientId).uuid().not_null())
.col(
ColumnDef::new(Messages::RecipientType)
.string()
.not_null()
.default("user"),
)
.col(ColumnDef::new(Messages::Title).string().not_null())
.col(ColumnDef::new(Messages::Body).text().not_null())
.col(
ColumnDef::new(Messages::Priority)
.string()
.not_null()
.default("normal"),
)
.col(ColumnDef::new(Messages::BusinessType).string().null())
.col(ColumnDef::new(Messages::BusinessId).uuid().null())
.col(
ColumnDef::new(Messages::IsRead)
.boolean()
.not_null()
.default(false),
)
.col(ColumnDef::new(Messages::ReadAt).timestamp_with_time_zone().null())
.col(
ColumnDef::new(Messages::IsArchived)
.boolean()
.not_null()
.default(false),
)
.col(ColumnDef::new(Messages::ArchivedAt).timestamp_with_time_zone().null())
.col(ColumnDef::new(Messages::SentAt).timestamp_with_time_zone().null())
.col(
ColumnDef::new(Messages::Status)
.string()
.not_null()
.default("sent"),
)
.col(ColumnDef::new(Messages::CreatedAt).timestamp_with_time_zone().not_null())
.col(ColumnDef::new(Messages::UpdatedAt).timestamp_with_time_zone().not_null())
.col(ColumnDef::new(Messages::CreatedBy).uuid().not_null())
.col(ColumnDef::new(Messages::UpdatedBy).uuid().not_null())
.col(ColumnDef::new(Messages::DeletedAt).timestamp_with_time_zone().null())
.to_owned(),
)
.await?;
manager.get_connection().execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"CREATE INDEX idx_messages_tenant_recipient ON messages (tenant_id, recipient_id) WHERE deleted_at IS NULL".to_string(),
)).await.map_err(|e| DbErr::Custom(e.to_string()))?;
manager.get_connection().execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"CREATE INDEX idx_messages_tenant_recipient_unread ON messages (tenant_id, recipient_id) WHERE deleted_at IS NULL AND is_read = false".to_string(),
)).await.map_err(|e| DbErr::Custom(e.to_string()))?;
manager.get_connection().execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"CREATE INDEX idx_messages_tenant_business ON messages (tenant_id, business_type, business_id) WHERE deleted_at IS NULL".to_string(),
)).await.map_err(|e| DbErr::Custom(e.to_string()))?;
manager
.create_foreign_key(
ForeignKey::create()
.name("fk_messages_template")
.from(Messages::Table, Messages::TemplateId)
.to(MessageTemplatesRef::Table, MessageTemplatesRef::Id)
.to_owned(),
)
.await?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(Messages::Table).to_owned())
.await
}
}
#[derive(DeriveIden)]
enum Messages {
Table,
Id,
TenantId,
TemplateId,
SenderId,
SenderType,
RecipientId,
RecipientType,
Title,
Body,
Priority,
BusinessType,
BusinessId,
IsRead,
ReadAt,
IsArchived,
ArchivedAt,
SentAt,
Status,
CreatedAt,
UpdatedAt,
CreatedBy,
UpdatedBy,
DeletedAt,
}
#[derive(DeriveIden)]
enum MessageTemplatesRef {
Table,
Id,
}

View File

@@ -0,0 +1,72 @@
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.create_table(
Table::create()
.table(MessageSubscriptions::Table)
.if_not_exists()
.col(
ColumnDef::new(MessageSubscriptions::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(MessageSubscriptions::TenantId).uuid().not_null())
.col(ColumnDef::new(MessageSubscriptions::UserId).uuid().not_null())
.col(ColumnDef::new(MessageSubscriptions::NotificationTypes).json().null())
.col(ColumnDef::new(MessageSubscriptions::ChannelPreferences).json().null())
.col(
ColumnDef::new(MessageSubscriptions::DndEnabled)
.boolean()
.not_null()
.default(false),
)
.col(ColumnDef::new(MessageSubscriptions::DndStart).string().null())
.col(ColumnDef::new(MessageSubscriptions::DndEnd).string().null())
.col(ColumnDef::new(MessageSubscriptions::CreatedAt).timestamp_with_time_zone().not_null())
.col(ColumnDef::new(MessageSubscriptions::UpdatedAt).timestamp_with_time_zone().not_null())
.col(ColumnDef::new(MessageSubscriptions::CreatedBy).uuid().not_null())
.col(ColumnDef::new(MessageSubscriptions::UpdatedBy).uuid().not_null())
.col(ColumnDef::new(MessageSubscriptions::DeletedAt).timestamp_with_time_zone().null())
.to_owned(),
)
.await?;
manager.get_connection().execute(sea_orm::Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
"CREATE UNIQUE INDEX idx_message_subscriptions_tenant_user ON message_subscriptions (tenant_id, user_id) WHERE deleted_at IS NULL".to_string(),
)).await.map_err(|e| DbErr::Custom(e.to_string()))?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(MessageSubscriptions::Table).to_owned())
.await
}
}
#[derive(DeriveIden)]
enum MessageSubscriptions {
Table,
Id,
TenantId,
UserId,
NotificationTypes,
ChannelPreferences,
DndEnabled,
DndStart,
DndEnd,
CreatedAt,
UpdatedAt,
CreatedBy,
UpdatedBy,
DeletedAt,
}

View File

@@ -8,6 +8,7 @@ pub struct AppConfig {
pub jwt: JwtConfig,
pub auth: AuthConfig,
pub log: LogConfig,
pub cors: CorsConfig,
}
#[derive(Debug, Clone, Deserialize)]
@@ -45,6 +46,13 @@ pub struct AuthConfig {
pub super_admin_password: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CorsConfig {
/// Comma-separated list of allowed origins.
/// Use "*" to allow all origins (development only).
pub allowed_origins: String,
}
impl AppConfig {
pub fn load() -> anyhow::Result<Self> {
let config = config::Config::builder()

View File

@@ -109,11 +109,16 @@ async fn main() -> anyhow::Result<()> {
let workflow_module = erp_workflow::WorkflowModule::new();
tracing::info!(module = workflow_module.name(), version = workflow_module.version(), "Workflow module initialized");
// Initialize message module
let message_module = erp_message::MessageModule::new();
tracing::info!(module = message_module.name(), version = message_module.version(), "Message module initialized");
// Initialize module registry and register modules
let registry = ModuleRegistry::new()
.register(auth_module)
.register(config_module)
.register(workflow_module);
.register(workflow_module)
.register(message_module);
tracing::info!(module_count = registry.modules().len(), "Modules registered");
// Register event handlers
@@ -152,6 +157,7 @@ async fn main() -> anyhow::Result<()> {
let protected_routes = erp_auth::AuthModule::protected_routes()
.merge(erp_config::ConfigModule::protected_routes())
.merge(erp_workflow::WorkflowModule::protected_routes())
.merge(erp_message::MessageModule::protected_routes())
.layer(middleware::from_fn(move |req, next| {
let secret = jwt_secret.clone();
async move { jwt_auth_middleware_fn(secret, req, next).await }
@@ -159,7 +165,7 @@ async fn main() -> anyhow::Result<()> {
.with_state(state.clone());
// Merge public + protected into the final application router
let cors = tower_http::cors::CorsLayer::permissive(); // TODO: restrict origins in production
let cors = build_cors_layer(&state.config.cors.allowed_origins);
let app = Router::new()
.merge(public_routes)
.merge(protected_routes)
@@ -178,6 +184,48 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}
/// Build a CORS layer from the comma-separated allowed origins config.
///
/// If the config is "*", allows all origins (development mode).
/// Otherwise, parses each origin as a URL and restricts to those origins only.
fn build_cors_layer(allowed_origins: &str) -> tower_http::cors::CorsLayer {
use axum::http::HeaderValue;
use tower_http::cors::AllowOrigin;
let origins = allowed_origins
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>();
if origins.len() == 1 && origins[0] == "*" {
tracing::warn!("CORS: allowing all origins — only use in development!");
tower_http::cors::CorsLayer::permissive()
} else {
let allowed: Vec<HeaderValue> = origins
.iter()
.filter_map(|o| o.parse::<HeaderValue>().ok())
.collect();
tracing::info!(origins = ?origins, "CORS: restricting to allowed origins");
tower_http::cors::CorsLayer::new()
.allow_origin(AllowOrigin::list(allowed))
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::PUT,
axum::http::Method::DELETE,
axum::http::Method::PATCH,
])
.allow_headers([
axum::http::header::AUTHORIZATION,
axum::http::header::CONTENT_TYPE,
])
.allow_credentials(true)
}
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()

View File

@@ -70,3 +70,13 @@ impl FromRef<AppState> for erp_workflow::WorkflowState {
}
}
}
/// Allow erp-message handlers to extract their required state without depending on erp-server.
impl FromRef<AppState> for erp_message::MessageState {
fn from_ref(state: &AppState) -> Self {
Self {
db: state.db.clone(),
event_bus: state.event_bus.clone(),
}
}
}