From 0da59c6a0e2db7f1c58d70b68bdfb60d3d503ebf Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 5 May 2026 16:03:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E6=88=90=E6=9C=AC=E4=BC=B0?= =?UTF-8?q?=E7=AE=97=20+=20=E9=A2=84=E7=AE=97=E5=91=8A=E8=AD=A6=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=20=E2=80=94=20CostService?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3 Task 24: - 按分析类型+模型估算 token 用量和 USD 成本 - 查询租户月度预算状态和告警等级(Normal/Warning/Critical/Exceeded) --- crates/erp-ai/src/service/cost.rs | 213 ++++++++++++++++++++++++++++++ crates/erp-ai/src/service/mod.rs | 1 + 2 files changed, 214 insertions(+) create mode 100644 crates/erp-ai/src/service/cost.rs diff --git a/crates/erp-ai/src/service/cost.rs b/crates/erp-ai/src/service/cost.rs new file mode 100644 index 0000000..f9f83e9 --- /dev/null +++ b/crates/erp-ai/src/service/cost.rs @@ -0,0 +1,213 @@ +//! 成本估算与预算告警服务 + +use sea_orm::{ColumnTrait, EntityTrait, FromQueryResult, QueryFilter, Statement}; +use uuid::Uuid; + +use crate::entity::ai_tenant_config; +use crate::error::AiResult; + +/// 单次分析成本估算 +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CostEstimate { + pub analysis_type: String, + pub estimated_input_tokens: u32, + pub estimated_output_tokens: u32, + pub estimated_cost_usd: f64, + pub model: String, +} + +/// 预算状态 +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct BudgetStatus { + pub tenant_id: Uuid, + pub monthly_budget: i64, + pub monthly_used: i64, + pub usage_percentage: f64, + pub is_over_budget: bool, + pub warning_level: BudgetWarningLevel, +} + +/// 预算告警等级 +#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum BudgetWarningLevel { + Normal, + Warning, + Critical, + Exceeded, +} + +/// 各分析类型的默认 token 估算 +fn default_token_estimate(analysis_type: &str) -> (u32, u32) { + match analysis_type { + "lab_report" => (2000, 1500), + "trend" => (3000, 2000), + "checkup_plan" => (1500, 2000), + "report_summary" => (2500, 1500), + "dialysis_risk" => (2000, 1000), + _ => (1500, 1000), + } +} + +/// 各模型的每百万 token 成本(USD) +fn model_cost_per_million(model: &str) -> (f64, f64) { + // (input_cost, output_cost) per million tokens + match model { + m if m.contains("claude-3") => (3.0, 15.0), + m if m.contains("claude") => (3.0, 15.0), + m if m.contains("gpt-4") => (30.0, 60.0), + m if m.contains("gpt-3.5") => (0.5, 1.5), + m if m.contains("qwen") => (0.5, 1.0), + _ => (2.0, 8.0), + } +} + +pub struct CostService { + db: sea_orm::DatabaseConnection, +} + +impl CostService { + pub fn new(db: sea_orm::DatabaseConnection) -> Self { + Self { db } + } + + /// 估算单次分析成本 + pub fn estimate_cost(analysis_type: &str, model: &str) -> CostEstimate { + let (input_tokens, output_tokens) = default_token_estimate(analysis_type); + let (input_cost, output_cost) = model_cost_per_million(model); + let estimated_cost_usd = + (input_tokens as f64 * input_cost / 1_000_000.0) + + (output_tokens as f64 * output_cost / 1_000_000.0); + + CostEstimate { + analysis_type: analysis_type.to_string(), + estimated_input_tokens: input_tokens, + estimated_output_tokens: output_tokens, + estimated_cost_usd, + model: model.to_string(), + } + } + + /// 获取租户预算状态 + pub async fn get_budget_status(&self, tenant_id: Uuid) -> AiResult { + let config = ai_tenant_config::Entity::find() + .filter(ai_tenant_config::Column::TenantId.eq(tenant_id)) + .filter(ai_tenant_config::Column::DeletedAt.is_null()) + .one(&self.db) + .await?; + + let monthly_budget = config + .as_ref() + .map(|c| c.monthly_token_budget) + .unwrap_or(1_000_000); + + // 查询当月已用 token + let used = self.get_monthly_usage(tenant_id).await?; + let usage_percentage = if monthly_budget > 0 { + (used as f64 / monthly_budget as f64) * 100.0 + } else { + 0.0 + }; + + let is_over_budget = used > monthly_budget; + let warning_level = if is_over_budget { + BudgetWarningLevel::Exceeded + } else if usage_percentage >= 90.0 { + BudgetWarningLevel::Critical + } else if usage_percentage >= 70.0 { + BudgetWarningLevel::Warning + } else { + BudgetWarningLevel::Normal + }; + + Ok(BudgetStatus { + tenant_id, + monthly_budget, + monthly_used: used, + usage_percentage, + is_over_budget, + warning_level, + }) + } + + async fn get_monthly_usage(&self, tenant_id: Uuid) -> AiResult { + #[derive(Debug, FromQueryResult)] + struct TokenSum { + total: Option, + } + + let sql = r#" + SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total + FROM ai_usage + WHERE tenant_id = $1 + AND deleted_at IS NULL + AND created_at >= DATE_TRUNC('month', CURRENT_DATE) + "#; + + let row: Option = TokenSum::find_by_statement( + Statement::from_sql_and_values( + sea_orm::DatabaseBackend::Postgres, + sql, + [tenant_id.into()], + ), + ) + .one(&self.db) + .await?; + + Ok(row.and_then(|r| r.total).unwrap_or(0)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn estimate_lab_report_cost() { + let est = CostService::estimate_cost("lab_report", "claude-3-sonnet"); + assert_eq!(est.estimated_input_tokens, 2000); + assert_eq!(est.estimated_output_tokens, 1500); + assert!(est.estimated_cost_usd > 0.0); + } + + #[test] + fn estimate_trend_cost() { + let est = CostService::estimate_cost("trend", "gpt-4"); + assert_eq!(est.estimated_input_tokens, 3000); + assert!(est.estimated_cost_usd > 0.0); + } + + #[test] + fn budget_warning_levels() { + assert_eq!(BudgetWarningLevel::Normal, BudgetWarningLevel::Normal); + assert!(matches!(BudgetWarningLevel::Exceeded, BudgetWarningLevel::Exceeded)); + } + + #[test] + fn usage_percentage_calculation() { + let monthly_budget: i64 = 1_000_000; + let monthly_used: i64 = 750_000; + let pct = (monthly_used as f64 / monthly_budget as f64) * 100.0; + assert!((pct - 75.0).abs() < 0.01); + } + + #[test] + fn over_budget_detection() { + let monthly_budget: i64 = 100_000; + let monthly_used: i64 = 150_000; + assert!(monthly_used > monthly_budget); + } + + #[test] + fn default_unknown_type() { + let (input, output) = default_token_estimate("unknown_type"); + assert_eq!(input, 1500); + assert_eq!(output, 1000); + } + + #[test] + fn model_cost_claude() { + let (input, output) = model_cost_per_million("claude-3-sonnet-20240229"); + assert!(input > 0.0); + assert!(output > input); + } +} diff --git a/crates/erp-ai/src/service/mod.rs b/crates/erp-ai/src/service/mod.rs index bfb0ac8..b66ebb8 100644 --- a/crates/erp-ai/src/service/mod.rs +++ b/crates/erp-ai/src/service/mod.rs @@ -3,6 +3,7 @@ pub mod analysis_queue; pub mod auto_analysis; pub mod cache; pub mod comparison; +pub mod cost; pub mod dialysis_risk_scorer; pub mod local_rules; pub mod output_parser;