feat(ai): 扩展 AiConfig 支持多 Provider 配置
- config/default.toml 新增 providers 子段(claude/openai/ollama) - erp-server/config.rs AiConfig 新增 quota_check_enabled + providers HashMap - erp-ai/config.rs 新增 ProviderType 枚举 + ProviderConfig 结构体
This commit is contained in:
@@ -48,6 +48,33 @@ max_tokens = 2048
|
|||||||
temperature = 0.3
|
temperature = 0.3
|
||||||
cache_ttl_seconds = 604800
|
cache_ttl_seconds = 604800
|
||||||
rate_limit_patient_daily = 10
|
rate_limit_patient_daily = 10
|
||||||
|
quota_check_enabled = true
|
||||||
|
|
||||||
|
[ai.providers.claude]
|
||||||
|
provider_type = "claude"
|
||||||
|
api_key_env = "ANTHROPIC_API_KEY"
|
||||||
|
base_url = "https://api.anthropic.com"
|
||||||
|
default_model = "claude-sonnet-4-6"
|
||||||
|
max_tokens = 2048
|
||||||
|
temperature = 0.3
|
||||||
|
is_enabled = true
|
||||||
|
|
||||||
|
[ai.providers.openai]
|
||||||
|
provider_type = "openai"
|
||||||
|
api_key_env = "OPENAI_API_KEY"
|
||||||
|
base_url = "https://api.openai.com"
|
||||||
|
default_model = "gpt-4o"
|
||||||
|
max_tokens = 2048
|
||||||
|
temperature = 0.3
|
||||||
|
is_enabled = false
|
||||||
|
|
||||||
|
[ai.providers.ollama]
|
||||||
|
provider_type = "ollama"
|
||||||
|
base_url = "http://localhost:11434"
|
||||||
|
default_model = "qwen2.5:7b"
|
||||||
|
max_tokens = 2048
|
||||||
|
temperature = 0.3
|
||||||
|
is_enabled = false
|
||||||
|
|
||||||
[storage]
|
[storage]
|
||||||
upload_dir = "./uploads"
|
upload_dir = "./uploads"
|
||||||
|
|||||||
90
crates/erp-ai/src/config.rs
Normal file
90
crates/erp-ai/src/config.rs
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ProviderType {
|
||||||
|
Claude,
|
||||||
|
Openai,
|
||||||
|
Ollama,
|
||||||
|
Rules,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub provider_type: ProviderType,
|
||||||
|
pub api_key_env: Option<String>,
|
||||||
|
pub base_url: Option<String>,
|
||||||
|
pub default_model: String,
|
||||||
|
pub max_tokens: u32,
|
||||||
|
pub temperature: f32,
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub is_enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct AiModuleConfig {
|
||||||
|
pub default_provider: String,
|
||||||
|
pub cache_ttl_seconds: u64,
|
||||||
|
pub quota_check_enabled: bool,
|
||||||
|
pub rate_limit_patient_daily: u32,
|
||||||
|
pub providers: std::collections::HashMap<String, ProviderConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_type_serde_roundtrip() {
|
||||||
|
let pt = ProviderType::Claude;
|
||||||
|
let json = serde_json::to_string(&pt).unwrap();
|
||||||
|
assert_eq!(json, "\"claude\"");
|
||||||
|
let back: ProviderType = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(back, pt);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_type_all_variants() {
|
||||||
|
for pt in [ProviderType::Claude, ProviderType::Openai, ProviderType::Ollama, ProviderType::Rules] {
|
||||||
|
let json = serde_json::to_string(&pt).unwrap();
|
||||||
|
let back: ProviderType = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(back, pt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_provider_config() {
|
||||||
|
let json = r#"{
|
||||||
|
"provider_type": "openai",
|
||||||
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
|
"base_url": "https://api.openai.com",
|
||||||
|
"default_model": "gpt-4o",
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"temperature": 0.3,
|
||||||
|
"is_enabled": true
|
||||||
|
}"#;
|
||||||
|
let config: ProviderConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(config.provider_type, ProviderType::Openai);
|
||||||
|
assert_eq!(config.default_model, "gpt-4o");
|
||||||
|
assert!(config.is_enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_ollama_config_no_api_key() {
|
||||||
|
let json = r#"{
|
||||||
|
"provider_type": "ollama",
|
||||||
|
"base_url": "http://localhost:11434",
|
||||||
|
"default_model": "qwen2.5:7b",
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"temperature": 0.3
|
||||||
|
}"#;
|
||||||
|
let config: ProviderConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(config.provider_type, ProviderType::Ollama);
|
||||||
|
assert!(config.api_key_env.is_none());
|
||||||
|
assert!(config.is_enabled); // default
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod config;
|
||||||
pub mod dto;
|
pub mod dto;
|
||||||
pub mod entity;
|
pub mod entity;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
|||||||
@@ -99,6 +99,26 @@ pub struct AiConfig {
|
|||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
pub cache_ttl_seconds: u64,
|
pub cache_ttl_seconds: u64,
|
||||||
pub rate_limit_patient_daily: u32,
|
pub rate_limit_patient_daily: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
pub quota_check_enabled: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub providers: std::collections::HashMap<String, ProviderConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub provider_type: String,
|
||||||
|
pub api_key_env: Option<String>,
|
||||||
|
pub base_url: Option<String>,
|
||||||
|
pub default_model: String,
|
||||||
|
pub max_tokens: u32,
|
||||||
|
pub temperature: f32,
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub is_enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
|||||||
Reference in New Issue
Block a user