- config/default.toml 新增 providers 子段(claude/openai/ollama) - erp-server/config.rs AiConfig 新增 quota_check_enabled + providers HashMap - erp-ai/config.rs 新增 ProviderType 枚举 + ProviderConfig 结构体
198 lines
5.1 KiB
Rust
198 lines
5.1 KiB
Rust
use serde::Deserialize;
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct AppConfig {
|
||
pub server: ServerConfig,
|
||
pub database: DatabaseConfig,
|
||
pub redis: RedisConfig,
|
||
pub jwt: JwtConfig,
|
||
pub auth: AuthConfig,
|
||
pub log: LogConfig,
|
||
pub cors: CorsConfig,
|
||
pub wechat: WechatConfig,
|
||
pub health: HealthConfig,
|
||
pub crypto: CryptoConfig,
|
||
pub ai: AiConfig,
|
||
pub storage: StorageConfig,
|
||
#[serde(default)]
|
||
pub rate_limit: RateLimitConfig,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct ServerConfig {
|
||
pub host: String,
|
||
pub port: u16,
|
||
#[serde(default = "default_metrics_port")]
|
||
pub metrics_port: u16,
|
||
}
|
||
|
||
fn default_metrics_port() -> u16 {
|
||
9090
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct DatabaseConfig {
|
||
pub url: String,
|
||
pub max_connections: u32,
|
||
pub min_connections: u32,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct RedisConfig {
|
||
pub url: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct JwtConfig {
|
||
pub secret: String,
|
||
pub access_token_ttl: String,
|
||
pub refresh_token_ttl: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct LogConfig {
|
||
pub level: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct AuthConfig {
|
||
pub super_admin_password: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct CorsConfig {
|
||
/// Comma-separated list of allowed origins.
|
||
/// Use "*" to allow all origins (development only).
|
||
pub allowed_origins: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct WechatConfig {
|
||
pub appid: String,
|
||
pub secret: String,
|
||
#[serde(default)]
|
||
pub dev_mode: bool,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct HealthConfig {
|
||
/// AES-256 密钥 (64 字符 hex 编码,32 字节)
|
||
pub aes_key: String,
|
||
/// HMAC-SHA256 密钥 (64 字符 hex 编码,32 字节)
|
||
pub hmac_key: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct CryptoConfig {
|
||
/// Master KEK (64 字符 hex 编码,32 字节)。用于加密保护每租户 DEK。
|
||
/// Phase A 阶段同时作为全局数据加密密钥使用。
|
||
pub kek: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct AiConfig {
|
||
pub default_provider: String,
|
||
pub api_key: String,
|
||
pub base_url: Option<String>,
|
||
pub model: String,
|
||
pub max_tokens: u32,
|
||
pub temperature: f32,
|
||
pub cache_ttl_seconds: u64,
|
||
pub rate_limit_patient_daily: u32,
|
||
#[serde(default)]
|
||
pub quota_check_enabled: bool,
|
||
#[serde(default)]
|
||
pub providers: std::collections::HashMap<String, ProviderConfig>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct ProviderConfig {
|
||
pub provider_type: String,
|
||
pub api_key_env: Option<String>,
|
||
pub base_url: Option<String>,
|
||
pub default_model: String,
|
||
pub max_tokens: u32,
|
||
pub temperature: f32,
|
||
#[serde(default = "default_true")]
|
||
pub is_enabled: bool,
|
||
}
|
||
|
||
fn default_true() -> bool {
|
||
true
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct StorageConfig {
|
||
/// 文件上传目录(本地存储)
|
||
pub upload_dir: String,
|
||
/// 单文件最大大小(如 "10MB")
|
||
pub max_file_size: String,
|
||
}
|
||
|
||
impl StorageConfig {
|
||
/// 解析 max_file_size 为字节数
|
||
pub fn max_file_size_bytes(&self) -> u64 {
|
||
let s = self.max_file_size.to_uppercase();
|
||
if let Some(num) = s.strip_suffix("MB") {
|
||
num.trim().parse::<u64>().unwrap_or(10) * 1024 * 1024
|
||
} else if let Some(num) = s.strip_suffix("KB") {
|
||
num.trim().parse::<u64>().unwrap_or(1024) * 1024
|
||
} else if let Some(num) = s.strip_suffix("GB") {
|
||
num.trim().parse::<u64>().unwrap_or(1) * 1024 * 1024 * 1024
|
||
} else {
|
||
s.parse::<u64>().unwrap_or(10 * 1024 * 1024)
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Deserialize)]
|
||
pub struct RateLimitConfig {
|
||
/// Redis 不可达时是否拒绝请求。默认 true(安全优先)。
|
||
#[serde(default = "default_fail_close")]
|
||
pub fail_close: bool,
|
||
}
|
||
|
||
fn default_fail_close() -> bool {
|
||
true
|
||
}
|
||
|
||
impl Default for RateLimitConfig {
|
||
fn default() -> Self {
|
||
Self { fail_close: true }
|
||
}
|
||
}
|
||
|
||
impl AppConfig {
|
||
pub fn load() -> anyhow::Result<Self> {
|
||
let config = config::Config::builder()
|
||
.add_source(config::File::with_name("config/default"))
|
||
.add_source(config::Environment::with_prefix("ERP").separator("__"))
|
||
.build()?;
|
||
let app_config: Self = config.try_deserialize()?;
|
||
|
||
// 安全检查:禁止在生产使用默认 JWT 密钥
|
||
if app_config.jwt.secret == "change-me-in-production" {
|
||
tracing::warn!("⚠️ JWT 密钥使用默认值,请通过 ERP__JWT__SECRET 环境变量设置安全密钥");
|
||
}
|
||
|
||
Ok(app_config)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn rate_limit_default_is_fail_close() {
|
||
let config = RateLimitConfig::default();
|
||
assert!(config.fail_close, "RateLimitConfig 默认应为 fail_close = true");
|
||
}
|
||
|
||
#[test]
|
||
fn serde_default_uses_custom_fn() {
|
||
let config: RateLimitConfig = serde_json::from_str("{}").unwrap();
|
||
assert!(config.fail_close, "serde 反序列化缺失字段时应使用 default_fail_close() = true");
|
||
}
|
||
}
|