Files
zclaw_openfang/crates/zclaw-saas/src/model_config/service.rs
iven 13c0b18bbc
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
feat: Batch 5-9 — GrowthIntegration桥接、验证补全、死代码清理、Pipeline模板、Speech/Twitter真实实现
Batch 5 (P0): GrowthIntegration 接入 Tauri
- Kernel 新增 set_viking()/set_extraction_driver() 桥接 SqliteStorage
- 中间件链共享存储,MemoryExtractor 接入 LLM 驱动

Batch 6 (P1): 输入验证 + Heartbeat
- Relay 验证补全(stream 兼容检查、API key 格式校验)
- UUID 类型校验、SessionId 错误返回
- Heartbeat 默认开启 + 首次聊天自动初始化

Batch 7 (P2): 死代码清理
- zclaw-channels 整体移除(317 行)
- multi-agent 特性门控、admin 方法标注

Batch 8 (P2): Pipeline 模板
- PipelineMetadata 新增 annotations 字段
- pipeline_templates 命令 + 2 个示例模板
- fallback driver base_url 修复(doubao/qwen/deepseek 端点)

Batch 9 (P1): SpeechHand/TwitterHand 真实实现
- SpeechHand: tts_method 字段 + Browser TTS 前端集成 (Web Speech API)
- TwitterHand: 12 个 action 全部替换为 Twitter API v2 真实 HTTP 调用
- chatStore/useAutomationEvents 双路径 TTS 触发
2026-03-30 09:24:50 +08:00

494 lines
20 KiB
Rust

