feat(auth): 添加异步密码哈希和验证函数
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
refactor(relay): 复用HTTP客户端和请求体序列化结果 feat(kernel): 添加获取单个审批记录的方法 fix(store): 改进SaaS连接错误分类和降级处理 docs: 更新审计文档和系统架构文档 refactor(prompt): 优化SQL查询参数化绑定 refactor(migration): 使用静态SQL和COALESCE更新配置项 feat(commands): 添加审批执行状态追踪和事件通知 chore: 更新启动脚本以支持Admin后台 fix(auth-guard): 优化授权状态管理和错误处理 refactor(db): 使用异步密码哈希函数 refactor(totp): 使用异步密码验证函数 style: 清理无用文件和注释 docs: 更新功能全景和审计文档 refactor(service): 优化HTTP客户端重用和请求处理 fix(connection): 改进SaaS不可用时的降级处理 refactor(handlers): 使用异步密码验证函数 chore: 更新依赖和工具链配置
This commit is contained in:
@@ -823,6 +823,14 @@ impl Kernel {
|
||||
approvals.iter().filter(|a| a.status == "pending").cloned().collect()
|
||||
}
|
||||
|
||||
/// Get a single approval by ID (any status, not just pending)
|
||||
///
|
||||
/// Returns None if no approval with the given ID exists.
|
||||
pub async fn get_approval(&self, id: &str) -> Option<ApprovalEntry> {
|
||||
let approvals = self.pending_approvals.lock().await;
|
||||
approvals.iter().find(|a| a.id == id).cloned()
|
||||
}
|
||||
|
||||
/// Create a pending approval (called when a needs_approval hand is triggered)
|
||||
pub async fn create_approval(&self, hand_id: String, input: serde_json::Value) -> ApprovalEntry {
|
||||
let entry = ApprovalEntry {
|
||||
|
||||
@@ -137,7 +137,8 @@ CREATE TABLE IF NOT EXISTS usage_records (
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((created_at::date));
|
||||
-- idx_usage_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
|
||||
-- CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((created_at::date));
|
||||
|
||||
CREATE TABLE IF NOT EXISTS relay_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -163,7 +164,8 @@ CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_time ON relay_tasks(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((created_at::date));
|
||||
-- idx_relay_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
|
||||
-- CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((created_at::date));
|
||||
|
||||
CREATE TABLE IF NOT EXISTS config_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -318,7 +320,8 @@ CREATE TABLE IF NOT EXISTS telemetry_reports (
|
||||
CREATE INDEX IF NOT EXISTS idx_telemetry_account ON telemetry_reports(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_telemetry_time ON telemetry_reports(reported_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_telemetry_model ON telemetry_reports(model_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_telemetry_day ON telemetry_reports((reported_at::date));
|
||||
-- idx_telemetry_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
|
||||
-- CREATE INDEX IF NOT EXISTS idx_telemetry_day ON telemetry_reports((reported_at::date));
|
||||
|
||||
-- Refresh Token storage (single-use, JWT jti tracking)
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
|
||||
@@ -14,55 +14,109 @@ pub async fn list_accounts(
|
||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
||||
let offset = (page - 1) * page_size;
|
||||
|
||||
let mut where_clauses = Vec::new();
|
||||
let mut params: Vec<String> = Vec::new();
|
||||
let mut param_idx = 1usize;
|
||||
|
||||
if let Some(role) = &query.role {
|
||||
where_clauses.push(format!("role = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(role.clone());
|
||||
}
|
||||
if let Some(status) = &query.status {
|
||||
where_clauses.push(format!("status = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(status.clone());
|
||||
}
|
||||
if let Some(search) = &query.search {
|
||||
where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
|
||||
param_idx += 3;
|
||||
let pattern = format!("%{}%", search);
|
||||
params.push(pattern.clone());
|
||||
params.push(pattern.clone());
|
||||
params.push(pattern);
|
||||
}
|
||||
|
||||
let where_sql = if where_clauses.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("WHERE {}", where_clauses.join(" AND "))
|
||||
// Static SQL per combination -- no format!() string interpolation
|
||||
let (total, rows) = match (&query.role, &query.status, &query.search) {
|
||||
// role + status + search
|
||||
(Some(role), Some(status), Some(search)) => {
|
||||
let pattern = format!("%{}%", search);
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND status = $2 AND (username LIKE $3 OR email LIKE $3 OR display_name LIKE $3)"
|
||||
).bind(role).bind(status).bind(&pattern).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE role = $1 AND status = $2 AND (username LIKE $3 OR email LIKE $3 OR display_name LIKE $3)
|
||||
ORDER BY created_at DESC LIMIT $4 OFFSET $5"
|
||||
).bind(role).bind(status).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// role + status
|
||||
(Some(role), Some(status), None) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND status = $2"
|
||||
).bind(role).bind(status).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE role = $1 AND status = $2
|
||||
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||||
).bind(role).bind(status).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// role + search
|
||||
(Some(role), None, Some(search)) => {
|
||||
let pattern = format!("%{}%", search);
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)"
|
||||
).bind(role).bind(&pattern).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE role = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)
|
||||
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||||
).bind(role).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// status + search
|
||||
(None, Some(status), Some(search)) => {
|
||||
let pattern = format!("%{}%", search);
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE status = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)"
|
||||
).bind(status).bind(&pattern).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE status = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)
|
||||
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||||
).bind(status).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// role only
|
||||
(Some(role), None, None) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE role = $1"
|
||||
).bind(role).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE role = $1
|
||||
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
).bind(role).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// status only
|
||||
(None, Some(status), None) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE status = $1"
|
||||
).bind(status).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE status = $1
|
||||
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
).bind(status).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// search only
|
||||
(None, None, Some(search)) => {
|
||||
let pattern = format!("%{}%", search);
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts WHERE (username LIKE $1 OR email LIKE $1 OR display_name LIKE $1)"
|
||||
).bind(&pattern).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE (username LIKE $1 OR email LIKE $1 OR display_name LIKE $1)
|
||||
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
// no filter
|
||||
(None, None, None) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM accounts"
|
||||
).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, AccountRow>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts ORDER BY created_at DESC LIMIT $1 OFFSET $2"
|
||||
).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
};
|
||||
|
||||
let count_sql = format!("SELECT COUNT(*) as count FROM accounts {}", where_sql);
|
||||
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||
for p in ¶ms {
|
||||
count_query = count_query.bind(p);
|
||||
}
|
||||
let total: i64 = count_query.fetch_one(db).await?;
|
||||
|
||||
let limit_idx = param_idx;
|
||||
let offset_idx = param_idx + 1;
|
||||
let data_sql = format!(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||
where_sql, limit_idx, offset_idx
|
||||
);
|
||||
let mut data_query = sqlx::query_as::<_, AccountRow>(&data_sql);
|
||||
for p in ¶ms {
|
||||
data_query = data_query.bind(p);
|
||||
}
|
||||
let rows = data_query.bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||
|
||||
let items: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
@@ -102,30 +156,26 @@ pub async fn update_account(
|
||||
req: &UpdateAccountRequest,
|
||||
) -> SaasResult<serde_json::Value> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<String> = Vec::new();
|
||||
let mut param_idx = 1usize;
|
||||
|
||||
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||
if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||
if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||
if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||
// COALESCE pattern: all updatable fields in a single static SQL.
|
||||
// NULL parameters leave the column unchanged.
|
||||
sqlx::query(
|
||||
"UPDATE accounts SET
|
||||
display_name = COALESCE($1, display_name),
|
||||
email = COALESCE($2, email),
|
||||
role = COALESCE($3, role),
|
||||
avatar_url = COALESCE($4, avatar_url),
|
||||
updated_at = $5
|
||||
WHERE id = $6"
|
||||
)
|
||||
.bind(req.display_name.as_deref())
|
||||
.bind(req.email.as_deref())
|
||||
.bind(req.role.as_deref())
|
||||
.bind(req.avatar_url.as_deref())
|
||||
.bind(&now)
|
||||
.bind(account_id)
|
||||
.execute(db).await?;
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_account(db, account_id).await;
|
||||
}
|
||||
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(now.clone());
|
||||
params.push(account_id.to_string());
|
||||
|
||||
let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(p);
|
||||
}
|
||||
query.execute(db).await?;
|
||||
get_account(db, account_id).await
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ fn row_to_template(
|
||||
}
|
||||
}
|
||||
|
||||
/// Row type for agent_template queries (avoids multi-line turbofish parsing issues)
|
||||
type AgentTemplateRow = (String, String, Option<String>, String, String, Option<String>, Option<String>, String, String, Option<f64>, Option<i32>, String, String, i32, String, String);
|
||||
|
||||
/// 创建 Agent 模板
|
||||
pub async fn create_template(
|
||||
db: &PgPool,
|
||||
@@ -58,7 +61,7 @@ pub async fn create_template(
|
||||
|
||||
/// 获取单个模板
|
||||
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo> {
|
||||
let row: Option<_> = sqlx::query_as(
|
||||
let row: Option<AgentTemplateRow> = sqlx::query_as(
|
||||
"SELECT id, name, description, category, source, model, system_prompt,
|
||||
tools, capabilities, temperature, max_tokens, visibility, status,
|
||||
current_version, created_at, updated_at
|
||||
@@ -70,7 +73,8 @@ pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo
|
||||
}
|
||||
|
||||
/// 列出模板(分页 + 过滤)
|
||||
/// 使用动态参数化查询,安全拼接 WHERE 条件。
|
||||
/// Static SQL + conditional filter pattern: ($N IS NULL OR col = $N).
|
||||
/// When the parameter is NULL the whole OR evaluates to TRUE (no filter).
|
||||
pub async fn list_templates(
|
||||
db: &PgPool,
|
||||
query: &AgentTemplateListQuery,
|
||||
@@ -79,80 +83,35 @@ pub async fn list_templates(
|
||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
||||
let offset = ((page - 1) * page_size) as i64;
|
||||
|
||||
// 动态构建参数化 WHERE 子句
|
||||
let mut conditions: Vec<String> = vec!["1=1".to_string()];
|
||||
let mut param_idx = 1u32;
|
||||
let mut cat_bind: Option<String> = None;
|
||||
let mut src_bind: Option<String> = None;
|
||||
let mut vis_bind: Option<String> = None;
|
||||
let mut st_bind: Option<String> = None;
|
||||
|
||||
if let Some(ref cat) = query.category {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("category = ${}", param_idx));
|
||||
cat_bind = Some(cat.clone());
|
||||
}
|
||||
if let Some(ref src) = query.source {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("source = ${}", param_idx));
|
||||
src_bind = Some(src.clone());
|
||||
}
|
||||
if let Some(ref vis) = query.visibility {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("visibility = ${}", param_idx));
|
||||
vis_bind = Some(vis.clone());
|
||||
}
|
||||
if let Some(ref st) = query.status {
|
||||
param_idx += 1;
|
||||
conditions.push(format!("status = ${}", param_idx));
|
||||
st_bind = Some(st.clone());
|
||||
}
|
||||
|
||||
let where_clause = conditions.join(" AND ");
|
||||
|
||||
// COUNT 查询: WHERE 参数绑定 ($1..$N)
|
||||
let count_idx = param_idx;
|
||||
let count_sql = format!(
|
||||
"SELECT COUNT(*) FROM agent_templates WHERE {}",
|
||||
where_clause
|
||||
);
|
||||
let count_limit_idx = count_idx + 1;
|
||||
let count_offset_idx = count_limit_idx + 1;
|
||||
let data_sql = format!(
|
||||
"SELECT id, name, description, category, source, model, system_prompt,
|
||||
let count_sql = "SELECT COUNT(*) FROM agent_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR visibility = $3) AND ($4 IS NULL OR status = $4)";
|
||||
let data_sql = "SELECT id, name, description, category, source, model, system_prompt,
|
||||
tools, capabilities, temperature, max_tokens, visibility, status,
|
||||
current_version, created_at, updated_at
|
||||
FROM agent_templates WHERE {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||
where_clause, count_limit_idx, count_offset_idx
|
||||
);
|
||||
FROM agent_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR visibility = $3) AND ($4 IS NULL OR status = $4) ORDER BY created_at DESC LIMIT $5 OFFSET $6";
|
||||
|
||||
// 构建 COUNT 查询并绑定参数
|
||||
let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||
if let Some(ref v) = cat_bind { count_q = count_q.bind(v); }
|
||||
if let Some(ref v) = src_bind { count_q = count_q.bind(v); }
|
||||
if let Some(ref v) = vis_bind { count_q = count_q.bind(v); }
|
||||
if let Some(ref v) = st_bind { count_q = count_q.bind(v); }
|
||||
let total: i64 = count_q.fetch_one(db).await?;
|
||||
let total: i64 = sqlx::query_scalar(count_sql)
|
||||
.bind(&query.category)
|
||||
.bind(&query.source)
|
||||
.bind(&query.visibility)
|
||||
.bind(&query.status)
|
||||
.fetch_one(db).await?;
|
||||
|
||||
// 构建数据查询并绑定参数
|
||||
let mut data_q = sqlx::query_as::<_, (
|
||||
String, String, Option<String>, String, String, Option<String>, Option<String>,
|
||||
String, String, Option<f64>, Option<i32>, String, String, i32, String, String
|
||||
)>(&data_sql);
|
||||
if let Some(ref v) = cat_bind { data_q = data_q.bind(v); }
|
||||
if let Some(ref v) = src_bind { data_q = data_q.bind(v); }
|
||||
if let Some(ref v) = vis_bind { data_q = data_q.bind(v); }
|
||||
if let Some(ref v) = st_bind { data_q = data_q.bind(v); }
|
||||
data_q = data_q.bind(page_size as i64).bind(offset);
|
||||
|
||||
let rows = data_q.fetch_all(db).await?;
|
||||
let rows: Vec<AgentTemplateRow> = sqlx::query_as(data_sql)
|
||||
.bind(&query.category)
|
||||
.bind(&query.source)
|
||||
.bind(&query.visibility)
|
||||
.bind(&query.status)
|
||||
.bind(page_size as i64)
|
||||
.bind(offset)
|
||||
.fetch_all(db).await?;
|
||||
let items = rows.into_iter().map(row_to_template).collect();
|
||||
|
||||
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||||
}
|
||||
|
||||
/// 更新模板
|
||||
/// 使用动态参数化查询,安全拼接 SET 子句。
|
||||
/// COALESCE pattern: all updatable fields in a single static SQL.
|
||||
/// NULL parameters leave the column unchanged.
|
||||
pub async fn update_template(
|
||||
db: &PgPool,
|
||||
id: &str,
|
||||
@@ -166,102 +125,41 @@ pub async fn update_template(
|
||||
visibility: Option<&str>,
|
||||
status: Option<&str>,
|
||||
) -> SaasResult<AgentTemplateInfo> {
|
||||
// 确认存在
|
||||
// Confirm existence
|
||||
get_template(db, id).await?;
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut set_clauses: Vec<String> = vec![];
|
||||
let mut param_idx = 1u32;
|
||||
|
||||
// 收集需要绑定的值(按顺序)
|
||||
let mut desc_val: Option<String> = None;
|
||||
let mut model_val: Option<String> = None;
|
||||
let mut sp_val: Option<String> = None;
|
||||
let mut tools_val: Option<String> = None;
|
||||
let mut caps_val: Option<String> = None;
|
||||
let mut temp_val: Option<f64> = None;
|
||||
let mut mt_val: Option<i32> = None;
|
||||
let mut vis_val: Option<String> = None;
|
||||
let mut st_val: Option<String> = None;
|
||||
// Serialize JSON fields upfront so we can bind Option<&str> consistently
|
||||
let tools_json = tools.map(|t| serde_json::to_string(t).unwrap_or_else(|_| "[]".to_string()));
|
||||
let caps_json = capabilities.map(|c| serde_json::to_string(c).unwrap_or_else(|_| "[]".to_string()));
|
||||
|
||||
if let Some(desc) = description {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("description = ${}", param_idx));
|
||||
desc_val = Some(desc.to_string());
|
||||
}
|
||||
if let Some(m) = model {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("model = ${}", param_idx));
|
||||
model_val = Some(m.to_string());
|
||||
}
|
||||
if let Some(sp) = system_prompt {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("system_prompt = ${}", param_idx));
|
||||
sp_val = Some(sp.to_string());
|
||||
}
|
||||
if let Some(t) = tools {
|
||||
let json = serde_json::to_string(t).unwrap_or_else(|_| "[]".to_string());
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("tools = ${}", param_idx));
|
||||
tools_val = Some(json);
|
||||
}
|
||||
if let Some(c) = capabilities {
|
||||
let json = serde_json::to_string(c).unwrap_or_else(|_| "[]".to_string());
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("capabilities = ${}", param_idx));
|
||||
caps_val = Some(json);
|
||||
}
|
||||
if let Some(t) = temperature {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("temperature = ${}", param_idx));
|
||||
temp_val = Some(t);
|
||||
}
|
||||
if let Some(m) = max_tokens {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("max_tokens = ${}", param_idx));
|
||||
mt_val = Some(m);
|
||||
}
|
||||
if let Some(v) = visibility {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("visibility = ${}", param_idx));
|
||||
vis_val = Some(v.to_string());
|
||||
}
|
||||
if let Some(s) = status {
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("status = ${}", param_idx));
|
||||
st_val = Some(s.to_string());
|
||||
}
|
||||
|
||||
if set_clauses.is_empty() {
|
||||
return get_template(db, id).await;
|
||||
}
|
||||
|
||||
// updated_at
|
||||
param_idx += 1;
|
||||
set_clauses.push(format!("updated_at = ${}", param_idx));
|
||||
|
||||
// WHERE id = $N
|
||||
let id_idx = param_idx + 1;
|
||||
|
||||
let sql = format!(
|
||||
"UPDATE agent_templates SET {} WHERE id = ${}",
|
||||
set_clauses.join(", "), id_idx
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
if let Some(ref v) = desc_val { q = q.bind(v); }
|
||||
if let Some(ref v) = model_val { q = q.bind(v); }
|
||||
if let Some(ref v) = sp_val { q = q.bind(v); }
|
||||
if let Some(ref v) = tools_val { q = q.bind(v); }
|
||||
if let Some(ref v) = caps_val { q = q.bind(v); }
|
||||
if let Some(v) = temp_val { q = q.bind(v); }
|
||||
if let Some(v) = mt_val { q = q.bind(v); }
|
||||
if let Some(ref v) = vis_val { q = q.bind(v); }
|
||||
if let Some(ref v) = st_val { q = q.bind(v); }
|
||||
q = q.bind(&now);
|
||||
q = q.bind(id);
|
||||
|
||||
q.execute(db).await?;
|
||||
sqlx::query(
|
||||
"UPDATE agent_templates SET
|
||||
description = COALESCE($1, description),
|
||||
model = COALESCE($2, model),
|
||||
system_prompt = COALESCE($3, system_prompt),
|
||||
tools = COALESCE($4, tools),
|
||||
capabilities = COALESCE($5, capabilities),
|
||||
temperature = COALESCE($6, temperature),
|
||||
max_tokens = COALESCE($7, max_tokens),
|
||||
visibility = COALESCE($8, visibility),
|
||||
status = COALESCE($9, status),
|
||||
updated_at = $10
|
||||
WHERE id = $11"
|
||||
)
|
||||
.bind(description)
|
||||
.bind(model)
|
||||
.bind(system_prompt)
|
||||
.bind(tools_json.as_deref())
|
||||
.bind(caps_json.as_deref())
|
||||
.bind(temperature)
|
||||
.bind(max_tokens)
|
||||
.bind(visibility)
|
||||
.bind(status)
|
||||
.bind(&now)
|
||||
.bind(id)
|
||||
.execute(db).await?;
|
||||
|
||||
get_template(db, id).await
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::{AccountAuthRow, AccountLoginRow};
|
||||
use super::{
|
||||
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
|
||||
password::{hash_password, verify_password},
|
||||
password::{hash_password_async, verify_password_async},
|
||||
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
|
||||
};
|
||||
|
||||
@@ -25,7 +25,8 @@ pub async fn register(
|
||||
if req.username.len() > 32 {
|
||||
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
|
||||
}
|
||||
let username_re = regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
|
||||
static USERNAME_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
|
||||
let username_re = USERNAME_RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
|
||||
if !username_re.is_match(&req.username) {
|
||||
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
|
||||
}
|
||||
@@ -56,7 +57,7 @@ pub async fn register(
|
||||
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
|
||||
}
|
||||
|
||||
let password_hash = hash_password(&req.password)?;
|
||||
let password_hash = hash_password_async(req.password.clone()).await?;
|
||||
let account_id = uuid::Uuid::new_v4().to_string();
|
||||
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
|
||||
let display_name = req.display_name.unwrap_or_default();
|
||||
@@ -138,7 +139,7 @@ pub async fn login(
|
||||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
|
||||
}
|
||||
|
||||
if !verify_password(&req.password, &r.password_hash)? {
|
||||
if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? {
|
||||
return Err(SaasError::AuthError("用户名或密码错误".into()));
|
||||
}
|
||||
|
||||
@@ -328,12 +329,12 @@ pub async fn change_password(
|
||||
.await?;
|
||||
|
||||
// 验证旧密码
|
||||
if !verify_password(&req.old_password, &password_hash)? {
|
||||
if !verify_password_async(req.old_password.clone(), password_hash.clone()).await? {
|
||||
return Err(SaasError::AuthError("旧密码错误".into()));
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
let new_hash = hash_password(&req.new_password)?;
|
||||
let new_hash = hash_password_async(req.new_password.clone()).await?;
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
|
||||
.bind(&new_hash)
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
//! 密码哈希 (Argon2id)
|
||||
//!
|
||||
//! Argon2 是 CPU 密集型操作(~100-500ms),不能在 tokio worker 线程上直接执行,
|
||||
//! 否则会阻塞整个异步运行时。所有 async 上下文必须使用 `hash_password_async`
|
||||
//! 和 `verify_password_async`。
|
||||
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
@@ -7,7 +11,7 @@ use argon2::{
|
||||
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
|
||||
/// 哈希密码
|
||||
/// 哈希密码(同步版本,仅用于测试和启动时 seed)
|
||||
pub fn hash_password(password: &str) -> SaasResult<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
@@ -17,7 +21,7 @@ pub fn hash_password(password: &str) -> SaasResult<String> {
|
||||
Ok(hash.to_string())
|
||||
}
|
||||
|
||||
/// 验证密码
|
||||
/// 验证密码(同步版本,仅用于测试)
|
||||
pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
|
||||
let parsed_hash = PasswordHash::new(hash)
|
||||
.map_err(|e| SaasError::PasswordHash(e.to_string()))?;
|
||||
@@ -26,6 +30,20 @@ pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
/// 异步哈希密码 — 在 spawn_blocking 线程池中执行 Argon2
|
||||
pub async fn hash_password_async(password: String) -> SaasResult<String> {
|
||||
tokio::task::spawn_blocking(move || hash_password(&password))
|
||||
.await
|
||||
.map_err(|e| SaasError::Internal(format!("spawn_blocking error: {e}")))?
|
||||
}
|
||||
|
||||
/// 异步验证密码 — 在 spawn_blocking 线程池中执行 Argon2
|
||||
pub async fn verify_password_async(password: String, hash: String) -> SaasResult<bool> {
|
||||
tokio::task::spawn_blocking(move || verify_password(&password, &hash))
|
||||
.await
|
||||
.map_err(|e| SaasError::Internal(format!("spawn_blocking error: {e}")))?
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -212,7 +212,7 @@ pub async fn disable_totp(
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
|
||||
if !crate::auth::password::verify_password(&req.password, &password_hash)? {
|
||||
if !crate::auth::password::verify_password_async(req.password.clone(), password_hash.clone()).await? {
|
||||
return Err(SaasError::AuthError("密码错误".into()));
|
||||
}
|
||||
|
||||
|
||||
@@ -150,11 +150,10 @@ pub async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
|
||||
.await?;
|
||||
|
||||
if let Some((account_id,)) = existing {
|
||||
// 更新现有用户的密码和角色
|
||||
use crate::auth::password::hash_password;
|
||||
let password_hash = hash_password(&admin_password)?;
|
||||
// 更新现有用户的密码和角色(使用 spawn_blocking 避免阻塞 tokio 运行时)
|
||||
let password_hash = crate::auth::password::hash_password_async(admin_password.clone()).await?;
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE accounts SET password_hash = $1, role = 'super_admin', updated_at = $2 WHERE id = $3"
|
||||
)
|
||||
@@ -163,12 +162,11 @@ pub async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
|
||||
.bind(&account_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
|
||||
tracing::info!("已更新用户 {} 的密码和角色为 super_admin", admin_username);
|
||||
} else {
|
||||
// 创建新的 super_admin 账号
|
||||
use crate::auth::password::hash_password;
|
||||
let password_hash = hash_password(&admin_password)?;
|
||||
let password_hash = crate::auth::password::hash_password_async(admin_password.clone()).await?;
|
||||
let account_id = uuid::Uuid::new_v4().to_string();
|
||||
let email = format!("{}@zclaw.local", admin_username);
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
@@ -54,39 +54,50 @@ pub async fn list_config_items(
|
||||
) -> SaasResult<PaginatedResponse<ConfigItemInfo>> {
|
||||
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||
|
||||
// Build WHERE clause for count and data queries
|
||||
let (where_clause, has_category, has_source) = match (&query.category, &query.source) {
|
||||
(Some(_), Some(_)) => ("WHERE category = $1 AND source = $2", true, true),
|
||||
(Some(_), None) => ("WHERE category = $1", true, false),
|
||||
(None, Some(_)) => ("WHERE source = $1", false, true),
|
||||
(None, None) => ("", false, false),
|
||||
// Static SQL per combination -- no format!() string interpolation
|
||||
let (total, rows) = match (&query.category, &query.source) {
|
||||
(Some(cat), Some(src)) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM config_items WHERE category = $1 AND source = $2"
|
||||
).bind(cat).bind(src).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, ConfigItemRow>(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE category = $1 AND source = $2 ORDER BY category, key_path LIMIT $3 OFFSET $4"
|
||||
).bind(cat).bind(src).bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
(Some(cat), None) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM config_items WHERE category = $1"
|
||||
).bind(cat).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, ConfigItemRow>(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE category = $1 ORDER BY category, key_path LIMIT $2 OFFSET $3"
|
||||
).bind(cat).bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
(None, Some(src)) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM config_items WHERE source = $1"
|
||||
).bind(src).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, ConfigItemRow>(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE source = $1 ORDER BY category, key_path LIMIT $2 OFFSET $3"
|
||||
).bind(src).bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
(None, None) => {
|
||||
let total: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM config_items"
|
||||
).fetch_one(db).await?;
|
||||
let rows = sqlx::query_as::<_, ConfigItemRow>(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items ORDER BY category, key_path LIMIT $1 OFFSET $2"
|
||||
).bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
(total, rows)
|
||||
}
|
||||
};
|
||||
|
||||
let count_sql = format!("SELECT COUNT(*) FROM config_items {}", where_clause);
|
||||
let data_sql = format!(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items {} ORDER BY category, key_path LIMIT {} OFFSET {}",
|
||||
where_clause, "$p", "$o"
|
||||
);
|
||||
|
||||
// Determine param indices for LIMIT/OFFSET based on filter params
|
||||
let (limit_idx, offset_idx) = match (has_category, has_source) {
|
||||
(true, true) => ("$3", "$4"),
|
||||
(true, false) | (false, true) => ("$2", "$3"),
|
||||
(false, false) => ("$1", "$2"),
|
||||
};
|
||||
let data_sql = data_sql.replace("$p", limit_idx).replace("$o", offset_idx);
|
||||
|
||||
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||
if has_category { count_query = count_query.bind(&query.category); }
|
||||
if has_source { count_query = count_query.bind(&query.source); }
|
||||
let total: i64 = count_query.fetch_one(db).await?;
|
||||
|
||||
let mut data_query = sqlx::query_as::<_, ConfigItemRow>(&data_sql);
|
||||
if has_category { data_query = data_query.bind(&query.category); }
|
||||
if has_source { data_query = data_query.bind(&query.source); }
|
||||
let rows = data_query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
|
||||
let items = rows.into_iter().map(|r| {
|
||||
ConfigItemInfo { id: r.id, category: r.category, key_path: r.key_path, value_type: r.value_type, current_value: r.current_value, default_value: r.default_value, source: r.source, description: r.description, requires_restart: r.requires_restart, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
@@ -146,29 +157,23 @@ pub async fn update_config_item(
|
||||
db: &PgPool, item_id: &str, req: &UpdateConfigItemRequest,
|
||||
) -> SaasResult<ConfigItemInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<String> = Vec::new();
|
||||
let mut param_idx = 1usize;
|
||||
|
||||
if let Some(ref v) = req.current_value { updates.push(format!("current_value = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.source { updates.push(format!("source = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.description { updates.push(format!("description = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_config_item(db, item_id).await;
|
||||
}
|
||||
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
params.push(now);
|
||||
param_idx += 1;
|
||||
params.push(item_id.to_string());
|
||||
|
||||
let sql = format!("UPDATE config_items SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(p);
|
||||
}
|
||||
query.execute(db).await?;
|
||||
// COALESCE pattern: all updatable fields in a single static SQL.
|
||||
// NULL parameters leave the column unchanged.
|
||||
sqlx::query(
|
||||
"UPDATE config_items SET
|
||||
current_value = COALESCE($1, current_value),
|
||||
source = COALESCE($2, source),
|
||||
description = COALESCE($3, description),
|
||||
updated_at = $4
|
||||
WHERE id = $5"
|
||||
)
|
||||
.bind(req.current_value.as_deref())
|
||||
.bind(req.source.as_deref())
|
||||
.bind(req.description.as_deref())
|
||||
.bind(&now)
|
||||
.bind(item_id)
|
||||
.execute(db).await?;
|
||||
|
||||
get_config_item(db, item_id).await
|
||||
}
|
||||
|
||||
@@ -104,36 +104,38 @@ pub async fn update_provider(
|
||||
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
|
||||
) -> SaasResult<ProviderInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||
let mut param_idx = 1;
|
||||
|
||||
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(ref v) = req.api_key {
|
||||
let encrypted = if v.is_empty() { String::new() } else { crypto::encrypt_value(v, enc_key)? };
|
||||
updates.push(format!("api_key = ${}", param_idx)); params.push(Box::new(encrypted)); param_idx += 1;
|
||||
}
|
||||
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
// Encrypt api_key upfront if provided
|
||||
let encrypted_api_key = match req.api_key {
|
||||
Some(ref v) if !v.is_empty() => Some(crypto::encrypt_value(v, enc_key)?),
|
||||
Some(ref v) if v.is_empty() => Some(String::new()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_provider(db, provider_id).await;
|
||||
}
|
||||
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
params.push(Box::new(now.clone()));
|
||||
param_idx += 1;
|
||||
params.push(Box::new(provider_id.to_string()));
|
||||
|
||||
let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(format!("{}", p));
|
||||
}
|
||||
query.execute(db).await?;
|
||||
// COALESCE pattern: all updatable fields in a single static SQL.
|
||||
// NULL parameters leave the column unchanged.
|
||||
sqlx::query(
|
||||
"UPDATE providers SET
|
||||
display_name = COALESCE($1, display_name),
|
||||
base_url = COALESCE($2, base_url),
|
||||
api_protocol = COALESCE($3, api_protocol),
|
||||
api_key = COALESCE($4, api_key),
|
||||
enabled = COALESCE($5, enabled),
|
||||
rate_limit_rpm = COALESCE($6, rate_limit_rpm),
|
||||
rate_limit_tpm = COALESCE($7, rate_limit_tpm),
|
||||
updated_at = $8
|
||||
WHERE id = $9"
|
||||
)
|
||||
.bind(req.display_name.as_deref())
|
||||
.bind(req.base_url.as_deref())
|
||||
.bind(req.api_protocol.as_deref())
|
||||
.bind(encrypted_api_key.as_deref())
|
||||
.bind(req.enabled)
|
||||
.bind(req.rate_limit_rpm)
|
||||
.bind(req.rate_limit_tpm)
|
||||
.bind(&now)
|
||||
.bind(provider_id)
|
||||
.execute(db).await?;
|
||||
|
||||
get_provider(db, provider_id).await
|
||||
}
|
||||
@@ -245,34 +247,33 @@ pub async fn update_model(
|
||||
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
|
||||
) -> SaasResult<ModelInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||
let mut param_idx = 1;
|
||||
|
||||
if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_model(db, model_id).await;
|
||||
}
|
||||
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
params.push(Box::new(now.clone()));
|
||||
param_idx += 1;
|
||||
params.push(Box::new(model_id.to_string()));
|
||||
|
||||
let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(format!("{}", p));
|
||||
}
|
||||
query.execute(db).await?;
|
||||
// COALESCE pattern: all updatable fields in a single static SQL.
|
||||
// NULL parameters leave the column unchanged.
|
||||
sqlx::query(
|
||||
"UPDATE models SET
|
||||
alias = COALESCE($1, alias),
|
||||
context_window = COALESCE($2, context_window),
|
||||
max_output_tokens = COALESCE($3, max_output_tokens),
|
||||
supports_streaming = COALESCE($4, supports_streaming),
|
||||
supports_vision = COALESCE($5, supports_vision),
|
||||
enabled = COALESCE($6, enabled),
|
||||
pricing_input = COALESCE($7, pricing_input),
|
||||
pricing_output = COALESCE($8, pricing_output),
|
||||
updated_at = $9
|
||||
WHERE id = $10"
|
||||
)
|
||||
.bind(req.alias.as_deref())
|
||||
.bind(req.context_window)
|
||||
.bind(req.max_output_tokens)
|
||||
.bind(req.supports_streaming)
|
||||
.bind(req.supports_vision)
|
||||
.bind(req.enabled)
|
||||
.bind(req.pricing_input)
|
||||
.bind(req.pricing_output)
|
||||
.bind(&now)
|
||||
.bind(model_id)
|
||||
.execute(db).await?;
|
||||
|
||||
get_model(db, model_id).await
|
||||
}
|
||||
@@ -401,58 +402,33 @@ pub async fn revoke_account_api_key(
|
||||
pub async fn get_usage_stats(
|
||||
db: &PgPool, account_id: &str, query: &UsageQuery,
|
||||
) -> SaasResult<UsageStats> {
|
||||
let mut param_idx = 1;
|
||||
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
|
||||
let mut params: Vec<String> = vec![account_id.to_string()];
|
||||
param_idx += 1;
|
||||
// Static SQL with conditional filter pattern:
|
||||
// account_id is always required; optional filters use ($N IS NULL OR col = $N).
|
||||
let total_sql = "SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5)";
|
||||
|
||||
if let Some(ref from) = query.from {
|
||||
where_clauses.push(format!("created_at >= ${}", param_idx));
|
||||
params.push(from.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref to) = query.to {
|
||||
where_clauses.push(format!("created_at <= ${}", param_idx));
|
||||
params.push(to.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref pid) = query.provider_id {
|
||||
where_clauses.push(format!("provider_id = ${}", param_idx));
|
||||
params.push(pid.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref mid) = query.model_id {
|
||||
where_clauses.push(format!("model_id = ${}", param_idx));
|
||||
params.push(mid.clone());
|
||||
}
|
||||
|
||||
let where_sql = where_clauses.join(" AND ");
|
||||
|
||||
// 总量统计
|
||||
let total_sql = format!(
|
||||
"SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
FROM usage_records WHERE {}", where_sql
|
||||
);
|
||||
let mut total_query = sqlx::query(&total_sql);
|
||||
for p in ¶ms {
|
||||
total_query = total_query.bind(p);
|
||||
}
|
||||
let row = total_query.fetch_one(db).await?;
|
||||
let row = sqlx::query(total_sql)
|
||||
.bind(account_id)
|
||||
.bind(&query.from)
|
||||
.bind(&query.to)
|
||||
.bind(&query.provider_id)
|
||||
.bind(&query.model_id)
|
||||
.fetch_one(db).await?;
|
||||
let total_requests: i64 = row.try_get(0).unwrap_or(0);
|
||||
let total_input: i64 = row.try_get(1).unwrap_or(0);
|
||||
let total_output: i64 = row.try_get(2).unwrap_or(0);
|
||||
|
||||
// 按模型统计
|
||||
let by_model_sql = format!(
|
||||
"SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
||||
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
|
||||
where_sql
|
||||
);
|
||||
let mut by_model_query = sqlx::query_as::<_, UsageByModelRow>(&by_model_sql);
|
||||
for p in ¶ms {
|
||||
by_model_query = by_model_query.bind(p);
|
||||
}
|
||||
let by_model_rows = by_model_query.fetch_all(db).await?;
|
||||
let by_model_sql = "SELECT provider_id, model_id, COUNT(*)::bigint AS request_count, COALESCE(SUM(input_tokens), 0) AS input_tokens, COALESCE(SUM(output_tokens), 0) AS output_tokens
|
||||
FROM usage_records WHERE account_id = $1 AND ($2 IS NULL OR created_at >= $2) AND ($3 IS NULL OR created_at <= $3) AND ($4 IS NULL OR provider_id = $4) AND ($5 IS NULL OR model_id = $5) GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20";
|
||||
|
||||
let by_model_rows: Vec<UsageByModelRow> = sqlx::query_as(by_model_sql)
|
||||
.bind(account_id)
|
||||
.bind(&query.from)
|
||||
.bind(&query.to)
|
||||
.bind(&query.provider_id)
|
||||
.bind(&query.model_id)
|
||||
.fetch_all(db).await?;
|
||||
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
|
||||
.map(|r| {
|
||||
ModelUsage { provider_id: r.provider_id, model_id: r.model_id, request_count: r.request_count, input_tokens: r.input_tokens, output_tokens: r.output_tokens }
|
||||
|
||||
@@ -76,62 +76,30 @@ pub async fn get_template_by_name(db: &PgPool, name: &str) -> SaasResult<PromptT
|
||||
}
|
||||
|
||||
/// 列表模板
|
||||
/// Static SQL with conditional filter pattern: ($N IS NULL OR col = $N).
|
||||
pub async fn list_templates(
|
||||
db: &PgPool,
|
||||
query: &PromptListQuery,
|
||||
) -> SaasResult<PaginatedResponse<PromptTemplateInfo>> {
|
||||
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
|
||||
|
||||
// 使用参数化查询构建,防止 SQL 注入
|
||||
let mut param_idx = 1usize;
|
||||
let mut conditions = Vec::new();
|
||||
let mut cat_bind: Option<String> = None;
|
||||
let mut src_bind: Option<String> = None;
|
||||
let mut status_bind: Option<String> = None;
|
||||
let count_sql = "SELECT COUNT(*) FROM prompt_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR status = $3)";
|
||||
let data_sql = "SELECT id, name, category, description, source, current_version, status, created_at, updated_at \
|
||||
FROM prompt_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR status = $3) ORDER BY updated_at DESC LIMIT $4 OFFSET $5";
|
||||
|
||||
if let Some(ref cat) = query.category {
|
||||
conditions.push(format!("category = ${}", param_idx));
|
||||
cat_bind = Some(cat.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref src) = query.source {
|
||||
conditions.push(format!("source = ${}", param_idx));
|
||||
src_bind = Some(src.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(ref st) = query.status {
|
||||
conditions.push(format!("status = ${}", param_idx));
|
||||
status_bind = Some(st.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
let total: i64 = sqlx::query_scalar(count_sql)
|
||||
.bind(&query.category)
|
||||
.bind(&query.source)
|
||||
.bind(&query.status)
|
||||
.fetch_one(db).await?;
|
||||
|
||||
let where_clause = if conditions.is_empty() {
|
||||
"1=1".to_string()
|
||||
} else {
|
||||
conditions.join(" AND ")
|
||||
};
|
||||
|
||||
let count_sql = format!("SELECT COUNT(*) FROM prompt_templates WHERE {}", where_clause);
|
||||
let data_sql = format!(
|
||||
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at \
|
||||
FROM prompt_templates WHERE {} ORDER BY updated_at DESC LIMIT {} OFFSET {}",
|
||||
where_clause, page_size, offset
|
||||
);
|
||||
|
||||
// 动态绑定参数到 count 查询
|
||||
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||
if let Some(ref v) = cat_bind { count_query = count_query.bind(v); }
|
||||
if let Some(ref v) = src_bind { count_query = count_query.bind(v); }
|
||||
if let Some(ref v) = status_bind { count_query = count_query.bind(v); }
|
||||
let total = count_query.fetch_one(db).await?;
|
||||
|
||||
// 动态绑定参数到 data 查询
|
||||
let mut data_query = sqlx::query_as::<_, PromptTemplateRow>(&data_sql);
|
||||
if let Some(ref v) = cat_bind { data_query = data_query.bind(v); }
|
||||
if let Some(ref v) = src_bind { data_query = data_query.bind(v); }
|
||||
if let Some(ref v) = status_bind { data_query = data_query.bind(v); }
|
||||
data_query = data_query.bind(page_size as i64).bind(offset as i64);
|
||||
let rows = data_query.fetch_all(db).await?;
|
||||
let rows: Vec<PromptTemplateRow> = sqlx::query_as(data_sql)
|
||||
.bind(&query.category)
|
||||
.bind(&query.source)
|
||||
.bind(&query.status)
|
||||
.bind(page_size as i64)
|
||||
.bind(offset as i64)
|
||||
.fetch_all(db).await?;
|
||||
|
||||
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|r| {
|
||||
PromptTemplateInfo { id: r.id, name: r.name, category: r.category, description: r.description, source: r.source, current_version: r.current_version, status: r.status, created_at: r.created_at, updated_at: r.updated_at }
|
||||
|
||||
@@ -43,12 +43,13 @@ pub async fn chat_completions(
|
||||
}
|
||||
|
||||
// --- 输入验证 ---
|
||||
// 请求体大小限制 (1 MB)
|
||||
// 请求体大小限制 (1 MB) — 直接序列化一次,后续复用
|
||||
const MAX_BODY_BYTES: usize = 1024 * 1024;
|
||||
let estimated_size = serde_json::to_string(&req).map(|s| s.len()).unwrap_or(0);
|
||||
if estimated_size > MAX_BODY_BYTES {
|
||||
let request_body = serde_json::to_string(&req)
|
||||
.map_err(|e| SaasError::InvalidInput(format!("请求体序列化失败: {}", e)))?;
|
||||
if request_body.len() > MAX_BODY_BYTES {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("请求体超过大小限制 ({} bytes > {} bytes)", estimated_size, MAX_BODY_BYTES)
|
||||
format!("请求体超过大小限制 ({} bytes > {} bytes)", request_body.len(), MAX_BODY_BYTES)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -147,7 +148,7 @@ pub async fn chat_completions(
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
// request_body 已在前面序列化并验证大小,直接复用
|
||||
|
||||
// 创建中转任务(提取配置后立即释放读锁)
|
||||
let (max_attempts, retry_delay_ms, enc_key) = {
|
||||
|
||||
@@ -185,10 +185,24 @@ pub async fn execute_relay(
|
||||
|
||||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
|
||||
.build()
|
||||
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
|
||||
// 复用全局 HTTP 客户端,避免每次请求重建 TLS 连接池和 DNS 解析器
|
||||
static SHORT_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||||
static LONG_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
||||
let client = if stream {
|
||||
LONG_CLIENT.get_or_init(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
.expect("Failed to build long-timeout HTTP client")
|
||||
})
|
||||
} else {
|
||||
SHORT_CLIENT.get_or_init(|| {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to build short-timeout HTTP client")
|
||||
})
|
||||
};
|
||||
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user