feat(ai): 成本估算 + 预算告警服务 — CostService
Some checks failed
CI / rust-test (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled
CI / rust-check (push) Has been cancelled

Phase 3 Task 24:
- 按分析类型+模型估算 token 用量和 USD 成本
- 查询租户月度预算状态和告警等级(Normal/Warning/Critical/Exceeded)
This commit is contained in:
iven
2026-05-05 16:03:32 +08:00
parent d2512ca9db
commit 0da59c6a0e
2 changed files with 214 additions and 0 deletions

View File

@@ -0,0 +1,213 @@
//! 成本估算与预算告警服务
use sea_orm::{ColumnTrait, EntityTrait, FromQueryResult, QueryFilter, Statement};
use uuid::Uuid;
use crate::entity::ai_tenant_config;
use crate::error::AiResult;
/// 单次分析成本估算
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CostEstimate {
pub analysis_type: String,
pub estimated_input_tokens: u32,
pub estimated_output_tokens: u32,
pub estimated_cost_usd: f64,
pub model: String,
}
/// 预算状态
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BudgetStatus {
pub tenant_id: Uuid,
pub monthly_budget: i64,
pub monthly_used: i64,
pub usage_percentage: f64,
pub is_over_budget: bool,
pub warning_level: BudgetWarningLevel,
}
/// 预算告警等级
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum BudgetWarningLevel {
Normal,
Warning,
Critical,
Exceeded,
}
/// 各分析类型的默认 token 估算
fn default_token_estimate(analysis_type: &str) -> (u32, u32) {
match analysis_type {
"lab_report" => (2000, 1500),
"trend" => (3000, 2000),
"checkup_plan" => (1500, 2000),
"report_summary" => (2500, 1500),
"dialysis_risk" => (2000, 1000),
_ => (1500, 1000),
}
}
/// 各模型的每百万 token 成本USD
fn model_cost_per_million(model: &str) -> (f64, f64) {
// (input_cost, output_cost) per million tokens
match model {
m if m.contains("claude-3") => (3.0, 15.0),
m if m.contains("claude") => (3.0, 15.0),
m if m.contains("gpt-4") => (30.0, 60.0),
m if m.contains("gpt-3.5") => (0.5, 1.5),
m if m.contains("qwen") => (0.5, 1.0),
_ => (2.0, 8.0),
}
}
pub struct CostService {
db: sea_orm::DatabaseConnection,
}
impl CostService {
pub fn new(db: sea_orm::DatabaseConnection) -> Self {
Self { db }
}
/// 估算单次分析成本
pub fn estimate_cost(analysis_type: &str, model: &str) -> CostEstimate {
let (input_tokens, output_tokens) = default_token_estimate(analysis_type);
let (input_cost, output_cost) = model_cost_per_million(model);
let estimated_cost_usd =
(input_tokens as f64 * input_cost / 1_000_000.0)
+ (output_tokens as f64 * output_cost / 1_000_000.0);
CostEstimate {
analysis_type: analysis_type.to_string(),
estimated_input_tokens: input_tokens,
estimated_output_tokens: output_tokens,
estimated_cost_usd,
model: model.to_string(),
}
}
/// 获取租户预算状态
pub async fn get_budget_status(&self, tenant_id: Uuid) -> AiResult<BudgetStatus> {
let config = ai_tenant_config::Entity::find()
.filter(ai_tenant_config::Column::TenantId.eq(tenant_id))
.filter(ai_tenant_config::Column::DeletedAt.is_null())
.one(&self.db)
.await?;
let monthly_budget = config
.as_ref()
.map(|c| c.monthly_token_budget)
.unwrap_or(1_000_000);
// 查询当月已用 token
let used = self.get_monthly_usage(tenant_id).await?;
let usage_percentage = if monthly_budget > 0 {
(used as f64 / monthly_budget as f64) * 100.0
} else {
0.0
};
let is_over_budget = used > monthly_budget;
let warning_level = if is_over_budget {
BudgetWarningLevel::Exceeded
} else if usage_percentage >= 90.0 {
BudgetWarningLevel::Critical
} else if usage_percentage >= 70.0 {
BudgetWarningLevel::Warning
} else {
BudgetWarningLevel::Normal
};
Ok(BudgetStatus {
tenant_id,
monthly_budget,
monthly_used: used,
usage_percentage,
is_over_budget,
warning_level,
})
}
async fn get_monthly_usage(&self, tenant_id: Uuid) -> AiResult<i64> {
#[derive(Debug, FromQueryResult)]
struct TokenSum {
total: Option<i64>,
}
let sql = r#"
SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total
FROM ai_usage
WHERE tenant_id = $1
AND deleted_at IS NULL
AND created_at >= DATE_TRUNC('month', CURRENT_DATE)
"#;
let row: Option<TokenSum> = TokenSum::find_by_statement(
Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[tenant_id.into()],
),
)
.one(&self.db)
.await?;
Ok(row.and_then(|r| r.total).unwrap_or(0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimate_lab_report_cost() {
let est = CostService::estimate_cost("lab_report", "claude-3-sonnet");
assert_eq!(est.estimated_input_tokens, 2000);
assert_eq!(est.estimated_output_tokens, 1500);
assert!(est.estimated_cost_usd > 0.0);
}
#[test]
fn estimate_trend_cost() {
let est = CostService::estimate_cost("trend", "gpt-4");
assert_eq!(est.estimated_input_tokens, 3000);
assert!(est.estimated_cost_usd > 0.0);
}
#[test]
fn budget_warning_levels() {
assert_eq!(BudgetWarningLevel::Normal, BudgetWarningLevel::Normal);
assert!(matches!(BudgetWarningLevel::Exceeded, BudgetWarningLevel::Exceeded));
}
#[test]
fn usage_percentage_calculation() {
let monthly_budget: i64 = 1_000_000;
let monthly_used: i64 = 750_000;
let pct = (monthly_used as f64 / monthly_budget as f64) * 100.0;
assert!((pct - 75.0).abs() < 0.01);
}
#[test]
fn over_budget_detection() {
let monthly_budget: i64 = 100_000;
let monthly_used: i64 = 150_000;
assert!(monthly_used > monthly_budget);
}
#[test]
fn default_unknown_type() {
let (input, output) = default_token_estimate("unknown_type");
assert_eq!(input, 1500);
assert_eq!(output, 1000);
}
#[test]
fn model_cost_claude() {
let (input, output) = model_cost_per_million("claude-3-sonnet-20240229");
assert!(input > 0.0);
assert!(output > input);
}
}

View File

@@ -3,6 +3,7 @@ pub mod analysis_queue;
pub mod auto_analysis;
pub mod cache;
pub mod comparison;
pub mod cost;
pub mod dialysis_risk_scorer;
pub mod local_rules;
pub mod output_parser;