Files
hms/crates/erp-ai/src/service/cost.rs
iven 0da59c6a0e
Some checks failed
CI / rust-test (push) Has been cancelled
CI / frontend-build (push) Has been cancelled
CI / security-audit (push) Has been cancelled
CI / rust-check (push) Has been cancelled
feat(ai): 成本估算 + 预算告警服务 — CostService
Phase 3 Task 24:
- 按分析类型+模型估算 token 用量和 USD 成本
- 查询租户月度预算状态和告警等级(Normal/Warning/Critical/Exceeded)
2026-05-05 16:03:32 +08:00

214 lines
6.2 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 成本估算与预算告警服务
use sea_orm::{ColumnTrait, EntityTrait, FromQueryResult, QueryFilter, Statement};
use uuid::Uuid;
use crate::entity::ai_tenant_config;
use crate::error::AiResult;
/// 单次分析成本估算
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CostEstimate {
pub analysis_type: String,
pub estimated_input_tokens: u32,
pub estimated_output_tokens: u32,
pub estimated_cost_usd: f64,
pub model: String,
}
/// 预算状态
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BudgetStatus {
pub tenant_id: Uuid,
pub monthly_budget: i64,
pub monthly_used: i64,
pub usage_percentage: f64,
pub is_over_budget: bool,
pub warning_level: BudgetWarningLevel,
}
/// 预算告警等级
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum BudgetWarningLevel {
Normal,
Warning,
Critical,
Exceeded,
}
/// 各分析类型的默认 token 估算
fn default_token_estimate(analysis_type: &str) -> (u32, u32) {
match analysis_type {
"lab_report" => (2000, 1500),
"trend" => (3000, 2000),
"checkup_plan" => (1500, 2000),
"report_summary" => (2500, 1500),
"dialysis_risk" => (2000, 1000),
_ => (1500, 1000),
}
}
/// 各模型的每百万 token 成本USD
fn model_cost_per_million(model: &str) -> (f64, f64) {
// (input_cost, output_cost) per million tokens
match model {
m if m.contains("claude-3") => (3.0, 15.0),
m if m.contains("claude") => (3.0, 15.0),
m if m.contains("gpt-4") => (30.0, 60.0),
m if m.contains("gpt-3.5") => (0.5, 1.5),
m if m.contains("qwen") => (0.5, 1.0),
_ => (2.0, 8.0),
}
}
pub struct CostService {
db: sea_orm::DatabaseConnection,
}
impl CostService {
pub fn new(db: sea_orm::DatabaseConnection) -> Self {
Self { db }
}
/// 估算单次分析成本
pub fn estimate_cost(analysis_type: &str, model: &str) -> CostEstimate {
let (input_tokens, output_tokens) = default_token_estimate(analysis_type);
let (input_cost, output_cost) = model_cost_per_million(model);
let estimated_cost_usd =
(input_tokens as f64 * input_cost / 1_000_000.0)
+ (output_tokens as f64 * output_cost / 1_000_000.0);
CostEstimate {
analysis_type: analysis_type.to_string(),
estimated_input_tokens: input_tokens,
estimated_output_tokens: output_tokens,
estimated_cost_usd,
model: model.to_string(),
}
}
/// 获取租户预算状态
pub async fn get_budget_status(&self, tenant_id: Uuid) -> AiResult<BudgetStatus> {
let config = ai_tenant_config::Entity::find()
.filter(ai_tenant_config::Column::TenantId.eq(tenant_id))
.filter(ai_tenant_config::Column::DeletedAt.is_null())
.one(&self.db)
.await?;
let monthly_budget = config
.as_ref()
.map(|c| c.monthly_token_budget)
.unwrap_or(1_000_000);
// 查询当月已用 token
let used = self.get_monthly_usage(tenant_id).await?;
let usage_percentage = if monthly_budget > 0 {
(used as f64 / monthly_budget as f64) * 100.0
} else {
0.0
};
let is_over_budget = used > monthly_budget;
let warning_level = if is_over_budget {
BudgetWarningLevel::Exceeded
} else if usage_percentage >= 90.0 {
BudgetWarningLevel::Critical
} else if usage_percentage >= 70.0 {
BudgetWarningLevel::Warning
} else {
BudgetWarningLevel::Normal
};
Ok(BudgetStatus {
tenant_id,
monthly_budget,
monthly_used: used,
usage_percentage,
is_over_budget,
warning_level,
})
}
async fn get_monthly_usage(&self, tenant_id: Uuid) -> AiResult<i64> {
#[derive(Debug, FromQueryResult)]
struct TokenSum {
total: Option<i64>,
}
let sql = r#"
SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total
FROM ai_usage
WHERE tenant_id = $1
AND deleted_at IS NULL
AND created_at >= DATE_TRUNC('month', CURRENT_DATE)
"#;
let row: Option<TokenSum> = TokenSum::find_by_statement(
Statement::from_sql_and_values(
sea_orm::DatabaseBackend::Postgres,
sql,
[tenant_id.into()],
),
)
.one(&self.db)
.await?;
Ok(row.and_then(|r| r.total).unwrap_or(0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimate_lab_report_cost() {
let est = CostService::estimate_cost("lab_report", "claude-3-sonnet");
assert_eq!(est.estimated_input_tokens, 2000);
assert_eq!(est.estimated_output_tokens, 1500);
assert!(est.estimated_cost_usd > 0.0);
}
#[test]
fn estimate_trend_cost() {
let est = CostService::estimate_cost("trend", "gpt-4");
assert_eq!(est.estimated_input_tokens, 3000);
assert!(est.estimated_cost_usd > 0.0);
}
#[test]
fn budget_warning_levels() {
assert_eq!(BudgetWarningLevel::Normal, BudgetWarningLevel::Normal);
assert!(matches!(BudgetWarningLevel::Exceeded, BudgetWarningLevel::Exceeded));
}
#[test]
fn usage_percentage_calculation() {
let monthly_budget: i64 = 1_000_000;
let monthly_used: i64 = 750_000;
let pct = (monthly_used as f64 / monthly_budget as f64) * 100.0;
assert!((pct - 75.0).abs() < 0.01);
}
#[test]
fn over_budget_detection() {
let monthly_budget: i64 = 100_000;
let monthly_used: i64 = 150_000;
assert!(monthly_used > monthly_budget);
}
#[test]
fn default_unknown_type() {
let (input, output) = default_token_estimate("unknown_type");
assert_eq!(input, 1500);
assert_eq!(output, 1000);
}
#[test]
fn model_cost_claude() {
let (input, output) = model_cost_per_million("claude-3-sonnet-20240229");
assert!(input > 0.0);
assert!(output > input);
}
}