//! 模型配置业务逻辑
use sqlx::{PgPool, Row};
use crate::error::{SaasError, SaasResult};
use crate::common::{PaginatedResponse, normalize_pagination};
use crate::crypto;
use crate::models::{ProviderRow, ModelRow, AccountApiKeyRow, UsageByModelRow, UsageByDayRow};
use super::types::*;
// ============ Providers ============
pub async fn list_providers(
db: &PgPool, page: Option<u32>, page_size: Option<u32>, enabled_filter: Option<bool>,
) -> SaasResult<PaginatedResponse<ProviderInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
let (count_sql, data_sql) = if enabled_filter.is_some() {
(
"SELECT COUNT(*) FROM providers WHERE enabled = $1",
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3",
)
} else {
(
"SELECT COUNT(*) FROM providers",
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers ORDER BY name LIMIT $1 OFFSET $2",
)
};
let total: (i64,) = if let Some(en) = enabled_filter {
sqlx::query_as(count_sql).bind(en).fetch_one(db).await?
} else {
sqlx::query_as(count_sql).fetch_one(db).await?
};
let rows: Vec<ProviderRow> =
if let Some(en) = enabled_filter {
sqlx::query_as(data_sql)
.bind(en).bind(ps as i64).bind(offset)
.fetch_all(db).await?
} else {
sqlx::query_as(data_sql)
.bind(ps as i64).bind(offset)
.fetch_all(db).await?
};
let items = rows.into_iter().map(|r| {
ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<ProviderRow> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE id = $1"
)
.bind(provider_id)
.fetch_optional(db)
.await?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
Ok(ProviderInfo { id: r.id, name: r.name, display_name: r.display_name, base_url: r.base_url, api_protocol: r.api_protocol, enabled: r.enabled, rate_limit_rpm: r.rate_limit_rpm, rate_limit_tpm: r.rate_limit_tpm, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
}
// 加密 API Key 后存储
let encrypted_api_key = if let Some(ref key) = req.api_key {
if key.is_empty() {
String::new()
} else {
crypto::encrypt_value(key, enc_key)?
}
} else {
String::new()
};
sqlx::query(
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
.execute(db).await?;
get_provider(db, &id).await
}
pub async fn update_provider(
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now().to_rfc3339();
// Encrypt api_key upfront if provided
let encrypted_api_key = match req.api_key {
Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?),
Some(ref v) if v.is_empty() => Some(String::new()),
_ => None,
};
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE providers SET
display_name = COALESCE($1, display_name),
base_url = COALESCE($2, base_url),
api_protocol = COALESCE($3, api_protocol),
api_key = COALESCE($4, api_key),
enabled = COALESCE($5, enabled),
rate_limit_rpm = COALESCE($6, rate_limit_rpm),
rate_limit_tpm = COALESCE($7, rate_limit_tpm),
updated_at = $8
WHERE id = $9"
)
.bind(req.display_name.as_deref())
.bind(req.base_url.as_deref())
.bind(req.api_protocol.as_deref())
.bind(encrypted_api_key.as_deref())
.bind(req.enabled)
.bind(req.rate_limit_rpm)
.bind(req.rate_limit_tpm)
.bind(&now)
.bind(provider_id)
.execute(db).await?;
get_provider(db, provider_id).await
}
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
.bind(provider_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
}
Ok(())
}
// ============ Models ============
pub async fn list_models(
db: &PgPool, provider_id: Option<&str>, page: Option<u32>, page_size: Option<u32>,
) -> SaasResult<PaginatedResponse<ModelInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
)
} else {
(
"SELECT COUNT(*) FROM models",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
)
};
let total: (i64,) = if let Some(pid) = provider_id {
sqlx::query_as(count_sql).bind(pid).fetch_one(db).await?
} else {
sqlx::query_as(count_sql).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, ModelRow>(data_sql);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|r| {
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
// 验证 provider 存在
let provider = get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
// 检查 model 唯一性
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
)
.bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!(
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
)));
}
let ctx = req.context_window.unwrap_or(8192);
let max_out = req.max_output_tokens.unwrap_or(4096);
let streaming = req.supports_streaming.unwrap_or(true);
let vision = req.supports_vision.unwrap_or(false);
let pi = req.pricing_input.unwrap_or(0.0);
let po = req.pricing_output.unwrap_or(0.0);
sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
)
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
.execute(db).await?;
get_model(db, &id).await
}
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<ModelRow> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE id = $1"
)
.bind(model_id)
.fetch_optional(db)
.await?;
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn update_model(
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now().to_rfc3339();
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE models SET
alias = COALESCE($1, alias),
context_window = COALESCE($2, context_window),
max_output_tokens = COALESCE($3, max_output_tokens),
supports_streaming = COALESCE($4, supports_streaming),
supports_vision = COALESCE($5, supports_vision),
enabled = COALESCE($6, enabled),
pricing_input = COALESCE($7, pricing_input),
pricing_output = COALESCE($8, pricing_output),
updated_at = $9
WHERE id = $10"
)
.bind(req.alias.as_deref())
.bind(req.context_window)
.bind(req.max_output_tokens)
.bind(req.supports_streaming)
.bind(req.supports_vision)
.bind(req.enabled)
.bind(req.pricing_input)
.bind(req.pricing_output)
.bind(&now)
.bind(model_id)
.execute(db).await?;
get_model(db, model_id).await
}
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = $1")
.bind(model_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
}
Ok(())
}
// ============ Account API Keys ============
pub async fn list_account_api_keys(
db: &PgPool, account_id: &str, provider_id: Option<&str>,
page: Option<u32>, page_size: Option<u32>,
) -> SaasResult<PaginatedResponse<AccountApiKeyInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
// Build COUNT and data queries based on whether provider_id is provided
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL",
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4",
)
} else {
(
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL",
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3",
)
};
let total: (i64,) = if provider_id.is_some() {
let mut q = sqlx::query_as(count_sql).bind(account_id);
if let Some(pid) = provider_id { q = q.bind(pid); }
q.fetch_one(db).await?
} else {
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
};
let mut query = sqlx::query_as::<_, AccountApiKeyRow>(data_sql)
.bind(account_id);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|r| {
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
let masked = mask_api_key(&r.key_value);
AccountApiKeyInfo { id: r.id, provider_id: r.provider_id, key_label: r.key_label, permissions, enabled: r.enabled, last_used_at: r.last_used_at, created_at: r.created_at, masked_key: masked }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
}
pub async fn create_account_api_key(
db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32],
) -> SaasResult<AccountApiKeyInfo> {
// 验证 provider 存在
get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let permissions = serde_json::to_string(&req.permissions)?;
// 加密 key_value 后存储
let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?;
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
)
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
.bind(&req.key_label).bind(&permissions).bind(&now)
.execute(db).await?;
let masked = mask_api_key(&req.key_value);
Ok(AccountApiKeyInfo {
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
created_at: now, masked_key: masked,
})
}
pub async fn rotate_account_api_key(
db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32],
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?;
let result = sqlx::query(
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
)
.bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
pub async fn revoke_account_api_key(
db: &PgPool, key_id: &str, account_id: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query(
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
)
.bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
// ============ Usage Statistics ============
pub async fn get_usage_stats(
db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> {
// Static SQL with conditional filter pattern:
// account_id is always required; optional filters use ($N IS NULL OR col = $N).
let total_sql = "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5)";
let row = sqlx::query(total_sql)
.bind(account_id)
.bind(&query.from)
.bind(&query.to)
.bind(&query.provider_id)
.bind(&query.model_id)
.fetch_one(db).await?;
let total_requests: i64 = row.try_get(0).unwrap_or(0);
let total_input: i64 = row.try_get(1).unwrap_or(0);
let total_output: i64 = row.try_get(2).unwrap_or(0);
// 按模型统计
let by_model_sql = "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5) GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20";
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(by_model_sql)
.bind(account_id)
.bind(&query.from)
.bind(&query.to)
.bind(&query.provider_id)
.bind(&query.model_id)
.fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|r| {
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
}).collect();
// 按天统计 (使用 days 参数或默认 30 天)
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
let from_days = (chrono::Utc::now() - chrono::Duration::days(days))
.date_naive()
.and_hms_opt(0, 0, 0).unwrap()
.and_utc()
.to_rfc3339();
let daily_sql = "SELECT created_at::date::text as day, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
FROM usage_records WHERE account_id = $1 AND created_at >= $2
GROUP BY created_at::date ORDER BY day DESC LIMIT $3";
let daily_rows: Vec<UsageByDayRow> = sqlx::query_as(daily_sql)
.bind(account_id).bind(&from_days).bind(days as i32)
.fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|r| {
DailyUsage { date: r.day, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
}).collect();
// 按 group_by 过滤返回
let group_by = query.group_by.as_deref();
let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] };
let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] };
Ok(UsageStats {
total_requests,
total_input_tokens: total_input,
total_output_tokens: total_output,
by_model,
by_day,
})
}
pub async fn record_usage(
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
status: &str, error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(account_id).bind(provider_id).bind(model_id)
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
.bind(status).bind(error_message).bind(&now)
.execute(db).await?;
Ok(())
}
// ============ Helpers ============
fn mask_api_key(key: &str) -> String {
if key.len() <= 8 {
return "*".repeat(key.len());
}
format!("{}...{}", &key[..4], &key[key.len()-4..])
}