feat(auth): 添加异步密码哈希和验证函数
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

refactor(relay): 复用HTTP客户端和请求体序列化结果

feat(kernel): 添加获取单个审批记录的方法

fix(store): 改进SaaS连接错误分类和降级处理

docs: 更新审计文档和系统架构文档

refactor(prompt): 优化SQL查询参数化绑定

refactor(migration): 使用静态SQL和COALESCE更新配置项

feat(commands): 添加审批执行状态追踪和事件通知

chore: 更新启动脚本以支持Admin后台

fix(auth-guard): 优化授权状态管理和错误处理

refactor(db): 使用异步密码哈希函数

refactor(totp): 使用异步密码验证函数

style: 清理无用文件和注释

docs: 更新功能全景和审计文档

refactor(service): 优化HTTP客户端重用和请求处理

fix(connection): 改进SaaS不可用时的降级处理

refactor(handlers): 使用异步密码验证函数

chore: 更新依赖和工具链配置
This commit is contained in:
iven
2026-03-29 21:45:29 +08:00
parent b7ec317d2c
commit 7de294375b
34 changed files with 2041 additions and 894 deletions

View File

@@ -823,6 +823,14 @@ impl Kernel {
approvals.iter().filter(|a| a.status == "pending").cloned().collect()
}
/// Get a single approval by ID (any status, not just pending)
///
/// Returns None if no approval with the given ID exists.
pub async fn get_approval(&self, id: &str) -> Option<ApprovalEntry> {
let approvals = self.pending_approvals.lock().await;
approvals.iter().find(|a| a.id == id).cloned()
}
/// Create a pending approval (called when a needs_approval hand is triggered)
pub async fn create_approval(&self, hand_id: String, input: serde_json::Value) -> ApprovalEntry {
let entry = ApprovalEntry {

View File

@@ -137,7 +137,8 @@ CREATE TABLE IF NOT EXISTS usage_records (
);
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((created_at::date));
-- idx_usage_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
-- CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((created_at::date));
CREATE TABLE IF NOT EXISTS relay_tasks (
id TEXT PRIMARY KEY,
@@ -163,7 +164,8 @@ CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id);
CREATE INDEX IF NOT EXISTS idx_relay_time ON relay_tasks(created_at);
CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((created_at::date));
-- idx_relay_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
-- CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((created_at::date));
CREATE TABLE IF NOT EXISTS config_items (
id TEXT PRIMARY KEY,
@@ -318,7 +320,8 @@ CREATE TABLE IF NOT EXISTS telemetry_reports (
CREATE INDEX IF NOT EXISTS idx_telemetry_account ON telemetry_reports(account_id);
CREATE INDEX IF NOT EXISTS idx_telemetry_time ON telemetry_reports(reported_at);
CREATE INDEX IF NOT EXISTS idx_telemetry_model ON telemetry_reports(model_id);
CREATE INDEX IF NOT EXISTS idx_telemetry_day ON telemetry_reports((reported_at::date));
-- idx_telemetry_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
-- CREATE INDEX IF NOT EXISTS idx_telemetry_day ON telemetry_reports((reported_at::date));
-- Refresh Token storage (single-use, JWT jti tracking)
CREATE TABLE IF NOT EXISTS refresh_tokens (

View File

@@ -14,55 +14,109 @@ pub async fn list_accounts(
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
let mut where_clauses = Vec::new();
let mut params: Vec<String> = Vec::new();
let mut param_idx = 1usize;
if let Some(role) = &query.role {
where_clauses.push(format!("role = ${}", param_idx));
param_idx += 1;
params.push(role.clone());
}
if let Some(status) = &query.status {
where_clauses.push(format!("status = ${}", param_idx));
param_idx += 1;
params.push(status.clone());
}
if let Some(search) = &query.search {
where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
param_idx += 3;
let pattern = format!("%{}%", search);
params.push(pattern.clone());
params.push(pattern.clone());
params.push(pattern);
}
let where_sql = if where_clauses.is_empty() {
String::new()
} else {
format!("WHERE {}", where_clauses.join(" AND "))
// Static SQL per combination -- no format!() string interpolation
let (total, rows) = match (&query.role, &query.status, &query.search) {
// role + status + search
(Some(role), Some(status), Some(search)) => {
let pattern = format!("%{}%", search);
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND status = $2 AND (username LIKE $3 OR email LIKE $3 OR display_name LIKE $3)"
).bind(role).bind(status).bind(&pattern).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE role = $1 AND status = $2 AND (username LIKE $3 OR email LIKE $3 OR display_name LIKE $3)
ORDER BY created_at DESC LIMIT $4 OFFSET $5"
).bind(role).bind(status).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// role + status
(Some(role), Some(status), None) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND status = $2"
).bind(role).bind(status).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE role = $1 AND status = $2
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
).bind(role).bind(status).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// role + search
(Some(role), None, Some(search)) => {
let pattern = format!("%{}%", search);
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)"
).bind(role).bind(&pattern).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE role = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
).bind(role).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// status + search
(None, Some(status), Some(search)) => {
let pattern = format!("%{}%", search);
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE status = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)"
).bind(status).bind(&pattern).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE status = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
).bind(status).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// role only
(Some(role), None, None) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE role = $1"
).bind(role).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE role = $1
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
).bind(role).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// status only
(None, Some(status), None) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE status = $1"
).bind(status).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE status = $1
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
).bind(status).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// search only
(None, None, Some(search)) => {
let pattern = format!("%{}%", search);
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts WHERE (username LIKE $1 OR email LIKE $1 OR display_name LIKE $1)"
).bind(&pattern).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE (username LIKE $1 OR email LIKE $1 OR display_name LIKE $1)
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
// no filter
(None, None, None) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM accounts"
).fetch_one(db).await?;
let rows = sqlx::query_as::<_, AccountRow>(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts ORDER BY created_at DESC LIMIT $1 OFFSET $2"
).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
(total, rows)
}
};
let count_sql = format!("SELECT COUNT(*) as count FROM accounts {}", where_sql);
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
for p in &params {
count_query = count_query.bind(p);
}
let total: i64 = count_query.fetch_one(db).await?;
let limit_idx = param_idx;
let offset_idx = param_idx + 1;
let data_sql = format!(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
where_sql, limit_idx, offset_idx
);
let mut data_query = sqlx::query_as::<_, AccountRow>(&data_sql);
for p in &params {
data_query = data_query.bind(p);
}
let rows = data_query.bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
let items: Vec<serde_json::Value> = rows
.into_iter()
.map(|r| {
@@ -102,30 +156,26 @@ pub async fn update_account(
req: &UpdateAccountRequest,
) -> SaasResult<serde_json::Value> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<String> = Vec::new();
let mut param_idx = 1usize;
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE accounts SET
display_name = COALESCE($1, display_name),
email = COALESCE($2, email),
role = COALESCE($3, role),
avatar_url = COALESCE($4, avatar_url),
updated_at = $5
WHERE id = $6"
)
.bind(req.display_name.as_deref())
.bind(req.email.as_deref())
.bind(req.role.as_deref())
.bind(req.avatar_url.as_deref())
.bind(&now)
.bind(account_id)
.execute(db).await?;
if updates.is_empty() {
return get_account(db, account_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
param_idx += 1;
params.push(now.clone());
params.push(account_id.to_string());
let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(p);
}
query.execute(db).await?;
get_account(db, account_id).await
}

View File

@@ -17,6 +17,9 @@ fn row_to_template(
}
}
/// Row type for agent_template queries (avoids multi-line turbofish parsing issues)
type AgentTemplateRow = (String, String, Option<String>, String, String, Option<String>, Option<String>, String, String, Option<f64>, Option<i32>, String, String, i32, String, String);
/// 创建 Agent 模板
pub async fn create_template(
db: &PgPool,
@@ -58,7 +61,7 @@ pub async fn create_template(
/// 获取单个模板
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo> {
let row: Option<_> = sqlx::query_as(
let row: Option<AgentTemplateRow> = sqlx::query_as(
"SELECT id, name, description, category, source, model, system_prompt,
tools, capabilities, temperature, max_tokens, visibility, status,
current_version, created_at, updated_at
@@ -70,7 +73,8 @@ pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo
}
/// 列出模板(分页 + 过滤)
/// 使用动态参数化查询,安全拼接 WHERE 条件。
/// Static SQL + conditional filter pattern: ($N IS NULL OR col = $N).
/// When the parameter is NULL the whole OR evaluates to TRUE (no filter).
pub async fn list_templates(
db: &PgPool,
query: &AgentTemplateListQuery,
@@ -79,80 +83,35 @@ pub async fn list_templates(
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = ((page - 1) * page_size) as i64;
// 动态构建参数化 WHERE 子句
let mut conditions: Vec<String> = vec!["1=1".to_string()];
let mut param_idx = 1u32;
let mut cat_bind: Option<String> = None;
let mut src_bind: Option<String> = None;
let mut vis_bind: Option<String> = None;
let mut st_bind: Option<String> = None;
if let Some(ref cat) = query.category {
param_idx += 1;
conditions.push(format!("category = ${}", param_idx));
cat_bind = Some(cat.clone());
}
if let Some(ref src) = query.source {
param_idx += 1;
conditions.push(format!("source = ${}", param_idx));
src_bind = Some(src.clone());
}
if let Some(ref vis) = query.visibility {
param_idx += 1;
conditions.push(format!("visibility = ${}", param_idx));
vis_bind = Some(vis.clone());
}
if let Some(ref st) = query.status {
param_idx += 1;
conditions.push(format!("status = ${}", param_idx));
st_bind = Some(st.clone());
}
let where_clause = conditions.join(" AND ");
// COUNT 查询: WHERE 参数绑定 ($1..$N)
let count_idx = param_idx;
let count_sql = format!(
"SELECT COUNT(*) FROM agent_templates WHERE {}",
where_clause
);
let count_limit_idx = count_idx + 1;
let count_offset_idx = count_limit_idx + 1;
let data_sql = format!(
"SELECT id, name, description, category, source, model, system_prompt,
let count_sql = "SELECT COUNT(*) FROM agent_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR visibility = $3) AND ($4 IS NULL OR status = $4)";
let data_sql = "SELECT id, name, description, category, source, model, system_prompt,
tools, capabilities, temperature, max_tokens, visibility, status,
current_version, created_at, updated_at
FROM agent_templates WHERE {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
where_clause, count_limit_idx, count_offset_idx
);
FROM agent_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR visibility = $3) AND ($4 IS NULL OR status = $4) ORDER BY created_at DESC LIMIT $5 OFFSET $6";
// 构建 COUNT 查询并绑定参数
let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql);
if let Some(ref v) = cat_bind { count_q = count_q.bind(v); }
if let Some(ref v) = src_bind { count_q = count_q.bind(v); }
if let Some(ref v) = vis_bind { count_q = count_q.bind(v); }
if let Some(ref v) = st_bind { count_q = count_q.bind(v); }
let total: i64 = count_q.fetch_one(db).await?;
let total: i64 = sqlx::query_scalar(count_sql)
.bind(&query.category)
.bind(&query.source)
.bind(&query.visibility)
.bind(&query.status)
.fetch_one(db).await?;
// 构建数据查询并绑定参数
let mut data_q = sqlx::query_as::<_, (
String, String, Option<String>, String, String, Option<String>, Option<String>,
String, String, Option<f64>, Option<i32>, String, String, i32, String, String
)>(&data_sql);
if let Some(ref v) = cat_bind { data_q = data_q.bind(v); }
if let Some(ref v) = src_bind { data_q = data_q.bind(v); }
if let Some(ref v) = vis_bind { data_q = data_q.bind(v); }
if let Some(ref v) = st_bind { data_q = data_q.bind(v); }
data_q = data_q.bind(page_size as i64).bind(offset);
let rows = data_q.fetch_all(db).await?;
let rows: Vec<AgentTemplateRow> = sqlx::query_as(data_sql)
.bind(&query.category)
.bind(&query.source)
.bind(&query.visibility)
.bind(&query.status)
.bind(page_size as i64)
.bind(offset)
.fetch_all(db).await?;
let items = rows.into_iter().map(row_to_template).collect();
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
}
/// 更新模板
/// 使用动态参数化查询,安全拼接 SET 子句。
/// COALESCE pattern: all updatable fields in a single static SQL.
/// NULL parameters leave the column unchanged.
pub async fn update_template(
db: &PgPool,
id: &str,
@@ -166,102 +125,41 @@ pub async fn update_template(
visibility: Option<&str>,
status: Option<&str>,
) -> SaasResult<AgentTemplateInfo> {
// 确认存在
// Confirm existence
get_template(db, id).await?;
let now = chrono::Utc::now().to_rfc3339();
let mut set_clauses: Vec<String> = vec![];
let mut param_idx = 1u32;
// 收集需要绑定的值(按顺序)
let mut desc_val: Option<String> = None;
let mut model_val: Option<String> = None;
let mut sp_val: Option<String> = None;
let mut tools_val: Option<String> = None;
let mut caps_val: Option<String> = None;
let mut temp_val: Option<f64> = None;
let mut mt_val: Option<i32> = None;
let mut vis_val: Option<String> = None;
let mut st_val: Option<String> = None;
// Serialize JSON fields upfront so we can bind Option<&str> consistently
let tools_json = tools.map(|t| serde_json::to_string(t).unwrap_or_else(|_| "[]".to_string()));
let caps_json = capabilities.map(|c| serde_json::to_string(c).unwrap_or_else(|_| "[]".to_string()));
if let Some(desc) = description {
param_idx += 1;
set_clauses.push(format!("description = ${}", param_idx));
desc_val = Some(desc.to_string());
}
if let Some(m) = model {
param_idx += 1;
set_clauses.push(format!("model = ${}", param_idx));
model_val = Some(m.to_string());
}
if let Some(sp) = system_prompt {
param_idx += 1;
set_clauses.push(format!("system_prompt = ${}", param_idx));
sp_val = Some(sp.to_string());
}
if let Some(t) = tools {
let json = serde_json::to_string(t).unwrap_or_else(|_| "[]".to_string());
param_idx += 1;
set_clauses.push(format!("tools = ${}", param_idx));
tools_val = Some(json);
}
if let Some(c) = capabilities {
let json = serde_json::to_string(c).unwrap_or_else(|_| "[]".to_string());
param_idx += 1;
set_clauses.push(format!("capabilities = ${}", param_idx));
caps_val = Some(json);
}
if let Some(t) = temperature {
param_idx += 1;
set_clauses.push(format!("temperature = ${}", param_idx));
temp_val = Some(t);
}
if let Some(m) = max_tokens {
param_idx += 1;
set_clauses.push(format!("max_tokens = ${}", param_idx));
mt_val = Some(m);
}
if let Some(v) = visibility {
param_idx += 1;
set_clauses.push(format!("visibility = ${}", param_idx));
vis_val = Some(v.to_string());
}
if let Some(s) = status {
param_idx += 1;
set_clauses.push(format!("status = ${}", param_idx));
st_val = Some(s.to_string());
}
if set_clauses.is_empty() {
return get_template(db, id).await;
}
// updated_at
param_idx += 1;
set_clauses.push(format!("updated_at = ${}", param_idx));
// WHERE id = $N
let id_idx = param_idx + 1;
let sql = format!(
"UPDATE agent_templates SET {} WHERE id = ${}",
set_clauses.join(", "), id_idx
);
let mut q = sqlx::query(&sql);
if let Some(ref v) = desc_val { q = q.bind(v); }
if let Some(ref v) = model_val { q = q.bind(v); }
if let Some(ref v) = sp_val { q = q.bind(v); }
if let Some(ref v) = tools_val { q = q.bind(v); }
if let Some(ref v) = caps_val { q = q.bind(v); }
if let Some(v) = temp_val { q = q.bind(v); }
if let Some(v) = mt_val { q = q.bind(v); }
if let Some(ref v) = vis_val { q = q.bind(v); }
if let Some(ref v) = st_val { q = q.bind(v); }
q = q.bind(&now);
q = q.bind(id);
q.execute(db).await?;
sqlx::query(
"UPDATE agent_templates SET
description = COALESCE($1, description),
model = COALESCE($2, model),
system_prompt = COALESCE($3, system_prompt),
tools = COALESCE($4, tools),
capabilities = COALESCE($5, capabilities),
temperature = COALESCE($6, temperature),
max_tokens = COALESCE($7, max_tokens),
visibility = COALESCE($8, visibility),
status = COALESCE($9, status),
updated_at = $10
WHERE id = $11"
)
.bind(description)
.bind(model)
.bind(system_prompt)
.bind(tools_json.as_deref())
.bind(caps_json.as_deref())
.bind(temperature)
.bind(max_tokens)
.bind(visibility)
.bind(status)
.bind(&now)
.bind(id)
.execute(db).await?;
get_template(db, id).await
}

View File

@@ -8,7 +8,7 @@ use crate::error::{SaasError, SaasResult};
use crate::models::{AccountAuthRow, AccountLoginRow};
use super::{
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
password::{hash_password, verify_password},
password::{hash_password_async, verify_password_async},
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
};
@@ -25,7 +25,8 @@ pub async fn register(
if req.username.len() > 32 {
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
}
let username_re = regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
static USERNAME_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
let username_re = USERNAME_RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
if !username_re.is_match(&req.username) {
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
}
@@ -56,7 +57,7 @@ pub async fn register(
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
}
let password_hash = hash_password(&req.password)?;
let password_hash = hash_password_async(req.password.clone()).await?;
let account_id = uuid::Uuid::new_v4().to_string();
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
let display_name = req.display_name.unwrap_or_default();
@@ -138,7 +139,7 @@ pub async fn login(
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
}
if !verify_password(&req.password, &r.password_hash)? {
if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? {
return Err(SaasError::AuthError("用户名或密码错误".into()));
}
@@ -328,12 +329,12 @@ pub async fn change_password(
.await?;
// 验证旧密码
if !verify_password(&req.old_password, &password_hash)? {
if !verify_password_async(req.old_password.clone(), password_hash.clone()).await? {
return Err(SaasError::AuthError("旧密码错误".into()));
}
// 更新密码
let new_hash = hash_password(&req.new_password)?;
let new_hash = hash_password_async(req.new_password.clone()).await?;
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
.bind(&new_hash)

View File

@@ -1,4 +1,8 @@
//! 密码哈希 (Argon2id)
//!
//! Argon2 是 CPU 密集型操作(~100-500ms不能在 tokio worker 线程上直接执行,
//! 否则会阻塞整个异步运行时。所有 async 上下文必须使用 `hash_password_async`
//! 和 `verify_password_async`。
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
@@ -7,7 +11,7 @@ use argon2::{
use crate::error::{SaasError, SaasResult};
/// 哈希密码
/// 哈希密码(同步版本,仅用于测试和启动时 seed
pub fn hash_password(password: &str) -> SaasResult<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
@@ -17,7 +21,7 @@ pub fn hash_password(password: &str) -> SaasResult<String> {
Ok(hash.to_string())
}
/// 验证密码
/// 验证密码(同步版本,仅用于测试)
pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| SaasError::PasswordHash(e.to_string()))?;
@@ -26,6 +30,20 @@ pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
.is_ok())
}
/// 异步哈希密码 — 在 spawn_blocking 线程池中执行 Argon2
pub async fn hash_password_async(password: String) -> SaasResult<String> {
tokio::task::spawn_blocking(move || hash_password(&password))
.await
.map_err(|e| SaasError::Internal(format!("spawn_blocking error: {e}")))?
}
/// 异步验证密码 — 在 spawn_blocking 线程池中执行 Argon2
pub async fn verify_password_async(password: String, hash: String) -> SaasResult<bool> {
tokio::task::spawn_blocking(move || verify_password(&password, &hash))
.await
.map_err(|e| SaasError::Internal(format!("spawn_blocking error: {e}")))?
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -212,7 +212,7 @@ pub async fn disable_totp(
.fetch_one(&state.db)
.await?;
if !crate::auth::password::verify_password(&req.password, &password_hash)? {
if !crate::auth::password::verify_password_async(req.password.clone(), password_hash.clone()).await? {
return Err(SaasError::AuthError("密码错误".into()));
}

View File

@@ -150,11 +150,10 @@ pub async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
.await?;
if let Some((account_id,)) = existing {
// 更新现有用户的密码和角色
use crate::auth::password::hash_password;
let password_hash = hash_password(&admin_password)?;
// 更新现有用户的密码和角色(使用 spawn_blocking 避免阻塞 tokio 运行时)
let password_hash = crate::auth::password::hash_password_async(admin_password.clone()).await?;
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"UPDATE accounts SET password_hash = $1, role = 'super_admin', updated_at = $2 WHERE id = $3"
)
@@ -163,12 +162,11 @@ pub async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
.bind(&account_id)
.execute(pool)
.await?;
tracing::info!("已更新用户 {} 的密码和角色为 super_admin", admin_username);
} else {
// 创建新的 super_admin 账号
use crate::auth::password::hash_password;
let password_hash = hash_password(&admin_password)?;
let password_hash = crate::auth::password::hash_password_async(admin_password.clone()).await?;
let account_id = uuid::Uuid::new_v4().to_string();
let email = format!("{}@zclaw.local", admin_username);
let now = chrono::Utc::now().to_rfc3339();

View File

@@ -54,39 +54,50 @@ pub async fn list_config_items(
) -> SaasResult<PaginatedResponse<ConfigItemInfo>> {
let (p, ps, offset) = normalize_pagination(page, page_size);
// Build WHERE clause for count and data queries
let (where_clause, has_category, has_source) = match (&query.category, &query.source) {
(Some(_), Some(_)) => ("WHERE category = $1 AND source = $2", true, true),
(Some(_), None) => ("WHERE category = $1", true, false),
(None, Some(_)) => ("WHERE source = $1", false, true),
(None, None) => ("", false, false),
// Static SQL per combination -- no format!() string interpolation
let (total, rows) = match (&query.category, &query.source) {
(Some(cat), Some(src)) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM config_items WHERE category = $1 AND source = $2"
).bind(cat).bind(src).fetch_one(db).await?;
let rows = sqlx::query_as::<_, ConfigItemRow>(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE category = $1 AND source = $2 ORDER BY category, key_path LIMIT $3 OFFSET $4"
).bind(cat).bind(src).bind(ps as i64).bind(offset).fetch_all(db).await?;
(total, rows)
}
(Some(cat), None) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM config_items WHERE category = $1"
).bind(cat).fetch_one(db).await?;
let rows = sqlx::query_as::<_, ConfigItemRow>(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE category = $1 ORDER BY category, key_path LIMIT $2 OFFSET $3"
).bind(cat).bind(ps as i64).bind(offset).fetch_all(db).await?;
(total, rows)
}
(None, Some(src)) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM config_items WHERE source = $1"
).bind(src).fetch_one(db).await?;
let rows = sqlx::query_as::<_, ConfigItemRow>(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE source = $1 ORDER BY category, key_path LIMIT $2 OFFSET $3"
).bind(src).bind(ps as i64).bind(offset).fetch_all(db).await?;
(total, rows)
}
(None, None) => {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM config_items"
).fetch_one(db).await?;
let rows = sqlx::query_as::<_, ConfigItemRow>(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items ORDER BY category, key_path LIMIT $1 OFFSET $2"
).bind(ps as i64).bind(offset).fetch_all(db).await?;
(total, rows)
}
};
let count_sql = format!("SELECT COUNT(*) FROM config_items {}", where_clause);
let data_sql = format!(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items {} ORDER BY category, key_path LIMIT {} OFFSET {}",
where_clause, "$p", "$o"
);
// Determine param indices for LIMIT/OFFSET based on filter params
let (limit_idx, offset_idx) = match (has_category, has_source) {
(true, true) => ("$3", "$4"),
(true, false) | (false, true) => ("$2", "$3"),
(false, false) => ("$1", "$2"),
};
let data_sql = data_sql.replace("$p", limit_idx).replace("$o", offset_idx);
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
if has_category { count_query = count_query.bind(&query.category); }
if has_source { count_query = count_query.bind(&query.source); }
let total: i64 = count_query.fetch_one(db).await?;
let mut data_query = sqlx::query_as::<_, ConfigItemRow>(&data_sql);
if has_category { data_query = data_query.bind(&query.category); }
if has_source { data_query = data_query.bind(&query.source); }
let rows = data_query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|r| {
ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
@@ -146,29 +157,23 @@ pub async fn update_config_item(
db: &PgPool, item_id: &str, req: &UpdateConfigItemRequest,
) -> SaasResult<ConfigItemInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<String> = Vec::new();
let mut param_idx = 1usize;
if let Some(ref v) = req.current_value { updates.push(format!("current_value = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.source { updates.push(format!("source = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.description { updates.push(format!("description = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if updates.is_empty() {
return get_config_item(db, item_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
params.push(now);
param_idx += 1;
params.push(item_id.to_string());
let sql = format!("UPDATE config_items SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(p);
}
query.execute(db).await?;
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE config_items SET
current_value = COALESCE($1, current_value),
source = COALESCE($2, source),
description = COALESCE($3, description),
updated_at = $4
WHERE id = $5"
)
.bind(req.current_value.as_deref())
.bind(req.source.as_deref())
.bind(req.description.as_deref())
.bind(&now)
.bind(item_id)
.execute(db).await?;
get_config_item(db, item_id).await
}

View File

@@ -104,36 +104,38 @@ pub async fn update_provider(
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx = 1;
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_key {
let encrypted = if v.is_empty() { String::new() } else { crypto::encrypt_value(v, enc_key)? };
updates.push(format!("api_key = ${}", param_idx)); params.push(Box::new(encrypted)); param_idx += 1;
}
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
// Encrypt api_key upfront if provided
let encrypted_api_key = match req.api_key {
Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?),
Some(ref v) if v.is_empty() => Some(String::new()),
_ => None,
};
if updates.is_empty() {
return get_provider(db, provider_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
params.push(Box::new(now.clone()));
param_idx += 1;
params.push(Box::new(provider_id.to_string()));
let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
}
query.execute(db).await?;
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE providers SET
display_name = COALESCE($1, display_name),
base_url = COALESCE($2, base_url),
api_protocol = COALESCE($3, api_protocol),
api_key = COALESCE($4, api_key),
enabled = COALESCE($5, enabled),
rate_limit_rpm = COALESCE($6, rate_limit_rpm),
rate_limit_tpm = COALESCE($7, rate_limit_tpm),
updated_at = $8
WHERE id = $9"
)
.bind(req.display_name.as_deref())
.bind(req.base_url.as_deref())
.bind(req.api_protocol.as_deref())
.bind(encrypted_api_key.as_deref())
.bind(req.enabled)
.bind(req.rate_limit_rpm)
.bind(req.rate_limit_tpm)
.bind(&now)
.bind(provider_id)
.execute(db).await?;
get_provider(db, provider_id).await
}
@@ -245,34 +247,33 @@ pub async fn update_model(
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now().to_rfc3339();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx = 1;
if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() {
return get_model(db, model_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
params.push(Box::new(now.clone()));
param_idx += 1;
params.push(Box::new(model_id.to_string()));
let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
}
query.execute(db).await?;
// COALESCE pattern: all updatable fields in a single static SQL.
// NULL parameters leave the column unchanged.
sqlx::query(
"UPDATE models SET
alias = COALESCE($1, alias),
context_window = COALESCE($2, context_window),
max_output_tokens = COALESCE($3, max_output_tokens),
supports_streaming = COALESCE($4, supports_streaming),
supports_vision = COALESCE($5, supports_vision),
enabled = COALESCE($6, enabled),
pricing_input = COALESCE($7, pricing_input),
pricing_output = COALESCE($8, pricing_output),
updated_at = $9
WHERE id = $10"
)
.bind(req.alias.as_deref())
.bind(req.context_window)
.bind(req.max_output_tokens)
.bind(req.supports_streaming)
.bind(req.supports_vision)
.bind(req.enabled)
.bind(req.pricing_input)
.bind(req.pricing_output)
.bind(&now)
.bind(model_id)
.execute(db).await?;
get_model(db, model_id).await
}
@@ -401,58 +402,33 @@ pub async fn revoke_account_api_key(
pub async fn get_usage_stats(
db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> {
let mut param_idx = 1;
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
let mut params: Vec<String> = vec![account_id.to_string()];
param_idx += 1;
// Static SQL with conditional filter pattern:
// account_id is always required; optional filters use ($N IS NULL OR col = $N).
let total_sql = "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5)";
if let Some(ref from) = query.from {
where_clauses.push(format!("created_at >= ${}", param_idx));
params.push(from.clone());
param_idx += 1;
}
if let Some(ref to) = query.to {
where_clauses.push(format!("created_at <= ${}", param_idx));
params.push(to.clone());
param_idx += 1;
}
if let Some(ref pid) = query.provider_id {
where_clauses.push(format!("provider_id = ${}", param_idx));
params.push(pid.clone());
param_idx += 1;
}
if let Some(ref mid) = query.model_id {
where_clauses.push(format!("model_id = ${}", param_idx));
params.push(mid.clone());
}
let where_sql = where_clauses.join(" AND ");
// 总量统计
let total_sql = format!(
"SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {}", where_sql
);
let mut total_query = sqlx::query(&total_sql);
for p in &params {
total_query = total_query.bind(p);
}
let row = total_query.fetch_one(db).await?;
let row = sqlx::query(total_sql)
.bind(account_id)
.bind(&query.from)
.bind(&query.to)
.bind(&query.provider_id)
.bind(&query.model_id)
.fetch_one(db).await?;
let total_requests: i64 = row.try_get(0).unwrap_or(0);
let total_input: i64 = row.try_get(1).unwrap_or(0);
let total_output: i64 = row.try_get(2).unwrap_or(0);
// 按模型统计
let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
where_sql
);
let mut by_model_query = sqlx::query_as::<_, UsageByModelRow>(&by_model_sql);
for p in &params {
by_model_query = by_model_query.bind(p);
}
let by_model_rows = by_model_query.fetch_all(db).await?;
let by_model_sql = "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5) GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20";
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(by_model_sql)
.bind(account_id)
.bind(&query.from)
.bind(&query.to)
.bind(&query.provider_id)
.bind(&query.model_id)
.fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|r| {
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }

View File

@@ -76,62 +76,30 @@ pub async fn get_template_by_name(db: &PgPool, name: &str) -> SaasResult<PromptT
}
/// 列表模板
/// Static SQL with conditional filter pattern: ($N IS NULL OR col = $N).
pub async fn list_templates(
db: &PgPool,
query: &PromptListQuery,
) -> SaasResult<PaginatedResponse<PromptTemplateInfo>> {
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
// 使用参数化查询构建,防止 SQL 注入
let mut param_idx = 1usize;
let mut conditions = Vec::new();
let mut cat_bind: Option<String> = None;
let mut src_bind: Option<String> = None;
let mut status_bind: Option<String> = None;
let count_sql = "SELECT COUNT(*) FROM prompt_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR status = $3)";
let data_sql = "SELECT id, name, category, description, source, current_version, status, created_at, updated_at \
FROM prompt_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR status = $3) ORDER BY updated_at DESC LIMIT $4 OFFSET $5";
if let Some(ref cat) = query.category {
conditions.push(format!("category = ${}", param_idx));
cat_bind = Some(cat.clone());
param_idx += 1;
}
if let Some(ref src) = query.source {
conditions.push(format!("source = ${}", param_idx));
src_bind = Some(src.clone());
param_idx += 1;
}
if let Some(ref st) = query.status {
conditions.push(format!("status = ${}", param_idx));
status_bind = Some(st.clone());
param_idx += 1;
}
let total: i64 = sqlx::query_scalar(count_sql)
.bind(&query.category)
.bind(&query.source)
.bind(&query.status)
.fetch_one(db).await?;
let where_clause = if conditions.is_empty() {
"1=1".to_string()
} else {
conditions.join(" AND ")
};
let count_sql = format!("SELECT COUNT(*) FROM prompt_templates WHERE {}", where_clause);
let data_sql = format!(
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at \
FROM prompt_templates WHERE {} ORDER BY updated_at DESC LIMIT {} OFFSET {}",
where_clause, page_size, offset
);
// 动态绑定参数到 count 查询
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
if let Some(ref v) = cat_bind { count_query = count_query.bind(v); }
if let Some(ref v) = src_bind { count_query = count_query.bind(v); }
if let Some(ref v) = status_bind { count_query = count_query.bind(v); }
let total = count_query.fetch_one(db).await?;
// 动态绑定参数到 data 查询
let mut data_query = sqlx::query_as::<_, PromptTemplateRow>(&data_sql);
if let Some(ref v) = cat_bind { data_query = data_query.bind(v); }
if let Some(ref v) = src_bind { data_query = data_query.bind(v); }
if let Some(ref v) = status_bind { data_query = data_query.bind(v); }
data_query = data_query.bind(page_size as i64).bind(offset as i64);
let rows = data_query.fetch_all(db).await?;
let rows: Vec<PromptTemplateRow> = sqlx::query_as(data_sql)
.bind(&query.category)
.bind(&query.source)
.bind(&query.status)
.bind(page_size as i64)
.bind(offset as i64)
.fetch_all(db).await?;
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|r| {
PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at }

View File

@@ -43,12 +43,13 @@ pub async fn chat_completions(
}
// --- 输入验证 ---
// 请求体大小限制 (1 MB)
// 请求体大小限制 (1 MB) — 直接序列化一次,后续复用
const MAX_BODY_BYTES: usize = 1024 * 1024;
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
if estimated_size > MAX_BODY_BYTES {
let request_body = serde_json::to_string(&req)
.map_err(|e| SaasError::InvalidInput(format!("请求体序列化失败: {}", e)))?;
if request_body.len() > MAX_BODY_BYTES {
return Err(SaasError::InvalidInput(
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
format!("请求体超过大小限制 ({} bytes > {} bytes)", request_body.len(), MAX_BODY_BYTES)
));
}
@@ -147,7 +148,7 @@ pub async fn chat_completions(
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
let request_body = serde_json::to_string(&req)?;
// request_body 已在前面序列化并验证大小,直接复用
// 创建中转任务(提取配置后立即释放读锁)
let (max_attempts, retry_delay_ms, enc_key) = {

View File

@@ -185,10 +185,24 @@ pub async fn execute_relay(
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
.build()
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
// 复用全局 HTTP 客户端,避免每次请求重建 TLS 连接池和 DNS 解析器
static SHORT_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
static LONG_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
let client = if stream {
LONG_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.expect("Failed to build long-timeout HTTP client")
})
} else {
SHORT_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to build short-timeout HTTP client")
})
};
let max_attempts = max_attempts.max(1).min(5);