diff --git a/config/default.toml b/config/default.toml index 0137de5..84dda02 100644 --- a/config/default.toml +++ b/config/default.toml @@ -48,6 +48,33 @@ max_tokens = 2048 temperature = 0.3 cache_ttl_seconds = 604800 rate_limit_patient_daily = 10 +quota_check_enabled = true + +[ai.providers.claude] +provider_type = "claude" +api_key_env = "ANTHROPIC_API_KEY" +base_url = "https://api.anthropic.com" +default_model = "claude-sonnet-4-6" +max_tokens = 2048 +temperature = 0.3 +is_enabled = true + +[ai.providers.openai] +provider_type = "openai" +api_key_env = "OPENAI_API_KEY" +base_url = "https://api.openai.com" +default_model = "gpt-4o" +max_tokens = 2048 +temperature = 0.3 +is_enabled = false + +[ai.providers.ollama] +provider_type = "ollama" +base_url = "http://localhost:11434" +default_model = "qwen2.5:7b" +max_tokens = 2048 +temperature = 0.3 +is_enabled = false [storage] upload_dir = "./uploads" diff --git a/crates/erp-ai/src/config.rs b/crates/erp-ai/src/config.rs new file mode 100644 index 0000000..c05fcf9 --- /dev/null +++ b/crates/erp-ai/src/config.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProviderType { + Claude, + Openai, + Ollama, + Rules, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ProviderConfig { + pub provider_type: ProviderType, + pub api_key_env: Option, + pub base_url: Option, + pub default_model: String, + pub max_tokens: u32, + pub temperature: f32, + #[serde(default = "default_true")] + pub is_enabled: bool, +} + +fn default_true() -> bool { + true +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AiModuleConfig { + pub default_provider: String, + pub cache_ttl_seconds: u64, + pub quota_check_enabled: bool, + pub rate_limit_patient_daily: u32, + pub providers: std::collections::HashMap, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn provider_type_serde_roundtrip() { + let pt = ProviderType::Claude; + let json = serde_json::to_string(&pt).unwrap(); + assert_eq!(json, "\"claude\""); + let back: ProviderType = serde_json::from_str(&json).unwrap(); + assert_eq!(back, pt); + } + + #[test] + fn provider_type_all_variants() { + for pt in [ProviderType::Claude, ProviderType::Openai, ProviderType::Ollama, ProviderType::Rules] { + let json = serde_json::to_string(&pt).unwrap(); + let back: ProviderType = serde_json::from_str(&json).unwrap(); + assert_eq!(back, pt); + } + } + + #[test] + fn parse_provider_config() { + let json = r#"{ + "provider_type": "openai", + "api_key_env": "OPENAI_API_KEY", + "base_url": "https://api.openai.com", + "default_model": "gpt-4o", + "max_tokens": 2048, + "temperature": 0.3, + "is_enabled": true + }"#; + let config: ProviderConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.provider_type, ProviderType::Openai); + assert_eq!(config.default_model, "gpt-4o"); + assert!(config.is_enabled); + } + + #[test] + fn parse_ollama_config_no_api_key() { + let json = r#"{ + "provider_type": "ollama", + "base_url": "http://localhost:11434", + "default_model": "qwen2.5:7b", + "max_tokens": 2048, + "temperature": 0.3 + }"#; + let config: ProviderConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.provider_type, ProviderType::Ollama); + assert!(config.api_key_env.is_none()); + assert!(config.is_enabled); // default + } +} diff --git a/crates/erp-ai/src/lib.rs b/crates/erp-ai/src/lib.rs index b6ea311..7d7b9f8 100644 --- a/crates/erp-ai/src/lib.rs +++ b/crates/erp-ai/src/lib.rs @@ -1,3 +1,4 @@ +pub mod config; pub mod dto; pub mod entity; pub mod error; diff --git a/crates/erp-server/src/config.rs b/crates/erp-server/src/config.rs index c2035fd..c77f0c7 100644 --- a/crates/erp-server/src/config.rs +++ b/crates/erp-server/src/config.rs @@ -99,6 +99,26 @@ pub struct AiConfig { pub temperature: f32, pub cache_ttl_seconds: u64, pub rate_limit_patient_daily: u32, + #[serde(default)] + pub quota_check_enabled: bool, + #[serde(default)] + pub providers: std::collections::HashMap, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ProviderConfig { + pub provider_type: String, + pub api_key_env: Option, + pub base_url: Option, + pub default_model: String, + pub max_tokens: u32, + pub temperature: f32, + #[serde(default = "default_true")] + pub is_enabled: bool, +} + +fn default_true() -> bool { + true } #[derive(Debug, Clone, Deserialize)]