Files
hms/crates/erp-server/src/config.rs
iven 4d02b2b531 feat(ai): 扩展 AiConfig 支持多 Provider 配置
- config/default.toml 新增 providers 子段(claude/openai/ollama)
- erp-server/config.rs AiConfig 新增 quota_check_enabled + providers HashMap
- erp-ai/config.rs 新增 ProviderType 枚举 + ProviderConfig 结构体
2026-05-05 15:01:24 +08:00

198 lines
5.1 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub redis: RedisConfig,
pub jwt: JwtConfig,
pub auth: AuthConfig,
pub log: LogConfig,
pub cors: CorsConfig,
pub wechat: WechatConfig,
pub health: HealthConfig,
pub crypto: CryptoConfig,
pub ai: AiConfig,
pub storage: StorageConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
#[serde(default = "default_metrics_port")]
pub metrics_port: u16,
}
fn default_metrics_port() -> u16 {
9090
}
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RedisConfig {
pub url: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct JwtConfig {
pub secret: String,
pub access_token_ttl: String,
pub refresh_token_ttl: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LogConfig {
pub level: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
pub super_admin_password: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CorsConfig {
/// Comma-separated list of allowed origins.
/// Use "*" to allow all origins (development only).
pub allowed_origins: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WechatConfig {
pub appid: String,
pub secret: String,
#[serde(default)]
pub dev_mode: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct HealthConfig {
/// AES-256 密钥 (64 字符 hex 编码32 字节)
pub aes_key: String,
/// HMAC-SHA256 密钥 (64 字符 hex 编码32 字节)
pub hmac_key: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CryptoConfig {
/// Master KEK (64 字符 hex 编码32 字节)。用于加密保护每租户 DEK。
/// Phase A 阶段同时作为全局数据加密密钥使用。
pub kek: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AiConfig {
pub default_provider: String,
pub api_key: String,
pub base_url: Option<String>,
pub model: String,
pub max_tokens: u32,
pub temperature: f32,
pub cache_ttl_seconds: u64,
pub rate_limit_patient_daily: u32,
#[serde(default)]
pub quota_check_enabled: bool,
#[serde(default)]
pub providers: std::collections::HashMap<String, ProviderConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProviderConfig {
pub provider_type: String,
pub api_key_env: Option<String>,
pub base_url: Option<String>,
pub default_model: String,
pub max_tokens: u32,
pub temperature: f32,
#[serde(default = "default_true")]
pub is_enabled: bool,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Deserialize)]
pub struct StorageConfig {
/// 文件上传目录(本地存储)
pub upload_dir: String,
/// 单文件最大大小(如 "10MB"
pub max_file_size: String,
}
impl StorageConfig {
/// 解析 max_file_size 为字节数
pub fn max_file_size_bytes(&self) -> u64 {
let s = self.max_file_size.to_uppercase();
if let Some(num) = s.strip_suffix("MB") {
num.trim().parse::<u64>().unwrap_or(10) * 1024 * 1024
} else if let Some(num) = s.strip_suffix("KB") {
num.trim().parse::<u64>().unwrap_or(1024) * 1024
} else if let Some(num) = s.strip_suffix("GB") {
num.trim().parse::<u64>().unwrap_or(1) * 1024 * 1024 * 1024
} else {
s.parse::<u64>().unwrap_or(10 * 1024 * 1024)
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct RateLimitConfig {
/// Redis 不可达时是否拒绝请求。默认 true安全优先
#[serde(default = "default_fail_close")]
pub fail_close: bool,
}
fn default_fail_close() -> bool {
true
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self { fail_close: true }
}
}
impl AppConfig {
pub fn load() -> anyhow::Result<Self> {
let config = config::Config::builder()
.add_source(config::File::with_name("config/default"))
.add_source(config::Environment::with_prefix("ERP").separator("__"))
.build()?;
let app_config: Self = config.try_deserialize()?;
// 安全检查:禁止在生产使用默认 JWT 密钥
if app_config.jwt.secret == "change-me-in-production" {
tracing::warn!("⚠️ JWT 密钥使用默认值,请通过 ERP__JWT__SECRET 环境变量设置安全密钥");
}
Ok(app_config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limit_default_is_fail_close() {
let config = RateLimitConfig::default();
assert!(config.fail_close, "RateLimitConfig 默认应为 fail_close = true");
}
#[test]
fn serde_default_uses_custom_fn() {
let config: RateLimitConfig = serde_json::from_str("{}").unwrap();
assert!(config.fail_close, "serde 反序列化缺失字段时应使用 default_fail_close() = true");
}
}