feat: 增强SaaS后端功能与安全性
refactor: 重构数据库连接使用PostgreSQL替代SQLite feat(auth): 增加JWT验证的audience和issuer检查 feat(crypto): 添加AES-256-GCM字段加密支持 feat(api): 集成utoipa实现OpenAPI文档 fix(admin): 修复配置项表单验证逻辑 style: 统一代码格式与类型定义 docs: 更新技术栈文档说明PostgreSQL
This commit is contained in:
@@ -9,8 +9,6 @@ name = "zclaw-saas"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
zclaw-types = { workspace = true }
|
||||
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
@@ -23,7 +21,6 @@ chrono = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
libsqlite3-sys = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
secrecy = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
@@ -34,6 +31,8 @@ url = "2"
|
||||
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
@@ -41,6 +40,9 @@ argon2 = { workspace = true }
|
||||
totp-rs = { workspace = true }
|
||||
urlencoding = "2"
|
||||
data-encoding = "2"
|
||||
aes-gcm = { workspace = true }
|
||||
utoipa = { version = "5", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "5", features = ["axum"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -121,23 +121,43 @@ pub async fn list_operation_logs(
|
||||
let page: i64 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1);
|
||||
let page_size: i64 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50);
|
||||
let offset = (page - 1) * page_size;
|
||||
let action_filter = params.get("action").map(|s| s.as_str());
|
||||
let target_type_filter = params.get("target_type").map(|s| s.as_str());
|
||||
|
||||
let rows: Vec<(i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, String)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
||||
FROM operation_logs ORDER BY created_at DESC LIMIT ?1 OFFSET ?2"
|
||||
)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
let mut sql = String::from(
|
||||
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
||||
FROM operation_logs"
|
||||
);
|
||||
let mut param_idx: usize = 1;
|
||||
if action_filter.is_some() || target_type_filter.is_some() {
|
||||
sql.push_str(" WHERE 1=1");
|
||||
if action_filter.is_some() {
|
||||
sql.push_str(&format!(" AND action = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
}
|
||||
if target_type_filter.is_some() {
|
||||
sql.push_str(&format!(" AND target_type = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
}
|
||||
}
|
||||
sql.push_str(&format!(" ORDER BY created_at DESC LIMIT ${} OFFSET ${}", param_idx, param_idx + 1));
|
||||
|
||||
let mut query = sqlx::query_as::<_, (i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)>(&sql);
|
||||
if let Some(action) = action_filter {
|
||||
query = query.bind(action);
|
||||
}
|
||||
if let Some(target_type) = target_type_filter {
|
||||
query = query.bind(target_type);
|
||||
}
|
||||
query = query.bind(page_size).bind(offset);
|
||||
let rows = query.fetch_all(&state.db).await?;
|
||||
|
||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| {
|
||||
serde_json::json!({
|
||||
"id": id, "account_id": account_id, "action": action,
|
||||
"target_type": target_type, "target_id": target_id,
|
||||
"details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
|
||||
"ip_address": ip_address, "created_at": created_at,
|
||||
"ip_address": ip_address, "created_at": created_at.to_rfc3339(),
|
||||
})
|
||||
}).collect();
|
||||
|
||||
@@ -151,32 +171,27 @@ pub async fn dashboard_stats(
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
require_admin(&ctx)?;
|
||||
|
||||
let total_accounts: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM accounts")
|
||||
.fetch_one(&state.db).await?;
|
||||
let active_accounts: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM accounts WHERE status = 'active'")
|
||||
.fetch_one(&state.db).await?;
|
||||
let tasks_today: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE date(created_at) = date('now')"
|
||||
).fetch_one(&state.db).await?;
|
||||
let active_providers: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM providers WHERE enabled = 1")
|
||||
.fetch_one(&state.db).await?;
|
||||
let active_models: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM models WHERE enabled = 1")
|
||||
.fetch_one(&state.db).await?;
|
||||
let tokens_today_input: (i64,) = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE date(created_at) = date('now')"
|
||||
).fetch_one(&state.db).await?;
|
||||
let tokens_today_output: (i64,) = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE date(created_at) = date('now')"
|
||||
).fetch_one(&state.db).await?;
|
||||
let row: (i64, i64, i64, i64, i64, i64, i64) = sqlx::query_as(
|
||||
"SELECT
|
||||
(SELECT COUNT(*) FROM accounts),
|
||||
(SELECT COUNT(*) FROM accounts WHERE status = 'active'),
|
||||
(SELECT COUNT(*) FROM relay_tasks WHERE DATE(created_at) = CURRENT_DATE),
|
||||
(SELECT COUNT(*) FROM providers WHERE enabled = true),
|
||||
(SELECT COUNT(*) FROM models WHERE enabled = true),
|
||||
(SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE),
|
||||
(SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE)"
|
||||
)
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"total_accounts": total_accounts.0,
|
||||
"active_accounts": active_accounts.0,
|
||||
"tasks_today": tasks_today.0,
|
||||
"active_providers": active_providers.0,
|
||||
"active_models": active_models.0,
|
||||
"tokens_today_input": tokens_today_input.0,
|
||||
"tokens_today_output": tokens_today_output.0,
|
||||
"total_accounts": row.0,
|
||||
"active_accounts": row.1,
|
||||
"tasks_today": row.2,
|
||||
"active_providers": row.3,
|
||||
"active_models": row.4,
|
||||
"tokens_today_input": row.5,
|
||||
"tokens_today_output": row.6,
|
||||
})))
|
||||
}
|
||||
|
||||
@@ -186,59 +201,48 @@ pub async fn dashboard_stats(
|
||||
pub async fn register_device(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<serde_json::Value>,
|
||||
Json(req): Json<super::types::RegisterDeviceRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
let device_id = req.get("device_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
||||
let device_name = req.get("device_name").and_then(|v| v.as_str()).unwrap_or("Unknown");
|
||||
let platform = req.get("platform").and_then(|v| v.as_str()).unwrap_or("unknown");
|
||||
let app_version = req.get("app_version").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let device_uuid = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// UPSERT: 已存在则更新 last_seen_at,不存在则插入
|
||||
sqlx::query(
|
||||
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?7)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $7)
|
||||
ON CONFLICT(account_id, device_id) DO UPDATE SET
|
||||
device_name = ?4, platform = ?5, app_version = ?6, last_seen_at = ?7"
|
||||
device_name = EXCLUDED.device_name, platform = EXCLUDED.platform, app_version = EXCLUDED.app_version, last_seen_at = EXCLUDED.last_seen_at"
|
||||
)
|
||||
.bind(&device_uuid)
|
||||
.bind(&ctx.account_id)
|
||||
.bind(device_id)
|
||||
.bind(device_name)
|
||||
.bind(platform)
|
||||
.bind(app_version)
|
||||
.bind(&req.device_id)
|
||||
.bind(&req.device_name)
|
||||
.bind(&req.platform)
|
||||
.bind(&req.app_version)
|
||||
.bind(&now)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
log_operation(&state.db, &ctx.account_id, "device.register", "device", device_id,
|
||||
Some(serde_json::json!({"device_name": device_name, "platform": platform})),
|
||||
log_operation(&state.db, &ctx.account_id, "device.register", "device", &req.device_id,
|
||||
Some(serde_json::json!({"device_name": req.device_name, "platform": req.platform})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "device_id": device_id})))
|
||||
Ok(Json(serde_json::json!({"ok": true, "device_id": req.device_id})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/devices/heartbeat — 设备心跳
|
||||
pub async fn device_heartbeat(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<serde_json::Value>,
|
||||
Json(req): Json<super::types::DeviceHeartbeatRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
let device_id = req.get("device_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let result = sqlx::query(
|
||||
"UPDATE devices SET last_seen_at = ?1 WHERE account_id = ?2 AND device_id = ?3"
|
||||
"UPDATE devices SET last_seen_at = $1 WHERE account_id = $2 AND device_id = $3"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&ctx.account_id)
|
||||
.bind(device_id)
|
||||
.bind(&req.device_id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
@@ -253,22 +257,22 @@ pub async fn device_heartbeat(
|
||||
pub async fn list_devices(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, String, String)> =
|
||||
) -> SaasResult<Json<Vec<super::types::DeviceInfo>>> {
|
||||
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
|
||||
FROM devices WHERE account_id = ?1 ORDER BY last_seen_at DESC"
|
||||
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"id": r.0, "device_id": r.1,
|
||||
"device_name": r.2, "platform": r.3, "app_version": r.4,
|
||||
"last_seen_at": r.5, "created_at": r.6,
|
||||
})
|
||||
let items: Vec<super::types::DeviceInfo> = rows.into_iter().map(|r| {
|
||||
super::types::DeviceInfo {
|
||||
id: r.0, device_id: r.1,
|
||||
device_name: r.2, platform: r.3, app_version: r.4,
|
||||
last_seen_at: r.5.to_rfc3339(), created_at: r.6.to_rfc3339(),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Ok(Json(items))
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
//! 账号管理业务逻辑
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
|
||||
pub async fn list_accounts(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
query: &ListAccountsQuery,
|
||||
) -> SaasResult<PaginatedResponse<serde_json::Value>> {
|
||||
let page = query.page.unwrap_or(1).max(1);
|
||||
@@ -14,21 +14,25 @@ pub async fn list_accounts(
|
||||
|
||||
let mut where_clauses = Vec::new();
|
||||
let mut params: Vec<String> = Vec::new();
|
||||
let mut param_idx: usize = 1;
|
||||
|
||||
if let Some(role) = &query.role {
|
||||
where_clauses.push("role = ?".to_string());
|
||||
where_clauses.push(format!("role = ${}", param_idx));
|
||||
params.push(role.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(status) = &query.status {
|
||||
where_clauses.push("status = ?".to_string());
|
||||
where_clauses.push(format!("status = ${}", param_idx));
|
||||
params.push(status.clone());
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(search) = &query.search {
|
||||
where_clauses.push("(username LIKE ? OR email LIKE ? OR display_name LIKE ?)".to_string());
|
||||
where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
|
||||
let pattern = format!("%{}%", search);
|
||||
params.push(pattern.clone());
|
||||
params.push(pattern.clone());
|
||||
params.push(pattern);
|
||||
param_idx += 3;
|
||||
}
|
||||
|
||||
let where_sql = if where_clauses.is_empty() {
|
||||
@@ -46,10 +50,10 @@ pub async fn list_accounts(
|
||||
|
||||
let data_sql = format!(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts {} ORDER BY created_at DESC LIMIT ? OFFSET ?",
|
||||
where_sql
|
||||
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||
where_sql, param_idx, param_idx + 1
|
||||
);
|
||||
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<String>, String)>(&data_sql);
|
||||
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(&data_sql);
|
||||
for p in ¶ms {
|
||||
data_query = data_query.bind(p);
|
||||
}
|
||||
@@ -61,7 +65,7 @@ pub async fn list_accounts(
|
||||
serde_json::json!({
|
||||
"id": id, "username": username, "email": email, "display_name": display_name,
|
||||
"role": role, "status": status, "totp_enabled": totp_enabled,
|
||||
"last_login_at": last_login_at, "created_at": created_at,
|
||||
"last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -69,11 +73,11 @@ pub async fn list_accounts(
|
||||
Ok(PaginatedResponse { items, total, page, page_size })
|
||||
}
|
||||
|
||||
pub async fn get_account(db: &SqlitePool, account_id: &str) -> SaasResult<serde_json::Value> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, Option<String>, String)> =
|
||||
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||
FROM accounts WHERE id = ?1"
|
||||
FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(account_id)
|
||||
.fetch_optional(db)
|
||||
@@ -85,43 +89,45 @@ pub async fn get_account(db: &SqlitePool, account_id: &str) -> SaasResult<serde_
|
||||
Ok(serde_json::json!({
|
||||
"id": id, "username": username, "email": email, "display_name": display_name,
|
||||
"role": role, "status": status, "totp_enabled": totp_enabled,
|
||||
"last_login_at": last_login_at, "created_at": created_at,
|
||||
"last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn update_account(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
req: &UpdateAccountRequest,
|
||||
) -> SaasResult<serde_json::Value> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<String> = Vec::new();
|
||||
let mut param_idx: usize = 1;
|
||||
|
||||
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.email { updates.push("email = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.role { updates.push("role = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.avatar_url { updates.push("avatar_url = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_account(db, account_id).await;
|
||||
}
|
||||
|
||||
updates.push("updated_at = ?");
|
||||
params.push(now.clone());
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(account_id.to_string());
|
||||
|
||||
let sql = format!("UPDATE accounts SET {} WHERE id = ?", updates.join(", "));
|
||||
let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(p);
|
||||
}
|
||||
query = query.bind(now);
|
||||
query.execute(db).await?;
|
||||
get_account(db, account_id).await
|
||||
}
|
||||
|
||||
pub async fn update_account_status(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
status: &str,
|
||||
) -> SaasResult<()> {
|
||||
@@ -129,8 +135,8 @@ pub async fn update_account_status(
|
||||
if !valid.contains(&status) {
|
||||
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
|
||||
}
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let result = sqlx::query("UPDATE accounts SET status = ?1, updated_at = ?2 WHERE id = ?3")
|
||||
let now = chrono::Utc::now();
|
||||
let result = sqlx::query("UPDATE accounts SET status = $1, updated_at = $2 WHERE id = $3")
|
||||
.bind(status).bind(&now).bind(account_id)
|
||||
.execute(db).await?;
|
||||
|
||||
@@ -141,7 +147,7 @@ pub async fn update_account_status(
|
||||
}
|
||||
|
||||
pub async fn create_api_token(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
req: &CreateTokenRequest,
|
||||
) -> SaasResult<TokenInfo> {
|
||||
@@ -154,16 +160,18 @@ pub async fn create_api_token(
|
||||
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
||||
let token_prefix = raw_token[..8].to_string();
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let now_str = now.to_rfc3339();
|
||||
let expires_at = req.expires_days.map(|d| {
|
||||
(chrono::Utc::now() + chrono::Duration::days(d)).to_rfc3339()
|
||||
chrono::Utc::now() + chrono::Duration::days(d)
|
||||
});
|
||||
let expires_at_str = expires_at.as_ref().map(|t| t.to_rfc3339());
|
||||
let permissions = serde_json::to_string(&req.permissions)?;
|
||||
let token_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at, expires_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
)
|
||||
.bind(&token_id)
|
||||
.bind(account_id)
|
||||
@@ -182,20 +190,20 @@ pub async fn create_api_token(
|
||||
token_prefix,
|
||||
permissions: req.permissions.clone(),
|
||||
last_used_at: None,
|
||||
expires_at,
|
||||
created_at: now,
|
||||
expires_at: expires_at_str,
|
||||
created_at: now_str,
|
||||
token: Some(raw_token),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_api_tokens(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
) -> SaasResult<Vec<TokenInfo>> {
|
||||
let rows: Vec<(String, String, String, String, Option<String>, Option<String>, String)> =
|
||||
let rows: Vec<(String, String, String, String, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
|
||||
FROM api_tokens WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
)
|
||||
.bind(account_id)
|
||||
.fetch_all(db)
|
||||
@@ -203,14 +211,14 @@ pub async fn list_api_tokens(
|
||||
|
||||
Ok(rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used, expires_at: expires, created_at: created, token: None, }
|
||||
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used.map(|t| t.to_rfc3339()), expires_at: expires.map(|t| t.to_rfc3339()), created_at: created.to_rfc3339(), token: None, }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn revoke_api_token(db: &SqlitePool, token_id: &str, account_id: &str) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
pub async fn revoke_api_token(db: &PgPool, token_id: &str, account_id: &str) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now();
|
||||
let result = sqlx::query(
|
||||
"UPDATE api_tokens SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
|
||||
"UPDATE api_tokens SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(&now).bind(token_id).bind(account_id)
|
||||
.execute(db).await?;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateAccountRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub email: Option<String>,
|
||||
@@ -10,12 +10,12 @@ pub struct UpdateAccountRequest {
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateStatusRequest {
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
|
||||
pub struct ListAccountsQuery {
|
||||
pub page: Option<u32>,
|
||||
pub page_size: Option<u32>,
|
||||
@@ -32,14 +32,28 @@ pub struct PaginatedResponse<T: Serialize> {
|
||||
pub page_size: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
/// Concrete type alias for OpenAPI schema generation.
|
||||
///
|
||||
/// NOTE: This is intentionally a concrete (non-generic) type because utoipa
|
||||
/// requires concrete types for schema generation. It is functionally
|
||||
/// identical to `Paginated<AccountPublic>`.
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
#[allow(clippy::manual_non_exhaustive)] // kept for OpenAPI schema
|
||||
pub struct AccountPublicPaginatedResponse {
|
||||
pub items: Vec<crate::auth::types::AccountPublic>,
|
||||
pub total: i64,
|
||||
pub page: u32,
|
||||
pub page_size: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateTokenRequest {
|
||||
pub name: String,
|
||||
pub permissions: Vec<String>,
|
||||
pub expires_days: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct TokenInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
@@ -51,3 +65,35 @@ pub struct TokenInfo {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
// ============ Device Types ============
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct RegisterDeviceRequest {
|
||||
pub device_id: String,
|
||||
#[serde(default = "default_device_name")]
|
||||
pub device_name: String,
|
||||
#[serde(default = "default_platform")]
|
||||
pub platform: String,
|
||||
#[serde(default)]
|
||||
pub app_version: String,
|
||||
}
|
||||
|
||||
fn default_device_name() -> String { "Unknown".into() }
|
||||
fn default_platform() -> String { "unknown".into() }
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct DeviceHeartbeatRequest {
|
||||
pub device_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct DeviceInfo {
|
||||
pub id: String,
|
||||
pub device_id: String,
|
||||
pub device_name: Option<String>,
|
||||
pub platform: Option<String>,
|
||||
pub app_version: Option<String>,
|
||||
pub last_seen_at: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
@@ -16,16 +16,24 @@ pub async fn register(
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Json(req): Json<RegisterRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<AccountPublic>)> {
|
||||
if req.username.len() < 3 {
|
||||
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
|
||||
) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
|
||||
// 4.6: 用户名格式验证 — 3-32 字符,仅允许字母数字下划线
|
||||
if req.username.len() < 3 || req.username.len() > 32 {
|
||||
return Err(SaasError::InvalidInput("用户名长度需在 3-32 个字符之间".into()));
|
||||
}
|
||||
if !req.username.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
||||
return Err(SaasError::InvalidInput("用户名仅允许字母、数字和下划线".into()));
|
||||
}
|
||||
// 4.7: 邮箱格式验证
|
||||
if !req.email.contains('@') || !req.email.split('@').nth(1).map_or(false, |d| d.contains('.')) {
|
||||
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||||
}
|
||||
if req.password.len() < 8 {
|
||||
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
||||
}
|
||||
|
||||
let existing: Vec<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM accounts WHERE username = ?1 OR email = ?2"
|
||||
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
|
||||
)
|
||||
.bind(&req.username)
|
||||
.bind(&req.email)
|
||||
@@ -40,11 +48,11 @@ pub async fn register(
|
||||
let account_id = uuid::Uuid::new_v4().to_string();
|
||||
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
|
||||
let display_name = req.display_name.unwrap_or_default();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'active', ?7, ?7)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7)"
|
||||
)
|
||||
.bind(&account_id)
|
||||
.bind(&req.username)
|
||||
@@ -52,22 +60,33 @@ pub async fn register(
|
||||
.bind(&password_hash)
|
||||
.bind(&display_name)
|
||||
.bind(&role)
|
||||
.bind(&now)
|
||||
.bind(now)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
|
||||
let client_ip = addr.ip().to_string();
|
||||
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
||||
|
||||
Ok((StatusCode::CREATED, Json(AccountPublic {
|
||||
id: account_id,
|
||||
username: req.username,
|
||||
email: req.email,
|
||||
display_name,
|
||||
role,
|
||||
status: "active".into(),
|
||||
totp_enabled: false,
|
||||
created_at: now,
|
||||
// Generate JWT token for auto-login after registration
|
||||
let config = state.config.read().await;
|
||||
let token = create_token(
|
||||
&account_id, &role, vec![],
|
||||
state.jwt_secret.expose_secret(), config.auth.jwt_expiration_hours,
|
||||
)?;
|
||||
|
||||
Ok((StatusCode::CREATED, Json(LoginResponse {
|
||||
token,
|
||||
account: AccountPublic {
|
||||
id: account_id,
|
||||
username: req.username,
|
||||
email: req.email,
|
||||
display_name,
|
||||
role,
|
||||
permissions: vec![],
|
||||
status: "active".into(),
|
||||
totp_enabled: false,
|
||||
created_at: now.to_rfc3339(),
|
||||
},
|
||||
})))
|
||||
}
|
||||
|
||||
@@ -77,10 +96,10 @@ pub async fn login(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> SaasResult<Json<LoginResponse>> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||
let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||
FROM accounts WHERE username = ?1 OR email = ?1"
|
||||
FROM accounts WHERE username = $1 OR email = $1"
|
||||
)
|
||||
.bind(&req.username)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -88,13 +107,14 @@ pub async fn login(
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||
row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
|
||||
let created_at = created_at.to_rfc3339();
|
||||
|
||||
if status != "active" {
|
||||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status)));
|
||||
}
|
||||
|
||||
let (password_hash,): (String,) = sqlx::query_as(
|
||||
"SELECT password_hash FROM accounts WHERE id = ?1"
|
||||
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(&state.db)
|
||||
@@ -110,7 +130,7 @@ pub async fn login(
|
||||
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
||||
|
||||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||
"SELECT totp_secret FROM accounts WHERE id = ?1"
|
||||
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(&state.db)
|
||||
@@ -120,7 +140,10 @@ pub async fn login(
|
||||
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
||||
})?;
|
||||
|
||||
if !super::totp::verify_totp_code(&secret, code) {
|
||||
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
|
||||
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
|
||||
|
||||
if !super::totp::verify_totp_code(&decrypted_secret, code) {
|
||||
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
|
||||
}
|
||||
}
|
||||
@@ -133,9 +156,9 @@ pub async fn login(
|
||||
config.auth.jwt_expiration_hours,
|
||||
)?;
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET last_login_at = ?1 WHERE id = ?2")
|
||||
.bind(&now).bind(&id)
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
|
||||
.bind(now).bind(&id)
|
||||
.execute(&state.db).await?;
|
||||
let client_ip = addr.ip().to_string();
|
||||
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
|
||||
@@ -143,7 +166,7 @@ pub async fn login(
|
||||
Ok(Json(LoginResponse {
|
||||
token,
|
||||
account: AccountPublic {
|
||||
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||
id, username, email, display_name, role, permissions, status, totp_enabled, created_at,
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -152,14 +175,30 @@ pub async fn login(
|
||||
pub async fn refresh(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
) -> SaasResult<Json<LoginResponse>> {
|
||||
let config = state.config.read().await;
|
||||
let token = create_token(
|
||||
&ctx.account_id, &ctx.role, ctx.permissions.clone(),
|
||||
state.jwt_secret.expose_secret(),
|
||||
config.auth.jwt_expiration_hours,
|
||||
)?;
|
||||
Ok(Json(serde_json::json!({ "token": token })))
|
||||
|
||||
// 查询账号信息以返回完整 LoginResponse
|
||||
let row = sqlx::query_as::<_, (String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)>(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||
FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, created_at) = row;
|
||||
let created_at = created_at.to_rfc3339();
|
||||
Ok(Json(LoginResponse {
|
||||
token,
|
||||
account: AccountPublic { id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at },
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
||||
@@ -167,10 +206,10 @@ pub async fn me(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||
) -> SaasResult<Json<AccountPublic>> {
|
||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||
let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||
FROM accounts WHERE id = ?1"
|
||||
FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -178,9 +217,10 @@ pub async fn me(
|
||||
|
||||
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||
let created_at = created_at.to_rfc3339();
|
||||
|
||||
Ok(Json(AccountPublic {
|
||||
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||
id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -196,7 +236,7 @@ pub async fn change_password(
|
||||
|
||||
// 获取当前密码哈希
|
||||
let (password_hash,): (String,) = sqlx::query_as(
|
||||
"SELECT password_hash FROM accounts WHERE id = ?1"
|
||||
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
@@ -209,10 +249,10 @@ pub async fn change_password(
|
||||
|
||||
// 更新密码
|
||||
let new_hash = hash_password(&req.new_password)?;
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
|
||||
.bind(&new_hash)
|
||||
.bind(&now)
|
||||
.bind(now)
|
||||
.bind(&ctx.account_id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
@@ -223,16 +263,16 @@ pub async fn change_password(
|
||||
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
||||
}
|
||||
|
||||
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
|
||||
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
|
||||
let row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT permissions FROM roles WHERE id = ?1"
|
||||
"SELECT permissions FROM roles WHERE id = $1"
|
||||
)
|
||||
.bind(role)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let permissions_str = row
|
||||
.ok_or_else(|| SaasError::Internal(format!("角色 {} 不存在", role)))?
|
||||
.ok_or_else(|| SaasError::Forbidden(format!("角色 {} 不存在或无权限", role)))?
|
||||
.0;
|
||||
|
||||
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
||||
@@ -252,7 +292,7 @@ pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
|
||||
|
||||
/// 记录操作日志
|
||||
pub async fn log_operation(
|
||||
db: &sqlx::SqlitePool,
|
||||
db: &sqlx::PgPool,
|
||||
account_id: &str,
|
||||
action: &str,
|
||||
target_type: &str,
|
||||
@@ -260,10 +300,10 @@ pub async fn log_operation(
|
||||
details: Option<serde_json::Value>,
|
||||
ip_address: Option<&str>,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query(
|
||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(action)
|
||||
@@ -271,8 +311,54 @@ pub async fn log_operation(
|
||||
.bind(target_id)
|
||||
.bind(details.map(|d| d.to_string()))
|
||||
.bind(ip_address)
|
||||
.bind(&now)
|
||||
.bind(now)
|
||||
.execute(db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::auth::types::AuthContext;
|
||||
|
||||
fn ctx(permissions: Vec<&str>) -> AuthContext {
|
||||
AuthContext {
|
||||
account_id: "test-id".into(),
|
||||
role: "user".into(),
|
||||
permissions: permissions.into_iter().map(String::from).collect(),
|
||||
client_ip: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_permission_admin_full() {
|
||||
let c = ctx(vec!["admin:full"]);
|
||||
assert!(check_permission(&c, "config:write").is_ok());
|
||||
assert!(check_permission(&c, "account:admin").is_ok());
|
||||
assert!(check_permission(&c, "any:permission").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_permission_has_permission() {
|
||||
let c = ctx(vec!["config:write", "model:read"]);
|
||||
assert!(check_permission(&c, "config:write").is_ok());
|
||||
assert!(check_permission(&c, "model:read").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_permission_missing() {
|
||||
let c = ctx(vec!["model:read"]);
|
||||
let result = check_permission(&c, "config:write");
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("config:write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_permission_empty_list() {
|
||||
let c = ctx(vec![]);
|
||||
assert!(check_permission(&c, "config:write").is_err());
|
||||
assert!(check_permission(&c, "admin:full").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,17 +10,24 @@ use crate::error::SaasResult;
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub aud: String,
|
||||
pub iss: String,
|
||||
pub role: String,
|
||||
pub permissions: Vec<String>,
|
||||
pub iat: i64,
|
||||
pub exp: i64,
|
||||
}
|
||||
|
||||
const JWT_AUDIENCE: &str = "zclaw-saas";
|
||||
const JWT_ISSUER: &str = "zclaw-saas";
|
||||
|
||||
impl Claims {
|
||||
pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
sub: account_id.to_string(),
|
||||
aud: JWT_AUDIENCE.to_string(),
|
||||
iss: JWT_ISSUER.to_string(),
|
||||
role: role.to_string(),
|
||||
permissions,
|
||||
iat: now.timestamp(),
|
||||
@@ -48,10 +55,14 @@ pub fn create_token(
|
||||
|
||||
/// 验证 JWT Token
|
||||
pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
|
||||
let mut validation = Validation::default();
|
||||
validation.set_audience(&[JWT_AUDIENCE]);
|
||||
validation.set_issuer(&[JWT_ISSUER]);
|
||||
|
||||
let token_data = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
&validation,
|
||||
)?;
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
|
||||
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
|
||||
"SELECT account_id, expires_at, permissions FROM api_tokens
|
||||
WHERE token_hash = ?1 AND revoked_at IS NULL"
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(&token_hash)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -50,7 +50,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
|
||||
// 查询关联账号的角色
|
||||
let (role,): (String,) = sqlx::query_as(
|
||||
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'"
|
||||
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
|
||||
)
|
||||
.bind(&account_id)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -70,9 +70,9 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
// 异步更新 last_used_at(不阻塞请求)
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2")
|
||||
.bind(&now).bind(&token_hash)
|
||||
let now = chrono::Utc::now();
|
||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||
.bind(now).bind(&token_hash)
|
||||
.execute(&db).await;
|
||||
});
|
||||
|
||||
@@ -84,23 +84,11 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
})
|
||||
}
|
||||
|
||||
/// 从请求中提取客户端 IP
|
||||
/// 从请求中提取客户端 IP(仅信任直连 IP,不信任可伪造的 proxy header)
|
||||
fn extract_client_ip(req: &Request) -> Option<String> {
|
||||
// 优先从 ConnectInfo 获取
|
||||
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
|
||||
return Some(addr.ip().to_string());
|
||||
}
|
||||
// 回退到 X-Forwarded-For / X-Real-IP
|
||||
if let Some(forwarded) = req.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
return Some(forwarded.split(',').next()?.trim().to_string());
|
||||
}
|
||||
req.headers()
|
||||
.get("x-real-ip")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
req.extensions()
|
||||
.get::<ConnectInfo<SocketAddr>>()
|
||||
.map(|addr| addr.ip().to_string())
|
||||
}
|
||||
|
||||
/// 认证中间件: 从 JWT 或 API Token 提取身份
|
||||
|
||||
@@ -94,7 +94,7 @@ pub async fn setup_totp(
|
||||
) -> SaasResult<Json<TotpSetupResponse>> {
|
||||
// 如果已启用 TOTP,先清除旧密钥
|
||||
let (username,): (String,) = sqlx::query_as(
|
||||
"SELECT username FROM accounts WHERE id = ?1"
|
||||
"SELECT username FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
@@ -103,9 +103,10 @@ pub async fn setup_totp(
|
||||
let config = state.config.read().await;
|
||||
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
|
||||
|
||||
// 存储密钥 (但不启用,需要 /verify 确认)
|
||||
sqlx::query("UPDATE accounts SET totp_secret = ?1 WHERE id = ?2")
|
||||
.bind(&setup.secret)
|
||||
// 加密 TOTP 密钥后存储 (但不启用,需要 /verify 确认)
|
||||
let encrypted_secret = state.field_encryption.encrypt(&setup.secret)?;
|
||||
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
|
||||
.bind(&encrypted_secret)
|
||||
.bind(&ctx.account_id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
@@ -130,7 +131,7 @@ pub async fn verify_totp(
|
||||
|
||||
// 获取存储的密钥
|
||||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||
"SELECT totp_secret FROM accounts WHERE id = ?1"
|
||||
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
@@ -140,14 +141,17 @@ pub async fn verify_totp(
|
||||
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
|
||||
})?;
|
||||
|
||||
if !verify_totp_code(&secret, code) {
|
||||
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
|
||||
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
|
||||
|
||||
if !verify_totp_code(&decrypted_secret, code) {
|
||||
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
||||
}
|
||||
|
||||
// 验证成功 → 启用 TOTP
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET totp_enabled = 1, updated_at = ?1 WHERE id = ?2")
|
||||
.bind(&now)
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query("UPDATE accounts SET totp_enabled = true, updated_at = $1 WHERE id = $2")
|
||||
.bind(now)
|
||||
.bind(&ctx.account_id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
@@ -167,7 +171,7 @@ pub async fn disable_totp(
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
// 验证密码
|
||||
let (password_hash,): (String,) = sqlx::query_as(
|
||||
"SELECT password_hash FROM accounts WHERE id = ?1"
|
||||
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
@@ -178,9 +182,9 @@ pub async fn disable_totp(
|
||||
}
|
||||
|
||||
// 清除 TOTP
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE accounts SET totp_enabled = 0, totp_secret = NULL, updated_at = ?1 WHERE id = ?2")
|
||||
.bind(&now)
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
|
||||
.bind(now)
|
||||
.bind(&ctx.account_id)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
@@ -190,3 +194,65 @@ pub async fn disable_totp(
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_totp_secret_format() {
|
||||
let result = generate_totp_secret("TestIssuer", "user@example.com");
|
||||
assert!(result.otpauth_uri.starts_with("otpauth://totp/"));
|
||||
assert!(result.otpauth_uri.contains("secret="));
|
||||
assert!(result.otpauth_uri.contains("issuer=TestIssuer"));
|
||||
assert!(result.otpauth_uri.contains("algorithm=SHA1"));
|
||||
assert!(result.otpauth_uri.contains("digits=6"));
|
||||
assert!(result.otpauth_uri.contains("period=30"));
|
||||
// Base32 编码的 20 字节 = 32 字符
|
||||
assert_eq!(result.secret.len(), 32);
|
||||
assert_eq!(result.issuer, "TestIssuer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_totp_secret_special_chars() {
|
||||
let result = generate_totp_secret("My App", "user@domain:8080");
|
||||
// 特殊字符应被 URL 编码
|
||||
assert!(!result.otpauth_uri.contains("user@domain:8080"));
|
||||
assert!(result.otpauth_uri.contains("user%40domain"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_totp_code_valid() {
|
||||
// 使用 generate_random_secret 创建合法 secret,然后生成并验证码
|
||||
let secret = generate_random_secret();
|
||||
let secret_bytes = data_encoding::BASE32.decode(secret.as_bytes()).unwrap();
|
||||
let totp = totp_rs::TOTP::new(
|
||||
totp_rs::Algorithm::SHA1, 6, 1, 30, secret_bytes,
|
||||
).unwrap();
|
||||
let valid_code = totp.generate(chrono::Utc::now().timestamp() as u64);
|
||||
assert!(verify_totp_code(&secret, &valid_code));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_totp_code_invalid() {
|
||||
let secret = generate_random_secret();
|
||||
assert!(!verify_totp_code(&secret, "000000"));
|
||||
assert!(!verify_totp_code(&secret, "999999"));
|
||||
assert!(!verify_totp_code(&secret, "abcdef"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_totp_code_invalid_secret() {
|
||||
assert!(!verify_totp_code("not-valid-base32!!!", "123456"));
|
||||
assert!(!verify_totp_code("", "123456"));
|
||||
assert!(!verify_totp_code("短", "123456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verify_totp_code_empty() {
|
||||
let secret = "JBSWY3DPEHPK3PXP";
|
||||
assert!(!verify_totp_code(secret, ""));
|
||||
assert!(!verify_totp_code(secret, "12345"));
|
||||
assert!(!verify_totp_code(secret, "1234567"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 登录请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct LoginRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
@@ -11,14 +11,14 @@ pub struct LoginRequest {
|
||||
}
|
||||
|
||||
/// 登录响应
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct LoginResponse {
|
||||
pub token: String,
|
||||
pub account: AccountPublic,
|
||||
}
|
||||
|
||||
/// 注册请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct RegisterRequest {
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
@@ -27,20 +27,21 @@ pub struct RegisterRequest {
|
||||
}
|
||||
|
||||
/// 修改密码请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ChangePasswordRequest {
|
||||
pub old_password: String,
|
||||
pub new_password: String,
|
||||
}
|
||||
|
||||
/// 公开账号信息 (无敏感数据)
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
|
||||
pub struct AccountPublic {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub display_name: String,
|
||||
pub role: String,
|
||||
pub permissions: Vec<String>,
|
||||
pub status: String,
|
||||
pub totp_enabled: bool,
|
||||
pub created_at: String,
|
||||
|
||||
@@ -45,10 +45,13 @@ pub struct AuthConfig {
|
||||
/// 中转服务配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RelayConfig {
|
||||
#[doc(hidden)]
|
||||
#[serde(default = "default_max_queue")]
|
||||
pub max_queue_size: usize,
|
||||
#[doc(hidden)]
|
||||
#[serde(default = "default_max_concurrent")]
|
||||
pub max_concurrent_per_provider: usize,
|
||||
#[doc(hidden)]
|
||||
#[serde(default = "default_batch_window")]
|
||||
pub batch_window_ms: u64,
|
||||
#[serde(default = "default_retry_delay")]
|
||||
@@ -59,7 +62,22 @@ pub struct RelayConfig {
|
||||
|
||||
fn default_host() -> String { "0.0.0.0".into() }
|
||||
fn default_port() -> u16 { 8080 }
|
||||
fn default_db_url() -> String { "sqlite:./saas-data.db".into() }
|
||||
fn default_db_url() -> String {
|
||||
// 无默认值:生产环境必须通过 DATABASE_URL 或配置文件设置
|
||||
// 开发环境可设置 ZCLAW_SAAS_DEV=true 使用 postgres://localhost:5432/zclaw
|
||||
std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| {
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
if is_dev {
|
||||
"postgres://localhost:5432/zclaw".into()
|
||||
} else {
|
||||
tracing::error!("DATABASE_URL 未设置且非开发环境");
|
||||
String::new()
|
||||
}
|
||||
})
|
||||
}
|
||||
fn default_jwt_hours() -> i64 { 24 }
|
||||
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
|
||||
fn default_max_queue() -> usize { 1000 }
|
||||
@@ -155,6 +173,16 @@ impl SaaSConfig {
|
||||
SaaSConfig::default()
|
||||
};
|
||||
|
||||
// 验证数据库 URL 已配置
|
||||
if config.database.url.is_empty() {
|
||||
anyhow::bail!(
|
||||
"数据库 URL 未配置。请通过以下方式之一设置:\n\
|
||||
1. 在配置文件中设置 [database].url\n\
|
||||
2. 设置 DATABASE_URL 环境变量\n\
|
||||
开发环境可设置 ZCLAW_SAAS_DEV=true 使用默认值。"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -182,3 +210,94 @@ impl SaaSConfig {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_has_expected_values() {
|
||||
let config = SaaSConfig::default();
|
||||
assert_eq!(config.server.host, "0.0.0.0");
|
||||
assert_eq!(config.server.port, 8080);
|
||||
assert!(config.server.cors_origins.is_empty());
|
||||
assert_eq!(config.auth.jwt_expiration_hours, 24);
|
||||
assert_eq!(config.auth.totp_issuer, "ZCLAW SaaS");
|
||||
assert_eq!(config.rate_limit.requests_per_minute, 60);
|
||||
assert_eq!(config.rate_limit.burst, 10);
|
||||
assert_eq!(config.relay.max_queue_size, 1000);
|
||||
assert_eq!(config.relay.max_concurrent_per_provider, 5);
|
||||
assert_eq!(config.relay.max_attempts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limit_default_matches_manual() {
|
||||
let config = SaaSConfig::default();
|
||||
assert_eq!(config.rate_limit.requests_per_minute, 60);
|
||||
assert_eq!(config.rate_limit.burst, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_minimal_config_toml() {
|
||||
let toml_str = r#"
|
||||
[server]
|
||||
host = "127.0.0.1"
|
||||
port = 9090
|
||||
|
||||
[database]
|
||||
url = "postgres://localhost/zclaw"
|
||||
|
||||
[auth]
|
||||
jwt_expiration_hours = 48
|
||||
|
||||
[relay]
|
||||
max_queue_size = 500
|
||||
"#;
|
||||
let config: SaaSConfig = toml::from_str(toml_str).expect("parse should succeed");
|
||||
assert_eq!(config.server.host, "127.0.0.1");
|
||||
assert_eq!(config.server.port, 9090);
|
||||
assert_eq!(config.database.url, "postgres://localhost/zclaw");
|
||||
assert_eq!(config.auth.jwt_expiration_hours, 48);
|
||||
assert_eq!(config.relay.max_queue_size, 500);
|
||||
// defaults should fill in
|
||||
assert_eq!(config.rate_limit.requests_per_minute, 60);
|
||||
assert_eq!(config.relay.max_attempts, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_full_config_with_rate_limit() {
|
||||
let toml_str = r#"
|
||||
[server]
|
||||
host = "0.0.0.0"
|
||||
port = 8080
|
||||
cors_origins = ["http://localhost:3000", "http://admin.example.com"]
|
||||
|
||||
[database]
|
||||
url = "postgres://db:5432/zclaw"
|
||||
|
||||
[auth]
|
||||
jwt_expiration_hours = 12
|
||||
totp_issuer = "MyCorp"
|
||||
|
||||
[relay]
|
||||
max_queue_size = 2000
|
||||
max_concurrent_per_provider = 10
|
||||
batch_window_ms = 100
|
||||
retry_delay_ms = 2000
|
||||
max_attempts = 5
|
||||
|
||||
[rate_limit]
|
||||
requests_per_minute = 120
|
||||
burst = 20
|
||||
"#;
|
||||
let config: SaaSConfig = toml::from_str(toml_str).expect("parse should succeed");
|
||||
assert_eq!(config.server.cors_origins.len(), 2);
|
||||
assert_eq!(config.auth.jwt_expiration_hours, 12);
|
||||
assert_eq!(config.auth.totp_issuer, "MyCorp");
|
||||
assert_eq!(config.relay.max_concurrent_per_provider, 10);
|
||||
assert_eq!(config.relay.retry_delay_ms, 2000);
|
||||
assert_eq!(config.relay.max_attempts, 5);
|
||||
assert_eq!(config.rate_limit.requests_per_minute, 120);
|
||||
assert_eq!(config.rate_limit.burst, 20);
|
||||
}
|
||||
}
|
||||
|
||||
277
crates/zclaw-saas/src/crypto.rs
Normal file
277
crates/zclaw-saas/src/crypto.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
//! AES-256-GCM 字段级加密
|
||||
//!
|
||||
//! 用于加密数据库中存储的敏感字段(如 API Key)。
|
||||
//! 每次加密生成随机 12 字节 nonce,密文格式: `base64(nonce || ciphertext || tag)`。
|
||||
|
||||
use aes_gcm::aead::{AeadInPlace, KeyInit, OsRng};
|
||||
use aes_gcm::{Aes256Gcm, AeadCore, Nonce};
|
||||
use data_encoding::BASE64;
|
||||
use std::fmt;
|
||||
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
|
||||
/// AES-256-GCM 密钥字节长度
|
||||
const KEY_LEN: usize = 32;
|
||||
|
||||
/// GCM nonce 字节长度 (96-bit,推荐值)
|
||||
const NONCE_LEN: usize = 12;
|
||||
|
||||
/// 字段加密器,持有 AES-256-GCM 密钥
|
||||
///
|
||||
/// 线程安全,可通过 `Arc` 在多任务间共享。
|
||||
pub struct FieldEncryption {
|
||||
cipher: Aes256Gcm,
|
||||
}
|
||||
|
||||
impl fmt::Debug for FieldEncryption {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("FieldEncryption")
|
||||
.field("cipher", &"<redacted>")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldEncryption {
|
||||
/// 从环境变量加载或生成加密密钥
|
||||
///
|
||||
/// - **生产环境**: 必须设置 `ZCLAW_SAAS_FIELD_ENCRYPTION_KEY`(32 字节 hex 编码)
|
||||
/// - **开发环境** (`ZCLAW_SAAS_DEV=true`): 自动生成随机密钥并输出警告
|
||||
pub fn new() -> anyhow::Result<Self> {
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
let key_bytes = match std::env::var("ZCLAW_SAAS_FIELD_ENCRYPTION_KEY") {
|
||||
Ok(hex_key) => {
|
||||
let bytes = hex::decode(&hex_key).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 格式无效 (期望 64 字符 hex): {e}"
|
||||
)
|
||||
})?;
|
||||
if bytes.len() != KEY_LEN {
|
||||
anyhow::bail!(
|
||||
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 长度错误: 期望 {KEY_LEN} 字节, 实际 {} 字节",
|
||||
bytes.len()
|
||||
);
|
||||
}
|
||||
tracing::info!("Field encryption key loaded from environment");
|
||||
bytes
|
||||
}
|
||||
Err(_) => {
|
||||
if is_dev {
|
||||
let random_key: [u8; KEY_LEN] = rand::random();
|
||||
let hex_key = hex::encode(random_key);
|
||||
tracing::warn!(
|
||||
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 未设置,已生成随机密钥 (仅限开发环境):\n {hex_key}\n\
|
||||
生产环境必须设置此环境变量!"
|
||||
);
|
||||
random_key.to_vec()
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 环境变量未设置。\n\
|
||||
请设置一个 32 字节 hex 编码密钥 (64 字符)。\n\
|
||||
生成方式: openssl rand -hex 32\n\
|
||||
开发环境可设置 ZCLAW_SAAS_DEV=true 自动生成。"
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes);
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
|
||||
Ok(Self { cipher })
|
||||
}
|
||||
|
||||
/// 加密明文,返回 base64 编码密文
|
||||
///
|
||||
/// 密文格式: `base64(nonce_12bytes || ciphertext || gcm_tag_16bytes)`
|
||||
pub fn encrypt(&self, plaintext: &str) -> SaasResult<String> {
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
let payload = plaintext.as_bytes();
|
||||
|
||||
// AeadInPlace::encrypt_in_place_append_tag 会在 payload 后面追加 16 字节 tag
|
||||
let mut buffer = payload.to_vec();
|
||||
self.cipher
|
||||
.encrypt_in_place(&nonce, &[], &mut buffer)
|
||||
.map_err(|e| SaasError::Encryption(format!("加密失败: {e}")))?;
|
||||
|
||||
// 构造输出: nonce (12) || ciphertext + tag
|
||||
let mut output = Vec::with_capacity(NONCE_LEN + buffer.len());
|
||||
output.extend_from_slice(&nonce);
|
||||
output.extend_from_slice(&buffer);
|
||||
|
||||
Ok(BASE64.encode(&output))
|
||||
}
|
||||
|
||||
/// 解密 base64 编码密文,返回原始明文
|
||||
///
|
||||
/// 输入格式: `base64(nonce_12bytes || ciphertext || gcm_tag_16bytes)`
|
||||
pub fn decrypt(&self, ciphertext: &str) -> SaasResult<String> {
|
||||
let raw = BASE64
|
||||
.decode(ciphertext.as_bytes())
|
||||
.map_err(|e| SaasError::Encryption(format!("Base64 解码失败: {e}")))?;
|
||||
|
||||
if raw.len() < NONCE_LEN {
|
||||
return Err(SaasError::Encryption(
|
||||
"密文长度不足: 无法提取 nonce".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (nonce_bytes, encrypted) = raw.split_at(NONCE_LEN);
|
||||
let nonce = Nonce::from_slice(nonce_bytes);
|
||||
|
||||
let mut buffer = encrypted.to_vec();
|
||||
self.cipher
|
||||
.decrypt_in_place(nonce, &[], &mut buffer)
|
||||
.map_err(|e| SaasError::Encryption(format!("解密失败 (密文可能已损坏或密钥不匹配): {e}")))?;
|
||||
|
||||
String::from_utf8(buffer)
|
||||
.map_err(|e| SaasError::Encryption(format!("解密结果非有效 UTF-8: {e}")))
|
||||
}
|
||||
|
||||
/// 尝试解密,失败时返回原始明文(用于迁移期间兼容未加密的旧数据)
|
||||
///
|
||||
/// 在字段加密上线前,数据库中可能已存在未加密的明文数据。
|
||||
/// 此方法先尝试解密,若解密失败(Base64 解码失败、GCM 认证失败等),
|
||||
/// 则假设数据是旧版明文,直接返回原值。
|
||||
pub fn decrypt_or_plaintext(&self, value: &str) -> String {
|
||||
self.decrypt(value).unwrap_or_else(|_| value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// 辅助: 用固定密钥创建 FieldEncryption(测试专用)
|
||||
fn test_encryption() -> FieldEncryption {
|
||||
// 固定 32 字节密钥,仅用于测试
|
||||
let key_bytes: [u8; KEY_LEN] = [
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
|
||||
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
|
||||
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
|
||||
];
|
||||
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes);
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
FieldEncryption { cipher }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_produces_base64_output() {
|
||||
let enc = test_encryption();
|
||||
let result = enc.encrypt("hello world");
|
||||
assert!(result.is_ok());
|
||||
let ciphertext = result.unwrap();
|
||||
// base64 输出应该能被 BASE64 解码
|
||||
assert!(BASE64.decode(ciphertext.as_bytes()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_roundtrip() {
|
||||
let enc = test_encryption();
|
||||
|
||||
let plaintext = "sk-proj-abc123SECRET_API_KEY_!@#$%";
|
||||
let ciphertext = enc.encrypt(plaintext).expect("encrypt should succeed");
|
||||
let decrypted = enc.decrypt(&ciphertext).expect("decrypt should succeed");
|
||||
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_roundtrip_chinese() {
|
||||
let enc = test_encryption();
|
||||
|
||||
let plaintext = "这是一个包含中文的敏感字段测试";
|
||||
let ciphertext = enc.encrypt(plaintext).expect("encrypt should succeed");
|
||||
let decrypted = enc.decrypt(&ciphertext).expect("decrypt should succeed");
|
||||
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_encryptions_produce_different_ciphertexts() {
|
||||
let enc = test_encryption();
|
||||
|
||||
let plaintext = "same-plaintext";
|
||||
let ct1 = enc.encrypt(plaintext).unwrap();
|
||||
let ct2 = enc.encrypt(plaintext).unwrap();
|
||||
|
||||
// 由于随机 nonce,相同明文的密文应该不同
|
||||
assert_ne!(ct1, ct2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_wrong_key_fails() {
|
||||
let enc1 = test_encryption();
|
||||
|
||||
// 用不同密钥创建另一个加密器
|
||||
let key_bytes2: [u8; KEY_LEN] = [
|
||||
0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8,
|
||||
0xf7, 0xf6, 0xf5, 0xf4, 0xf3, 0xf2, 0xf1, 0xf0,
|
||||
0xef, 0xee, 0xed, 0xec, 0xeb, 0xea, 0xe9, 0xe8,
|
||||
0xe7, 0xe6, 0xe5, 0xe4, 0xe3, 0xe2, 0xe1, 0xe0,
|
||||
];
|
||||
let key2 = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes2);
|
||||
let cipher2 = Aes256Gcm::new(key2);
|
||||
let enc2 = FieldEncryption { cipher: cipher2 };
|
||||
|
||||
let ciphertext = enc1.encrypt("secret").unwrap();
|
||||
let result = enc2.decrypt(&ciphertext);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_invalid_base64_fails() {
|
||||
let enc = test_encryption();
|
||||
let result = enc.decrypt("not-valid-base64!!!");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_too_short_ciphertext_fails() {
|
||||
let enc = test_encryption();
|
||||
// 构造一个短于 12 字节 nonce 的有效 base64 字符串
|
||||
let short = BASE64.encode(&[0x01, 0x02, 0x03]);
|
||||
let result = enc.decrypt(&short);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decrypt_tampered_ciphertext_fails() {
|
||||
let enc = test_encryption();
|
||||
let ciphertext = enc.encrypt("sensitive-data").unwrap();
|
||||
|
||||
// 解码、篡改、重新编码
|
||||
let mut raw = BASE64.decode(ciphertext.as_bytes()).unwrap();
|
||||
// 翻转 nonce 后的一个字节
|
||||
let tamper_pos = NONCE_LEN + 2;
|
||||
if tamper_pos < raw.len() {
|
||||
raw[tamper_pos] ^= 0xff;
|
||||
}
|
||||
let tampered = BASE64.encode(&raw);
|
||||
|
||||
let result = enc.decrypt(&tampered);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_empty_string_roundtrip() {
|
||||
let enc = test_encryption();
|
||||
let ciphertext = enc.encrypt("").unwrap();
|
||||
let decrypted = enc.decrypt(&ciphertext).unwrap();
|
||||
assert_eq!(decrypted, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ciphertext_format_has_nonce_prefix() {
|
||||
let enc = test_encryption();
|
||||
let ciphertext = enc.encrypt("test").unwrap();
|
||||
let raw = BASE64.decode(ciphertext.as_bytes()).unwrap();
|
||||
// raw 应该 = nonce(12) + ciphertext + tag(16)
|
||||
// 至少 12 + 16 = 28 字节(明文 4 字节加密后 4 字节 + 16 字节 tag)
|
||||
assert!(raw.len() >= NONCE_LEN + 16);
|
||||
}
|
||||
}
|
||||
243
crates/zclaw-saas/src/csrf.rs
Normal file
243
crates/zclaw-saas/src/csrf.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
//! CSRF 防护: Origin 校验中间件
|
||||
//!
|
||||
//! 对所有状态变更请求 (POST/PUT/PATCH/DELETE) 校验 `Origin` 请求头,
|
||||
//! 确保其与 `server.cors_origins` 白名单中的某项匹配。
|
||||
//!
|
||||
//! - GET / HEAD / OPTIONS 请求跳过校验 (安全方法)
|
||||
//! - 缺少 Origin 头时拒绝 (403)
|
||||
//! - Origin 不匹配白名单时拒绝 (403)
|
||||
//! - `ZCLAW_SAAS_DEV=true` 时跳过校验
|
||||
//!
|
||||
//! 这是 Bearer Token API 最合适的 CSRF 防护方案。
|
||||
//! 如果未来迁移到 Cookie 认证,需要升级为 CSRF Token 方案。
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 需要进行 Origin 校验的 HTTP 方法
|
||||
const CSRF_UNSAFE_METHODS: &[&str] = &["POST", "PUT", "PATCH", "DELETE"];
|
||||
|
||||
/// Origin 校验中间件
|
||||
///
|
||||
/// 在 auth_middleware 之后、rate_limit_middleware 之前执行。
|
||||
/// 已认证的请求若缺少或不匹配 Origin 头,返回 403 Forbidden。
|
||||
pub async fn origin_check_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// 开发模式跳过校验
|
||||
if is_dev_mode() {
|
||||
return next.run(req).await;
|
||||
}
|
||||
|
||||
// 安全方法跳过校验
|
||||
let method = req.method().as_str().to_uppercase();
|
||||
if !CSRF_UNSAFE_METHODS.contains(&method.as_str()) {
|
||||
return next.run(req).await;
|
||||
}
|
||||
|
||||
// 获取 Origin 头
|
||||
let origin_header = match req.headers().get(header::ORIGIN) {
|
||||
Some(value) => match value.to_str() {
|
||||
Ok(origin) => origin,
|
||||
Err(_) => {
|
||||
warn!("CSRF: Origin header contains invalid UTF-8");
|
||||
return csrf_reject("ORIGIN_INVALID", "Origin 请求头格式无效");
|
||||
}
|
||||
},
|
||||
None => {
|
||||
warn!("CSRF: Missing Origin header on {} {}", method, req.uri());
|
||||
return csrf_reject("ORIGIN_MISSING", "缺少 Origin 请求头");
|
||||
}
|
||||
};
|
||||
|
||||
// 从配置读取白名单
|
||||
let allowed_origins = {
|
||||
let config = state.config.read().await;
|
||||
config.server.cors_origins.clone()
|
||||
};
|
||||
|
||||
// 白名单为空时不校验 (生产环境已在 main.rs 中强制要求配置)
|
||||
if allowed_origins.is_empty() {
|
||||
return next.run(req).await;
|
||||
}
|
||||
|
||||
// 校验 Origin 是否在白名单中
|
||||
if !origin_matches_whitelist(origin_header, &allowed_origins) {
|
||||
warn!(
|
||||
"CSRF: Origin '{}' not in whitelist for {} {}",
|
||||
origin_header,
|
||||
method,
|
||||
req.uri()
|
||||
);
|
||||
return csrf_reject("ORIGIN_NOT_ALLOWED", "Origin 不在允许列表中");
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// 判断是否为开发模式
|
||||
fn is_dev_mode() -> bool {
|
||||
std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// 校验 Origin 是否匹配白名单中的某项
|
||||
///
|
||||
/// 匹配规则: 精确匹配 (scheme + host + port)。
|
||||
/// 例如白名单 `https://admin.zclaw.com` 只匹配该 Origin,
|
||||
/// 不匹配 `https://evil.zclaw.com`。
|
||||
fn origin_matches_whitelist(origin: &str, whitelist: &[String]) -> bool {
|
||||
// 使用 url::Url 进行规范化比较,避免字符串拼接攻击
|
||||
let parsed_origin = match url::Url::parse(origin) {
|
||||
Ok(url) => url,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
for allowed in whitelist {
|
||||
if let Ok(allowed_url) = url::Url::parse(allowed) {
|
||||
if origins_equal(&parsed_origin, &allowed_url) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// 白名单条目本身无法解析,降级为字符串比较
|
||||
if origin == allowed {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// 比较两个 Origin URL 是否相等 (scheme + host + port)
|
||||
///
|
||||
/// 同时拒绝包含路径的 URL: 真实的 Origin 头永远不会包含路径。
|
||||
/// 如果传入的 origin 字符串包含路径,视为不合法的 Origin。
|
||||
fn origins_equal(a: &url::Url, b: &url::Url) -> bool {
|
||||
// scheme 必须完全一致
|
||||
if a.scheme() != b.scheme() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// host 必须完全一致
|
||||
if a.host_str() != b.host_str() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// port 必须完全一致 (url::Url 会规范化默认端口: 80/HTTP, 443/HTTPS)
|
||||
if a.port() != b.port() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 防御性检查: 合法的 Origin 不应包含路径、query string 或 fragment
|
||||
// 如果任一 URL 的 path 不是 "/" 或有 query/fragment,视为可疑请求
|
||||
if a.path() != "/" || b.path() != "/" {
|
||||
return false;
|
||||
}
|
||||
if a.query().is_some() || b.query().is_some() {
|
||||
return false;
|
||||
}
|
||||
if a.fragment().is_some() || b.fragment().is_some() {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// 返回 403 拒绝响应
|
||||
fn csrf_reject(error_code: &str, message: &str) -> Response {
|
||||
(
|
||||
StatusCode::FORBIDDEN,
|
||||
[("Content-Type", "application/json")],
|
||||
axum::Json(serde_json::json!({
|
||||
"error": error_code,
|
||||
"message": message,
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_origin_matches_whitelist_exact() {
|
||||
let whitelist = vec![
|
||||
"https://admin.zclaw.com".to_string(),
|
||||
"http://localhost:3000".to_string(),
|
||||
];
|
||||
|
||||
assert!(origin_matches_whitelist("https://admin.zclaw.com", &whitelist));
|
||||
assert!(origin_matches_whitelist("http://localhost:3000", &whitelist));
|
||||
assert!(!origin_matches_whitelist("https://evil.zclaw.com", &whitelist));
|
||||
// url::Url normalizes port 443 for HTTPS to None, so these match
|
||||
assert!(origin_matches_whitelist("https://admin.zclaw.com:443", &whitelist));
|
||||
assert!(!origin_matches_whitelist("http://localhost:3001", &whitelist));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origin_matches_whitelist_empty() {
|
||||
let whitelist: Vec<String> = vec![];
|
||||
assert!(!origin_matches_whitelist("https://example.com", &whitelist));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origin_matches_whitelist_with_path() {
|
||||
let whitelist = vec!["https://admin.zclaw.com".to_string()];
|
||||
// 标准 Origin 不包含路径,应该匹配
|
||||
assert!(origin_matches_whitelist("https://admin.zclaw.com", &whitelist));
|
||||
// 包含路径的 Origin 不合法 (浏览器永远不会发送带路径的 Origin)
|
||||
assert!(!origin_matches_whitelist("https://admin.zclaw.com/evil", &whitelist));
|
||||
// 带查询字符串的 Origin 也不合法
|
||||
assert!(!origin_matches_whitelist("https://admin.zclaw.com/?evil=1", &whitelist));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origin_matches_whitelist_invalid_origin() {
|
||||
let whitelist = vec!["https://admin.zclaw.com".to_string()];
|
||||
assert!(!origin_matches_whitelist("not-a-url", &whitelist));
|
||||
assert!(!origin_matches_whitelist("", &whitelist));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origins_equal() {
|
||||
let a = url::Url::parse("https://admin.zclaw.com").unwrap();
|
||||
let b = url::Url::parse("https://admin.zclaw.com").unwrap();
|
||||
assert!(origins_equal(&a, &b));
|
||||
|
||||
// Different scheme
|
||||
let c = url::Url::parse("http://admin.zclaw.com").unwrap();
|
||||
assert!(!origins_equal(&a, &c));
|
||||
|
||||
// Different host
|
||||
let d = url::Url::parse("https://evil.zclaw.com").unwrap();
|
||||
assert!(!origins_equal(&a, &d));
|
||||
|
||||
// Different port
|
||||
let e = url::Url::parse("https://admin.zclaw.com:8443").unwrap();
|
||||
assert!(!origins_equal(&a, &e));
|
||||
|
||||
// Explicit default port vs implicit
|
||||
let f = url::Url::parse("https://admin.zclaw.com:443").unwrap();
|
||||
// url::Url normalizes 443 for HTTPS, so both have None port
|
||||
assert!(origins_equal(&a, &f));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_dev_mode() {
|
||||
// Don't modify env in tests; just verify the function signature works
|
||||
// Actual env-var-based behavior tested in integration tests
|
||||
let _ = is_dev_mode();
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
//! 数据库初始化与 Schema
|
||||
//! 数据库初始化与 Schema (PostgreSQL)
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
|
||||
const SCHEMA_VERSION: i32 = 1;
|
||||
const SCHEMA_VERSION: i32 = 2;
|
||||
|
||||
const SCHEMA_SQL: &str = r#"
|
||||
CREATE TABLE IF NOT EXISTS saas_schema_version (
|
||||
@@ -20,10 +20,10 @@ CREATE TABLE IF NOT EXISTS accounts (
|
||||
role TEXT NOT NULL DEFAULT 'user',
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
totp_secret TEXT,
|
||||
totp_enabled INTEGER NOT NULL DEFAULT 0,
|
||||
last_login_at TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
totp_enabled BOOLEAN NOT NULL DEFAULT false,
|
||||
last_login_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role);
|
||||
@@ -35,10 +35,10 @@ CREATE TABLE IF NOT EXISTS api_tokens (
|
||||
token_hash TEXT NOT NULL,
|
||||
token_prefix TEXT NOT NULL,
|
||||
permissions TEXT NOT NULL DEFAULT '[]',
|
||||
last_used_at TEXT,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
revoked_at TEXT,
|
||||
last_used_at TIMESTAMPTZ,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
revoked_at TIMESTAMPTZ,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id);
|
||||
@@ -46,32 +46,23 @@ CREATE INDEX IF NOT EXISTS idx_api_tokens_hash ON api_tokens(token_hash);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS roles (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
description TEXT,
|
||||
permissions TEXT NOT NULL DEFAULT '[]',
|
||||
is_system INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS permission_templates (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
permissions TEXT NOT NULL DEFAULT '[]',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
is_system BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS operation_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
account_id TEXT,
|
||||
action TEXT NOT NULL,
|
||||
target_type TEXT,
|
||||
target_id TEXT,
|
||||
details TEXT,
|
||||
ip_address TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action);
|
||||
@@ -84,12 +75,12 @@ CREATE TABLE IF NOT EXISTS providers (
|
||||
api_key TEXT,
|
||||
base_url TEXT NOT NULL,
|
||||
api_protocol TEXT NOT NULL DEFAULT 'openai',
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
rate_limit_rpm INTEGER,
|
||||
rate_limit_tpm INTEGER,
|
||||
config_json TEXT DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
@@ -99,13 +90,13 @@ CREATE TABLE IF NOT EXISTS models (
|
||||
alias TEXT NOT NULL,
|
||||
context_window INTEGER NOT NULL DEFAULT 8192,
|
||||
max_output_tokens INTEGER NOT NULL DEFAULT 4096,
|
||||
supports_streaming INTEGER NOT NULL DEFAULT 1,
|
||||
supports_vision INTEGER NOT NULL DEFAULT 0,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
pricing_input REAL DEFAULT 0,
|
||||
pricing_output REAL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
supports_streaming BOOLEAN NOT NULL DEFAULT true,
|
||||
supports_vision BOOLEAN NOT NULL DEFAULT false,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
pricing_input DOUBLE PRECISION DEFAULT 0,
|
||||
pricing_output DOUBLE PRECISION DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(provider_id, model_id),
|
||||
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
||||
);
|
||||
@@ -118,18 +109,18 @@ CREATE TABLE IF NOT EXISTS account_api_keys (
|
||||
key_value TEXT NOT NULL,
|
||||
key_label TEXT,
|
||||
permissions TEXT NOT NULL DEFAULT '[]',
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
last_used_at TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
revoked_at TEXT,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
last_used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
revoked_at TIMESTAMPTZ,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS usage_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
account_id TEXT NOT NULL,
|
||||
provider_id TEXT NOT NULL,
|
||||
model_id TEXT NOT NULL,
|
||||
@@ -138,10 +129,12 @@ CREATE TABLE IF NOT EXISTS usage_records (
|
||||
latency_ms INTEGER,
|
||||
status TEXT NOT NULL DEFAULT 'success',
|
||||
error_message TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_provider ON usage_records(provider_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_model ON usage_records(model_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS relay_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -158,14 +151,15 @@ CREATE TABLE IF NOT EXISTS relay_tasks (
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
error_message TEXT,
|
||||
queued_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
queued_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
started_at TIMESTAMPTZ,
|
||||
completed_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_relay_account_status ON relay_tasks(account_id, status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS config_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -176,15 +170,15 @@ CREATE TABLE IF NOT EXISTS config_items (
|
||||
default_value TEXT,
|
||||
source TEXT NOT NULL DEFAULT 'local',
|
||||
description TEXT,
|
||||
requires_restart INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
requires_restart BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(category, key_path)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS config_sync_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
account_id TEXT NOT NULL,
|
||||
client_fingerprint TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
@@ -192,7 +186,7 @@ CREATE TABLE IF NOT EXISTS config_sync_log (
|
||||
client_values TEXT,
|
||||
saas_values TEXT,
|
||||
resolution TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id);
|
||||
|
||||
@@ -203,8 +197,8 @@ CREATE TABLE IF NOT EXISTS devices (
|
||||
device_name TEXT,
|
||||
platform TEXT,
|
||||
app_version TEXT,
|
||||
last_seen_at TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
last_seen_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id);
|
||||
@@ -213,55 +207,76 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_unique ON devices(account_id, devi
|
||||
"#;
|
||||
|
||||
const SEED_ROLES: &str = r#"
|
||||
INSERT OR IGNORE INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
|
||||
INSERT INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
|
||||
VALUES
|
||||
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write"]', 1, datetime('now'), datetime('now')),
|
||||
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write"]', 1, datetime('now'), datetime('now')),
|
||||
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read"]', 1, datetime('now'), datetime('now'));
|
||||
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write"]', true, NOW(), NOW()),
|
||||
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write"]', true, NOW(), NOW()),
|
||||
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read"]', true, NOW(), NOW())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
"#;
|
||||
|
||||
/// 初始化数据库
|
||||
pub async fn init_db(database_url: &str) -> SaasResult<SqlitePool> {
|
||||
if database_url.starts_with("sqlite:") {
|
||||
let path_part = database_url.strip_prefix("sqlite:").unwrap_or("");
|
||||
if path_part != ":memory:" {
|
||||
if let Some(parent) = std::path::Path::new(path_part).parent() {
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
/// PostgreSQL 不支持在单条 prepared statement 中执行多条 SQL 命令,
|
||||
/// 因此需要拆分后逐条执行。
|
||||
async fn execute_multi_statements(pool: &PgPool, sql: &str) -> SaasResult<()> {
|
||||
for stmt in sql.split(';') {
|
||||
let trimmed = stmt.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = sqlx::query(trimmed).execute(pool).await {
|
||||
let err_str = e.to_string();
|
||||
// 忽略 "已存在" 类错误 (并发初始化或重复调用)
|
||||
let is_already_exists = err_str.contains("already exists")
|
||||
|| err_str.contains("已经存在")
|
||||
|| err_str.contains("重复键");
|
||||
if !is_already_exists {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let pool = SqlitePool::connect(database_url).await?;
|
||||
sqlx::query("PRAGMA journal_mode=WAL;")
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
sqlx::query(SCHEMA_SQL).execute(&pool).await?;
|
||||
sqlx::query("INSERT OR IGNORE INTO saas_schema_version (version) VALUES (?1)")
|
||||
.bind(SCHEMA_VERSION)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
sqlx::query(SEED_ROLES).execute(&pool).await?;
|
||||
/// 初始化数据库
|
||||
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
tracing::info!("Connecting to database: {}", database_url);
|
||||
let pool = PgPool::connect(database_url).await?;
|
||||
execute_multi_statements(&pool, SCHEMA_SQL).await?;
|
||||
execute_multi_statements(&pool, SEED_ROLES).await?;
|
||||
seed_admin_account(&pool).await?;
|
||||
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// 创建内存数据库 (测试用)
|
||||
pub async fn init_memory_db() -> SaasResult<SqlitePool> {
|
||||
let pool = SqlitePool::connect("sqlite::memory:").await?;
|
||||
sqlx::query(SCHEMA_SQL).execute(&pool).await?;
|
||||
sqlx::query("INSERT OR IGNORE INTO saas_schema_version (version) VALUES (?1)")
|
||||
.bind(SCHEMA_VERSION)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
sqlx::query(SEED_ROLES).execute(&pool).await?;
|
||||
/// 创建测试数据库 (连接到真实 PG 实例)
|
||||
/// 测试前清空所有数据,确保每次从干净状态开始
|
||||
pub async fn init_test_db() -> SaasResult<PgPool> {
|
||||
let url = std::env::var("ZCLAW_TEST_DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://localhost:5432/zclaw_test".to_string());
|
||||
let pool = PgPool::connect(&url).await?;
|
||||
execute_multi_statements(&pool, SCHEMA_SQL).await?;
|
||||
clean_test_data(&pool).await?;
|
||||
execute_multi_statements(&pool, SEED_ROLES).await?;
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// 清空所有表数据 (按外键依赖顺序,使用 DELETE 而非 TRUNCATE)
|
||||
/// DELETE 不获取 ACCESS EXCLUSIVE 锁,对并发更友好
|
||||
pub async fn clean_test_data(pool: &PgPool) -> SaasResult<()> {
|
||||
let tables_to_clean = [
|
||||
"config_sync_log", "config_items", "usage_records", "relay_tasks",
|
||||
"account_api_keys", "models", "providers", "operation_logs",
|
||||
"api_tokens", "devices", "roles", "accounts",
|
||||
];
|
||||
for table in &tables_to_clean {
|
||||
let _ = sqlx::query(&format!("DELETE FROM {}", table))
|
||||
.execute(pool).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
|
||||
async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
|
||||
async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
|
||||
let has_accounts: (bool,) = sqlx::query_as(
|
||||
"SELECT EXISTS(SELECT 1 FROM accounts LIMIT 1) as has"
|
||||
)
|
||||
@@ -291,18 +306,16 @@ async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
|
||||
let password_hash = hash_password(&admin_password)?;
|
||||
let account_id = uuid::Uuid::new_v4().to_string();
|
||||
let email = format!("{}@zclaw.local", admin_username);
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, 'super_admin', 'active', ?6, ?6)"
|
||||
VALUES ($1, $2, $3, $4, $5, 'super_admin', 'active', NOW(), NOW())"
|
||||
)
|
||||
.bind(&account_id)
|
||||
.bind(&admin_username)
|
||||
.bind(&email)
|
||||
.bind(&password_hash)
|
||||
.bind(&admin_username)
|
||||
.bind(&now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
@@ -316,13 +329,35 @@ async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// 全局 Mutex 用于序列化所有数据库测试,避免并行测试之间的数据竞争
|
||||
static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
/// 共享测试连接池,避免每次测试都创建新连接
|
||||
static TEST_POOL: tokio::sync::OnceCell<PgPool> = tokio::sync::OnceCell::const_new();
|
||||
|
||||
/// 获取测试连接池(异步初始化,避免嵌套 runtime 问题)
|
||||
async fn get_test_pool() -> &'static PgPool {
|
||||
TEST_POOL.get_or_init(|| async {
|
||||
init_test_db().await.expect("init_test_db failed")
|
||||
}).await
|
||||
}
|
||||
|
||||
/// 每个测试前清理数据,确保隔离
|
||||
async fn clean_before_test(pool: &PgPool) {
|
||||
clean_test_data(pool).await.expect("clean_test_data failed");
|
||||
execute_multi_statements(pool, SEED_ROLES).await.expect("seed roles failed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_init_memory_db() {
|
||||
let pool = init_memory_db().await.unwrap();
|
||||
async fn test_init_test_db() {
|
||||
// 获取全局锁,确保测试串行执行
|
||||
let _guard = TEST_LOCK.lock().unwrap();
|
||||
let pool = get_test_pool().await;
|
||||
clean_before_test(pool).await;
|
||||
|
||||
let roles: Vec<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM roles WHERE is_system = 1"
|
||||
"SELECT id FROM roles WHERE is_system = true"
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(roles.len(), 3);
|
||||
@@ -330,17 +365,20 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_tables_exist() {
|
||||
let pool = init_memory_db().await.unwrap();
|
||||
let _guard = TEST_LOCK.lock().unwrap();
|
||||
let pool = get_test_pool().await;
|
||||
clean_before_test(pool).await;
|
||||
|
||||
let tables = [
|
||||
"accounts", "api_tokens", "roles", "permission_templates",
|
||||
"accounts", "api_tokens", "roles",
|
||||
"operation_logs", "providers", "models", "account_api_keys",
|
||||
"usage_records", "relay_tasks", "config_items", "config_sync_log", "devices",
|
||||
];
|
||||
for table in tables {
|
||||
let count: (i64,) = sqlx::query_as(&format!(
|
||||
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{}'", table
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema='public' AND table_name='{}'", table
|
||||
))
|
||||
.fetch_one(&pool)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(count.0, 1, "Table {} should exist", table);
|
||||
|
||||
@@ -127,3 +127,62 @@ impl IntoResponse for SaasError {
|
||||
|
||||
/// Result 类型别名
|
||||
pub type SaasResult<T> = std::result::Result<T, SaasError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn status_code_maps_correctly() {
|
||||
assert_eq!(SaasError::NotFound("x".into()).status_code(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(SaasError::Forbidden("x".into()).status_code(), StatusCode::FORBIDDEN);
|
||||
assert_eq!(SaasError::Unauthorized.status_code(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(SaasError::InvalidInput("x".into()).status_code(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(SaasError::AlreadyExists("x".into()).status_code(), StatusCode::CONFLICT);
|
||||
assert_eq!(SaasError::RateLimited("x".into()).status_code(), StatusCode::TOO_MANY_REQUESTS);
|
||||
assert_eq!(SaasError::Relay("x".into()).status_code(), StatusCode::BAD_GATEWAY);
|
||||
assert_eq!(SaasError::Totp("x".into()).status_code(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(SaasError::Internal("x".into()).status_code(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert_eq!(SaasError::AuthError("x".into()).status_code(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_code_returns_expected_strings() {
|
||||
assert_eq!(SaasError::NotFound("x".into()).error_code(), "NOT_FOUND");
|
||||
assert_eq!(SaasError::RateLimited("x".into()).error_code(), "RATE_LIMITED");
|
||||
assert_eq!(SaasError::Unauthorized.error_code(), "UNAUTHORIZED");
|
||||
assert_eq!(SaasError::Encryption("x".into()).error_code(), "ENCRYPTION_ERROR");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn into_response_hides_internal_errors() {
|
||||
// 内部错误不应泄露细节
|
||||
let err = SaasError::Internal("secret database password exposed".into());
|
||||
let resp = err.into_response();
|
||||
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024)
|
||||
.await
|
||||
.expect("body should be readable");
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert_eq!(body["error"], "INTERNAL_ERROR");
|
||||
assert_eq!(body["message"], "服务内部错误");
|
||||
assert!(!body["message"].as_str().unwrap().contains("secret"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn into_response_shows_user_facing_errors() {
|
||||
let err = SaasError::InvalidInput("用户名不能为空".into());
|
||||
let resp = err.into_response();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024)
|
||||
.await
|
||||
.expect("body should be readable");
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert_eq!(body["error"], "INVALID_INPUT");
|
||||
// InvalidInput includes the "无效输入: " prefix from Display impl
|
||||
let msg = body["message"].as_str().unwrap();
|
||||
assert!(msg.contains("用户名不能为空"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
//! 独立的 SaaS 后端服务,提供账号权限管理、模型配置、请求中转和配置迁移。
|
||||
|
||||
pub mod config;
|
||||
pub mod crypto;
|
||||
pub mod csrf;
|
||||
pub mod db;
|
||||
pub mod error;
|
||||
pub mod middleware;
|
||||
pub mod openapi;
|
||||
pub mod state;
|
||||
|
||||
pub mod auth;
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
//! ZCLAW SaaS 服务入口
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
|
||||
use axum::{extract::State, Json};
|
||||
|
||||
async fn health_handler(State(_state): State<AppState>) -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"status": "ok",
|
||||
"service": "zclaw-saas",
|
||||
}))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
@@ -19,7 +28,63 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Database initialized");
|
||||
|
||||
let state = AppState::new(db, config.clone())?;
|
||||
let app = build_router(state);
|
||||
|
||||
// SEC-14: 后台清理 rate_limit_entries DashMap,防止不活跃账号条目无限增长。
|
||||
// 中间件仅在被请求命中时清理对应 entry,不活跃的 account 永远不会被回收。
|
||||
// 此任务每 5 分钟扫描一次,移除所有时间戳均已超过 2 分钟的 entry
|
||||
// (滑动窗口为 1 分钟,2 分钟是安全的 2x 余量)。
|
||||
{
|
||||
let rate_limit_entries = state.rate_limit_entries.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(5 * 60)).await;
|
||||
|
||||
let cutoff = Instant::now() - Duration::from_secs(2 * 60);
|
||||
let mut removed = 0usize;
|
||||
|
||||
rate_limit_entries.retain(|_account_id, timestamps| {
|
||||
timestamps.retain(|&ts| ts > cutoff);
|
||||
let keep = !timestamps.is_empty();
|
||||
if !keep {
|
||||
removed += 1;
|
||||
}
|
||||
keep
|
||||
});
|
||||
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed,
|
||||
remaining = rate_limit_entries.len(),
|
||||
"rate limiter cleanup: removed stale entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// CORS 安全检查:生产环境必须配置 cors_origins
|
||||
if config.server.cors_origins.is_empty() {
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
if !is_dev {
|
||||
anyhow::bail!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
|
||||
}
|
||||
}
|
||||
|
||||
let app = build_router(state, &config);
|
||||
|
||||
// Swagger UI / OpenAPI 文档
|
||||
// TODO: 启用 Swagger UI 后取消注释 (需要 utoipa / utoipa-swagger-ui 版本对齐)
|
||||
// let app = {
|
||||
// use utoipa_swagger_ui::SwaggerUi;
|
||||
// use utoipa::OpenApi;
|
||||
// let openapi = zclaw_saas::openapi::ApiDoc::openapi();
|
||||
// app.merge(
|
||||
// SwaggerUi::new("/api-docs/openapi.json")
|
||||
// .url("/api-docs/openapi.json", openapi),
|
||||
// )
|
||||
// };
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port))
|
||||
.await?;
|
||||
@@ -29,27 +94,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_router(state: AppState) -> axum::Router {
|
||||
fn build_router(state: AppState, config: &SaaSConfig) -> axum::Router {
|
||||
use axum::middleware;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use axum::http::HeaderValue;
|
||||
let cors = {
|
||||
let config = state.config.blocking_read();
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
if config.server.cors_origins.is_empty() {
|
||||
if is_dev {
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
} else {
|
||||
tracing::error!("生产环境必须配置 server.cors_origins,不能使用 allow_origin(Any)");
|
||||
panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
|
||||
}
|
||||
// 开发环境允许任意 origin(生产环境已在 main 中拦截)
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
} else {
|
||||
let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
|
||||
.filter_map(|o: &String| o.parse::<HeaderValue>().ok())
|
||||
@@ -72,14 +129,20 @@ fn build_router(state: AppState) -> axum::Router {
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::rate_limit_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::csrf::origin_check_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
));
|
||||
|
||||
axum::Router::new()
|
||||
.route("/api/health", axum::routing::get(health_handler))
|
||||
.merge(public_routes)
|
||||
.merge(protected_routes)
|
||||
.layer(axum::extract::DefaultBodyLimit::max(10 * 1024 * 1024)) // 10MB 请求体限制,防止 DoS
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
|
||||
@@ -10,6 +10,58 @@ use std::time::Instant;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 速率限制检查结果
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub(crate) enum RateLimitResult {
|
||||
/// 允许通过
|
||||
Allowed,
|
||||
/// 被限制,附带 Retry-After 秒数
|
||||
Limited { retry_after_secs: u64 },
|
||||
}
|
||||
|
||||
/// 滑动窗口速率限制核心逻辑(纯函数,便于测试)
|
||||
///
|
||||
/// 返回 `RateLimitResult::Allowed` 表示未超限(已记录本次请求),
|
||||
/// `RateLimitResult::Limited` 表示超限。
|
||||
pub(crate) fn check_rate_limit(
|
||||
entries: &mut Vec<Instant>,
|
||||
now: Instant,
|
||||
window_duration: std::time::Duration,
|
||||
max_requests: u64,
|
||||
) -> RateLimitResult {
|
||||
let window_start = now - window_duration;
|
||||
|
||||
// 清理过期条目
|
||||
entries.retain(|&ts| ts > window_start);
|
||||
|
||||
let count = entries.len() as u64;
|
||||
if count < max_requests {
|
||||
entries.push(now);
|
||||
RateLimitResult::Allowed
|
||||
} else {
|
||||
// 计算最早条目的过期时间作为 Retry-After
|
||||
entries.sort();
|
||||
let earliest = *entries.first().unwrap_or(&now);
|
||||
let elapsed = now.duration_since(earliest).as_secs();
|
||||
let retry_after = window_duration.as_secs().saturating_sub(elapsed);
|
||||
RateLimitResult::Limited {
|
||||
retry_after_secs: retry_after,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// 清理过期条目并移除空 entry
|
||||
fn cleanup_stale_entries(
|
||||
map: &dashmap::DashMap<String, Vec<Instant>>,
|
||||
cutoff: Instant,
|
||||
) {
|
||||
map.retain(|_, entries| {
|
||||
entries.retain(|&ts| ts > cutoff);
|
||||
!entries.is_empty()
|
||||
});
|
||||
}
|
||||
|
||||
/// 滑动窗口速率限制中间件
|
||||
///
|
||||
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
|
||||
@@ -37,45 +89,186 @@ pub async fn rate_limit_middleware(
|
||||
drop(config);
|
||||
|
||||
let now = Instant::now();
|
||||
let window_start = now - std::time::Duration::from_secs(60);
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
|
||||
// 滑动窗口: 清理过期条目 + 计数
|
||||
let current_count = {
|
||||
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
|
||||
entries.retain(|&ts| ts > window_start);
|
||||
let count = entries.len() as u64;
|
||||
if count < max_requests {
|
||||
entries.push(now);
|
||||
0 // 未超限
|
||||
} else {
|
||||
count
|
||||
let result = check_rate_limit(&mut entries, now, window, max_requests);
|
||||
if let RateLimitResult::Limited { retry_after_secs } = result {
|
||||
if let Some(entries) = state.rate_limit_entries.get_mut(&account_id) {
|
||||
if entries.is_empty() {
|
||||
drop(entries);
|
||||
state.rate_limit_entries.remove(&account_id);
|
||||
}
|
||||
}
|
||||
return (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
[
|
||||
("Retry-After", retry_after_secs.to_string()),
|
||||
("Content-Type", "application/json".to_string()),
|
||||
],
|
||||
axum::Json(serde_json::json!({
|
||||
"error": "RATE_LIMITED",
|
||||
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after_secs),
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
entries.len() as u64
|
||||
};
|
||||
|
||||
if current_count >= max_requests {
|
||||
// 计算最早条目的过期时间作为 Retry-After
|
||||
let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) {
|
||||
entries.sort();
|
||||
let earliest = *entries.first().unwrap_or(&now);
|
||||
let elapsed = now.duration_since(earliest).as_secs();
|
||||
60u64.saturating_sub(elapsed)
|
||||
} else {
|
||||
60
|
||||
};
|
||||
|
||||
return (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
[
|
||||
("Retry-After", retry_after.to_string()),
|
||||
("Content-Type", "application/json".to_string()),
|
||||
],
|
||||
axum::Json(serde_json::json!({
|
||||
"error": "RATE_LIMITED",
|
||||
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after),
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
// 清理空 entry (不再活跃的用户)
|
||||
if current_count == 0 {
|
||||
if let Some(entries) = state.rate_limit_entries.get_mut(&account_id) {
|
||||
if entries.is_empty() {
|
||||
drop(entries);
|
||||
state.rate_limit_entries.remove(&account_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn allows_under_limit() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
|
||||
for i in 0..5 {
|
||||
let result = check_rate_limit(&mut entries, now, window, 10);
|
||||
assert_eq!(result, RateLimitResult::Allowed, "request {} should be allowed", i);
|
||||
}
|
||||
assert_eq!(entries.len() as u64, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocks_at_limit() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
let limit: u64 = 3;
|
||||
|
||||
// 填到限额
|
||||
for _ in 0..limit {
|
||||
let result = check_rate_limit(&mut entries, now, window, limit);
|
||||
assert_eq!(result, RateLimitResult::Allowed);
|
||||
}
|
||||
assert_eq!(entries.len() as u64, limit);
|
||||
|
||||
// 下一个应该被限流
|
||||
let result = check_rate_limit(&mut entries, now, window, limit);
|
||||
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
|
||||
// 不应该增加新条目
|
||||
assert_eq!(entries.len() as u64, limit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expired_entries_are_cleaned() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
|
||||
// 插入一个 61 秒前的旧条目
|
||||
entries.push(now - std::time::Duration::from_secs(61));
|
||||
assert_eq!(entries.len(), 1);
|
||||
|
||||
// 旧条目应该被清理,然后允许新请求
|
||||
let result = check_rate_limit(&mut entries, now, window, 1);
|
||||
assert_eq!(result, RateLimitResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retry_after_reflects_earliest_entry() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
let limit: u64 = 2;
|
||||
|
||||
// 第一个请求在 10 秒前
|
||||
let first_time = now - std::time::Duration::from_secs(10);
|
||||
entries.push(first_time);
|
||||
// 第二个请求现在
|
||||
entries.push(now);
|
||||
|
||||
assert_eq!(entries.len() as u64, limit);
|
||||
|
||||
let result = check_rate_limit(&mut entries, now, window, limit);
|
||||
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 50 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn burst_allows_extra_requests() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
let rpm: u64 = 5;
|
||||
let burst: u64 = 3;
|
||||
let max = rpm + burst; // 8
|
||||
|
||||
// 前 8 个请求应该全部通过
|
||||
for _ in 0..max {
|
||||
let result = check_rate_limit(&mut entries, now, window, max);
|
||||
assert_eq!(result, RateLimitResult::Allowed);
|
||||
}
|
||||
// 第 9 个被限流
|
||||
let result = check_rate_limit(&mut entries, now, window, max);
|
||||
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cleanup_removes_expired_and_empty() {
|
||||
let map: dashmap::DashMap<String, Vec<Instant>> = dashmap::DashMap::new();
|
||||
let now = Instant::now();
|
||||
let cutoff = now - std::time::Duration::from_secs(120);
|
||||
|
||||
// 活跃用户
|
||||
map.insert("active".to_string(), vec![now]);
|
||||
// 过期用户
|
||||
map.insert(
|
||||
"expired".to_string(),
|
||||
vec![now - std::time::Duration::from_secs(200)],
|
||||
);
|
||||
// 空用户
|
||||
map.insert("empty".to_string(), vec![]);
|
||||
|
||||
cleanup_stale_entries(&map, cutoff);
|
||||
|
||||
assert!(map.contains_key("active"));
|
||||
assert!(!map.contains_key("expired"));
|
||||
assert!(!map.contains_key("empty"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_entries_allowed() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(60);
|
||||
|
||||
let result = check_rate_limit(&mut entries, now, window, 0);
|
||||
// limit=0 means always limited
|
||||
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_request_with_large_window() {
|
||||
let mut entries: Vec<Instant> = vec![];
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(3600);
|
||||
let limit: u64 = 100;
|
||||
|
||||
for _ in 0..limit {
|
||||
let result = check_rate_limit(&mut entries, now, window, limit);
|
||||
assert_eq!(result, RateLimitResult::Allowed);
|
||||
}
|
||||
assert_eq!(entries.len() as u64, limit);
|
||||
|
||||
let result = check_rate_limit(&mut entries, now, window, limit);
|
||||
assert!(matches!(result, RateLimitResult::Limited { .. }));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use axum::{
|
||||
use crate::state::AppState;
|
||||
use crate::error::SaasResult;
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::auth::handlers::check_permission;
|
||||
use crate::auth::handlers::{check_permission, log_operation};
|
||||
use super::{types::*, service};
|
||||
|
||||
/// GET /api/v1/config/items?category=xxx&source=xxx
|
||||
@@ -36,6 +36,9 @@ pub async fn create_config_item(
|
||||
) -> SaasResult<(StatusCode, Json<ConfigItemInfo>)> {
|
||||
check_permission(&ctx, "config:write")?;
|
||||
let item = service::create_config_item(&state.db, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "config.create", "config_item", &item.id,
|
||||
Some(serde_json::json!({"category": req.category, "key_path": req.key_path})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
Ok((StatusCode::CREATED, Json(item)))
|
||||
}
|
||||
|
||||
@@ -47,7 +50,10 @@ pub async fn update_config_item(
|
||||
Json(req): Json<UpdateConfigItemRequest>,
|
||||
) -> SaasResult<Json<ConfigItemInfo>> {
|
||||
check_permission(&ctx, "config:write")?;
|
||||
service::update_config_item(&state.db, &id, &req).await.map(Json)
|
||||
let item = service::update_config_item(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "config.update", "config_item", &id, None,
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
Ok(Json(item))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/config/items/:id (admin only)
|
||||
@@ -58,6 +64,8 @@ pub async fn delete_config_item(
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "config:write")?;
|
||||
service::delete_config_item(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "config.delete", "config_item", &id, None,
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
@@ -76,16 +84,24 @@ pub async fn seed_config(
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "config:write")?;
|
||||
let count = service::seed_default_config_items(&state.db).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "config.seed", "config_items", "batch",
|
||||
Some(serde_json::json!({"created": count})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
Ok(Json(serde_json::json!({"created": count})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/config/sync
|
||||
/// POST /api/v1/config/sync (admin only)
|
||||
pub async fn sync_config(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<SyncConfigRequest>,
|
||||
) -> SaasResult<Json<super::service::ConfigSyncResult>> {
|
||||
super::service::sync_config(&state.db, &ctx.account_id, &req).await.map(Json)
|
||||
check_permission(&ctx, "config:write")?;
|
||||
let result = super::service::sync_config(&state.db, &ctx.account_id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "config.sync", "config_sync", &ctx.account_id,
|
||||
Some(serde_json::json!({"action": req.action, "updated": result.updated, "created": result.created, "skipped": result.skipped})),
|
||||
ctx.client_ip.as_deref()).await?;
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
||||
/// POST /api/v1/config/diff
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! 配置迁移业务逻辑
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
use serde::Serialize;
|
||||
@@ -8,20 +8,20 @@ use serde::Serialize;
|
||||
// ============ Config Items ============
|
||||
|
||||
pub async fn list_config_items(
|
||||
db: &SqlitePool, query: &ConfigQuery,
|
||||
db: &PgPool, query: &ConfigQuery,
|
||||
) -> SaasResult<Vec<ConfigItemInfo>> {
|
||||
let sql = match (&query.category, &query.source) {
|
||||
(Some(_), Some(_)) => {
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE category = ?1 AND source = ?2 ORDER BY category, key_path"
|
||||
FROM config_items WHERE category = $1 AND source = $2 ORDER BY category, key_path"
|
||||
}
|
||||
(Some(_), None) => {
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE category = ?1 ORDER BY key_path"
|
||||
FROM config_items WHERE category = $1 ORDER BY key_path"
|
||||
}
|
||||
(None, Some(_)) => {
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE source = ?1 ORDER BY category, key_path"
|
||||
FROM config_items WHERE source = $1 ORDER BY category, key_path"
|
||||
}
|
||||
(None, None) => {
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
@@ -29,7 +29,7 @@ pub async fn list_config_items(
|
||||
}
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(sql);
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(sql);
|
||||
|
||||
if let Some(cat) = &query.category {
|
||||
query_builder = query_builder.bind(cat);
|
||||
@@ -40,15 +40,15 @@ pub async fn list_config_items(
|
||||
|
||||
let rows = query_builder.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
|
||||
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }
|
||||
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn get_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<ConfigItemInfo> {
|
||||
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)> =
|
||||
pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigItemInfo> {
|
||||
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||
FROM config_items WHERE id = ?1"
|
||||
FROM config_items WHERE id = $1"
|
||||
)
|
||||
.bind(item_id)
|
||||
.fetch_optional(db)
|
||||
@@ -57,20 +57,20 @@ pub async fn get_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<Confi
|
||||
let (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
|
||||
|
||||
Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at })
|
||||
Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
|
||||
}
|
||||
|
||||
pub async fn create_config_item(
|
||||
db: &SqlitePool, req: &CreateConfigItemRequest,
|
||||
db: &PgPool, req: &CreateConfigItemRequest,
|
||||
) -> SaasResult<ConfigItemInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let source = req.source.as_deref().unwrap_or("local");
|
||||
let requires_restart = req.requires_restart.unwrap_or(false);
|
||||
|
||||
// 检查唯一性
|
||||
let existing: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM config_items WHERE category = ?1 AND key_path = ?2"
|
||||
"SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
|
||||
)
|
||||
.bind(&req.category).bind(&req.key_path)
|
||||
.fetch_optional(db).await?;
|
||||
@@ -83,7 +83,7 @@ pub async fn create_config_item(
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?10)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $10)"
|
||||
)
|
||||
.bind(&id).bind(&req.category).bind(&req.key_path).bind(&req.value_type)
|
||||
.bind(&req.current_value).bind(&req.default_value).bind(source)
|
||||
@@ -94,36 +94,38 @@ pub async fn create_config_item(
|
||||
}
|
||||
|
||||
pub async fn update_config_item(
|
||||
db: &SqlitePool, item_id: &str, req: &UpdateConfigItemRequest,
|
||||
db: &PgPool, item_id: &str, req: &UpdateConfigItemRequest,
|
||||
) -> SaasResult<ConfigItemInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<String> = Vec::new();
|
||||
let mut param_idx: i32 = 1;
|
||||
|
||||
if let Some(ref v) = req.current_value { updates.push("current_value = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.source { updates.push("source = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.description { updates.push("description = ?"); params.push(v.clone()); }
|
||||
if let Some(ref v) = req.current_value { updates.push(format!("current_value = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.source { updates.push(format!("source = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
if let Some(ref v) = req.description { updates.push(format!("description = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_config_item(db, item_id).await;
|
||||
}
|
||||
|
||||
updates.push("updated_at = ?");
|
||||
params.push(now);
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(item_id.to_string());
|
||||
|
||||
let sql = format!("UPDATE config_items SET {} WHERE id = ?", updates.join(", "));
|
||||
let sql = format!("UPDATE config_items SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(p);
|
||||
}
|
||||
query = query.bind(now);
|
||||
query.execute(db).await?;
|
||||
|
||||
get_config_item(db, item_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM config_items WHERE id = ?1")
|
||||
pub async fn delete_config_item(db: &PgPool, item_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM config_items WHERE id = $1")
|
||||
.bind(item_id).execute(db).await?;
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(SaasError::NotFound(format!("配置项 {} 不存在", item_id)));
|
||||
@@ -133,7 +135,7 @@ pub async fn delete_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<()
|
||||
|
||||
// ============ Config Analysis ============
|
||||
|
||||
pub async fn analyze_config(db: &SqlitePool) -> SaasResult<ConfigAnalysis> {
|
||||
pub async fn analyze_config(db: &PgPool) -> SaasResult<ConfigAnalysis> {
|
||||
let items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
||||
|
||||
let mut categories: std::collections::HashMap<String, (i64, i64)> = std::collections::HashMap::new();
|
||||
@@ -157,7 +159,7 @@ pub async fn analyze_config(db: &SqlitePool) -> SaasResult<ConfigAnalysis> {
|
||||
}
|
||||
|
||||
/// 种子默认配置项
|
||||
pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
||||
pub async fn seed_default_config_items(db: &PgPool) -> SaasResult<usize> {
|
||||
let defaults = [
|
||||
("server", "server.host", "string", Some("127.0.0.1"), Some("127.0.0.1"), "服务器监听地址"),
|
||||
("server", "server.port", "integer", Some("4200"), Some("4200"), "服务器端口"),
|
||||
@@ -175,11 +177,11 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
||||
];
|
||||
|
||||
let mut created = 0;
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
for (category, key_path, value_type, default_value, current_value, description) in defaults {
|
||||
let existing: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM config_items WHERE category = ?1 AND key_path = ?2"
|
||||
"SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
|
||||
)
|
||||
.bind(category).bind(key_path)
|
||||
.fetch_optional(db)
|
||||
@@ -189,7 +191,7 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'local', ?7, 0, ?8, ?8)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, false, $8, $8)"
|
||||
)
|
||||
.bind(&id).bind(category).bind(key_path).bind(value_type)
|
||||
.bind(current_value).bind(default_value).bind(description).bind(&now)
|
||||
@@ -204,21 +206,20 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
||||
|
||||
// ============ Config Sync ============
|
||||
|
||||
/// 计算客户端与 SaaS 端的配置差异
|
||||
pub async fn compute_config_diff(
|
||||
db: &SqlitePool, req: &SyncConfigRequest,
|
||||
) -> SaasResult<ConfigDiffResponse> {
|
||||
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
||||
|
||||
/// 纯函数:计算客户端与 SaaS 配置项的差异(不依赖数据库)
|
||||
pub fn compute_diff_items(
|
||||
config_keys: &[String],
|
||||
client_values: &serde_json::Value,
|
||||
saas_items: &[ConfigItemInfo],
|
||||
) -> (Vec<ConfigDiffItem>, usize) {
|
||||
let mut items = Vec::new();
|
||||
let mut conflicts = 0usize;
|
||||
|
||||
for key in &req.config_keys {
|
||||
let client_val = req.client_values.get(key)
|
||||
for key in config_keys {
|
||||
let client_val = client_values.get(key)
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// 查找 SaaS 端的值
|
||||
let saas_item = saas_items.iter().find(|item| item.key_path == *key);
|
||||
let saas_val = saas_item.and_then(|item| item.current_value.clone());
|
||||
|
||||
@@ -239,6 +240,17 @@ pub async fn compute_config_diff(
|
||||
});
|
||||
}
|
||||
|
||||
(items, conflicts)
|
||||
}
|
||||
|
||||
/// 计算客户端与 SaaS 端的配置差异
|
||||
pub async fn compute_config_diff(
|
||||
db: &PgPool, req: &SyncConfigRequest,
|
||||
) -> SaasResult<ConfigDiffResponse> {
|
||||
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
||||
|
||||
let (items, conflicts) = compute_diff_items(&req.config_keys, &req.client_values, &saas_items);
|
||||
|
||||
Ok(ConfigDiffResponse {
|
||||
total_keys: items.len(),
|
||||
conflicts,
|
||||
@@ -248,16 +260,16 @@ pub async fn compute_config_diff(
|
||||
|
||||
/// 执行配置同步 (实际写入 config_items)
|
||||
pub async fn sync_config(
|
||||
db: &SqlitePool, account_id: &str, req: &SyncConfigRequest,
|
||||
db: &PgPool, account_id: &str, req: &SyncConfigRequest,
|
||||
) -> SaasResult<ConfigSyncResult> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let config_keys_str = serde_json::to_string(&req.config_keys)?;
|
||||
let client_values_str = Some(serde_json::to_string(&req.client_values)?);
|
||||
|
||||
// 获取 SaaS 端的配置值
|
||||
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
||||
let mut updated = 0i64;
|
||||
let created = 0i64;
|
||||
let mut created = 0i64;
|
||||
let mut skipped = 0i64;
|
||||
|
||||
for key in &req.config_keys {
|
||||
@@ -273,13 +285,20 @@ pub async fn sync_config(
|
||||
if let Some(val) = &client_val {
|
||||
if let Some(item) = saas_item {
|
||||
// 更新已有配置项
|
||||
sqlx::query("UPDATE config_items SET current_value = ?1, source = 'local', updated_at = ?2 WHERE id = ?3")
|
||||
sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
|
||||
.bind(val).bind(&now).bind(&item.id)
|
||||
.execute(db).await?;
|
||||
updated += 1;
|
||||
} else {
|
||||
// 推送时如果 SaaS 不存在该 key,记录跳过
|
||||
skipped += 1;
|
||||
// SaaS 不存在该 key → 自动创建
|
||||
let new_id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, source, requires_restart, created_at, updated_at)
|
||||
VALUES ($1, 'imported', $2, 'string', $3, 'local', false, $4, $4)"
|
||||
)
|
||||
.bind(&new_id).bind(key).bind(val).bind(&now)
|
||||
.execute(db).await?;
|
||||
created += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -288,7 +307,7 @@ pub async fn sync_config(
|
||||
if let Some(val) = &client_val {
|
||||
if let Some(item) = saas_item {
|
||||
if item.current_value.is_none() || item.current_value.as_deref() == Some("") {
|
||||
sqlx::query("UPDATE config_items SET current_value = ?1, source = 'local', updated_at = ?2 WHERE id = ?3")
|
||||
sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
|
||||
.bind(val).bind(&now).bind(&item.id)
|
||||
.execute(db).await?;
|
||||
updated += 1;
|
||||
@@ -296,9 +315,17 @@ pub async fn sync_config(
|
||||
// 冲突: SaaS 有值 → 保留 SaaS 值
|
||||
skipped += 1;
|
||||
}
|
||||
} else {
|
||||
// SaaS 完全没有该 key → 创建
|
||||
let new_id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, source, requires_restart, created_at, updated_at)
|
||||
VALUES ($1, 'imported', $2, 'string', $3, 'local', false, $4, $4)"
|
||||
)
|
||||
.bind(&new_id).bind(key).bind(val).bind(&now)
|
||||
.execute(db).await?;
|
||||
created += 1;
|
||||
}
|
||||
// 客户端有但 SaaS 完全没有的 key → 不自动创建 (需要管理员先创建)
|
||||
skipped += 1;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -323,7 +350,7 @@ pub async fn sync_config(
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO config_sync_log (account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
)
|
||||
.bind(account_id).bind(&req.client_fingerprint)
|
||||
.bind(&req.action).bind(&config_keys_str).bind(&client_values_str)
|
||||
@@ -343,18 +370,126 @@ pub struct ConfigSyncResult {
|
||||
}
|
||||
|
||||
pub async fn list_sync_logs(
|
||||
db: &SqlitePool, account_id: &str,
|
||||
db: &PgPool, account_id: &str,
|
||||
) -> SaasResult<Vec<ConfigSyncLogInfo>> {
|
||||
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, String)> =
|
||||
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
|
||||
FROM config_sync_log WHERE account_id = ?1 ORDER BY created_at DESC LIMIT 50"
|
||||
FROM config_sync_log WHERE account_id = $1 ORDER BY created_at DESC LIMIT 50"
|
||||
)
|
||||
.bind(account_id)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
|
||||
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at }
|
||||
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at: created_at.to_rfc3339() }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_saas_item(key: &str, value: Option<&str>) -> ConfigItemInfo {
|
||||
ConfigItemInfo {
|
||||
id: "test-id".into(),
|
||||
category: "test".into(),
|
||||
key_path: key.into(),
|
||||
value_type: "string".into(),
|
||||
current_value: value.map(String::from),
|
||||
default_value: None,
|
||||
source: "local".into(),
|
||||
description: None,
|
||||
requires_restart: false,
|
||||
created_at: "2026-01-01T00:00:00Z".into(),
|
||||
updated_at: "2026-01-01T00:00:00Z".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_identical_values() {
|
||||
let keys = vec!["server.host".into(), "server.port".into()];
|
||||
let client = serde_json::json!({"server.host": "127.0.0.1", "server.port": "8080"});
|
||||
let saas = vec![
|
||||
make_saas_item("server.host", Some("127.0.0.1")),
|
||||
make_saas_item("server.port", Some("8080")),
|
||||
];
|
||||
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
|
||||
assert_eq!(conflicts, 0);
|
||||
assert_eq!(items.len(), 2);
|
||||
assert!(!items[0].conflict);
|
||||
assert!(!items[1].conflict);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_conflict() {
|
||||
let keys = vec!["server.host".into()];
|
||||
let client = serde_json::json!({"server.host": "0.0.0.0"});
|
||||
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
|
||||
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
|
||||
assert_eq!(conflicts, 1);
|
||||
assert!(items[0].conflict);
|
||||
assert_eq!(items[0].client_value.as_deref(), Some("0.0.0.0"));
|
||||
assert_eq!(items[0].saas_value.as_deref(), Some("127.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_client_only_key() {
|
||||
let keys = vec!["new.key".into()];
|
||||
let client = serde_json::json!({"new.key": "value1"});
|
||||
let saas = vec![]; // SaaS 没有这个 key
|
||||
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
|
||||
assert_eq!(conflicts, 0);
|
||||
assert_eq!(items[0].client_value.as_deref(), Some("value1"));
|
||||
assert!(items[0].saas_value.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_missing_client_value() {
|
||||
let keys = vec!["server.host".into()];
|
||||
let client = serde_json::json!({}); // 客户端没有这个 key
|
||||
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
|
||||
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
|
||||
assert_eq!(conflicts, 0); // 一方为 null 不算冲突
|
||||
assert!(items[0].client_value.is_none());
|
||||
assert_eq!(items[0].saas_value.as_deref(), Some("127.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_empty_keys() {
|
||||
let keys: Vec<String> = vec![];
|
||||
let client = serde_json::json!({"server.host": "127.0.0.1"});
|
||||
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
|
||||
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
|
||||
assert!(items.is_empty());
|
||||
assert_eq!(conflicts, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_mixed() {
|
||||
let keys = vec!["same".into(), "conflict".into(), "client_only".into(), "saas_only".into()];
|
||||
let client = serde_json::json!({
|
||||
"same": "val1",
|
||||
"conflict": "client-val",
|
||||
"client_only": "new-val",
|
||||
});
|
||||
let saas = vec![
|
||||
make_saas_item("same", Some("val1")),
|
||||
make_saas_item("conflict", Some("saas-val")),
|
||||
make_saas_item("saas_only", Some("only-here")),
|
||||
];
|
||||
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
|
||||
assert_eq!(items.len(), 4);
|
||||
assert_eq!(conflicts, 1);
|
||||
// same: no conflict
|
||||
assert!(!items[0].conflict);
|
||||
// conflict: has conflict
|
||||
assert!(items[1].conflict);
|
||||
// client_only: SaaS has no such key
|
||||
assert!(items[2].saas_value.is_none());
|
||||
assert_eq!(items[2].client_value.as_deref(), Some("new-val"));
|
||||
// saas_only: client has no such key
|
||||
assert!(items[3].client_value.is_none());
|
||||
assert_eq!(items[3].saas_value.as_deref(), Some("only-here"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 配置项信息
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ConfigItemInfo {
|
||||
pub id: String,
|
||||
pub category: String,
|
||||
@@ -19,7 +19,7 @@ pub struct ConfigItemInfo {
|
||||
}
|
||||
|
||||
/// 创建配置项请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateConfigItemRequest {
|
||||
pub category: String,
|
||||
pub key_path: String,
|
||||
@@ -32,7 +32,7 @@ pub struct CreateConfigItemRequest {
|
||||
}
|
||||
|
||||
/// 更新配置项请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateConfigItemRequest {
|
||||
pub current_value: Option<String>,
|
||||
pub source: Option<String>,
|
||||
@@ -40,7 +40,7 @@ pub struct UpdateConfigItemRequest {
|
||||
}
|
||||
|
||||
/// 配置同步日志
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
|
||||
pub struct ConfigSyncLogInfo {
|
||||
pub id: i64,
|
||||
pub account_id: String,
|
||||
@@ -54,14 +54,14 @@ pub struct ConfigSyncLogInfo {
|
||||
}
|
||||
|
||||
/// 配置分析结果
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct ConfigAnalysis {
|
||||
pub total_items: i64,
|
||||
pub categories: Vec<CategorySummary>,
|
||||
pub items: Vec<ConfigItemInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct CategorySummary {
|
||||
pub category: String,
|
||||
pub count: i64,
|
||||
@@ -69,10 +69,10 @@ pub struct CategorySummary {
|
||||
}
|
||||
|
||||
/// 配置同步请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct SyncConfigRequest {
|
||||
pub client_fingerprint: String,
|
||||
/// 同步方向: "push", "pull", "merge"
|
||||
/// 同步方向: "push", "merge"
|
||||
#[serde(default = "default_sync_action")]
|
||||
pub action: String,
|
||||
pub config_keys: Vec<String>,
|
||||
@@ -82,7 +82,7 @@ pub struct SyncConfigRequest {
|
||||
fn default_sync_action() -> String { "push".to_string() }
|
||||
|
||||
/// 配置差异项
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
|
||||
pub struct ConfigDiffItem {
|
||||
pub key_path: String,
|
||||
pub client_value: Option<String>,
|
||||
@@ -91,7 +91,7 @@ pub struct ConfigDiffItem {
|
||||
}
|
||||
|
||||
/// 配置差异响应
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct ConfigDiffResponse {
|
||||
pub items: Vec<ConfigDiffItem>,
|
||||
pub total_keys: usize,
|
||||
@@ -99,7 +99,7 @@ pub struct ConfigDiffResponse {
|
||||
}
|
||||
|
||||
/// 配置查询参数
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
|
||||
pub struct ConfigQuery {
|
||||
pub category: Option<String>,
|
||||
pub source: Option<String>,
|
||||
|
||||
@@ -36,7 +36,7 @@ pub async fn create_provider(
|
||||
Json(req): Json<CreateProviderRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
let provider = service::create_provider(&state.db, &req).await?;
|
||||
let provider = service::create_provider(&state.db, &state.field_encryption, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
|
||||
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||
Ok((StatusCode::CREATED, Json(provider)))
|
||||
@@ -50,7 +50,7 @@ pub async fn update_provider(
|
||||
Json(req): Json<UpdateProviderRequest>,
|
||||
) -> SaasResult<Json<ProviderInfo>> {
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
let provider = service::update_provider(&state.db, &id, &req).await?;
|
||||
let provider = service::update_provider(&state.db, &state.field_encryption, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
Ok(Json(provider))
|
||||
}
|
||||
@@ -135,7 +135,7 @@ pub async fn list_api_keys(
|
||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
|
||||
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
||||
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json)
|
||||
service::list_account_api_keys(&state.db, &state.field_encryption, &ctx.account_id, provider_id).await.map(Json)
|
||||
}
|
||||
|
||||
/// POST /api/v1/keys
|
||||
@@ -144,7 +144,7 @@ pub async fn create_api_key(
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateAccountApiKeyRequest>,
|
||||
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
|
||||
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?;
|
||||
let key = service::create_account_api_key(&state.db, &state.field_encryption, &ctx.account_id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
|
||||
Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
|
||||
Ok((StatusCode::CREATED, Json(key)))
|
||||
@@ -157,7 +157,7 @@ pub async fn rotate_api_key(
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<RotateApiKeyRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?;
|
||||
service::rotate_account_api_key(&state.db, &state.field_encryption, &id, &ctx.account_id, &req.new_key_value).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
//! 模型配置业务逻辑
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use crate::crypto::FieldEncryption;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
|
||||
// ============ Providers ============
|
||||
|
||||
pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
|
||||
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||
pub async fn list_providers(db: &PgPool) -> SaasResult<Vec<ProviderInfo>> {
|
||||
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||
FROM providers ORDER BY name"
|
||||
@@ -16,15 +18,15 @@ pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
|
||||
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
|
||||
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
||||
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
||||
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||
FROM providers WHERE id = ?1"
|
||||
FROM providers WHERE id = $1"
|
||||
)
|
||||
.bind(provider_id)
|
||||
.fetch_optional(db)
|
||||
@@ -33,25 +35,33 @@ pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<Prov
|
||||
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
|
||||
|
||||
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
|
||||
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
|
||||
}
|
||||
|
||||
pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult<ProviderInfo> {
|
||||
pub async fn create_provider(
|
||||
db: &PgPool, encryption: &Arc<FieldEncryption>, req: &CreateProviderRequest,
|
||||
) -> SaasResult<ProviderInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
// 检查名称唯一性
|
||||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1")
|
||||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
|
||||
.bind(&req.name).fetch_optional(db).await?;
|
||||
if existing.is_some() {
|
||||
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
|
||||
}
|
||||
|
||||
// 加密 API Key 后存储
|
||||
let encrypted_api_key: Option<String> = match &req.api_key {
|
||||
Some(key) => Some(encryption.encrypt(key)?),
|
||||
None => None,
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
|
||||
)
|
||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key)
|
||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
|
||||
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
@@ -59,40 +69,48 @@ pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> Sa
|
||||
}
|
||||
|
||||
pub async fn update_provider(
|
||||
db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest,
|
||||
db: &PgPool, encryption: &Arc<FieldEncryption>, provider_id: &str, req: &UpdateProviderRequest,
|
||||
) -> SaasResult<ProviderInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||
let mut param_idx: i32 = 1;
|
||||
|
||||
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); }
|
||||
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(ref v) = req.api_key {
|
||||
// 加密 API Key 后存储
|
||||
let encrypted = encryption.encrypt(v)?;
|
||||
updates.push(format!("api_key = ${}", param_idx));
|
||||
params.push(Box::new(encrypted));
|
||||
param_idx += 1;
|
||||
}
|
||||
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_provider(db, provider_id).await;
|
||||
}
|
||||
|
||||
updates.push("updated_at = ?");
|
||||
params.push(Box::new(now.clone()));
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(Box::new(provider_id.to_string()));
|
||||
|
||||
let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", "));
|
||||
let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(format!("{}", p));
|
||||
}
|
||||
query = query.bind(now);
|
||||
query.execute(db).await?;
|
||||
|
||||
get_provider(db, provider_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM providers WHERE id = ?1")
|
||||
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
|
||||
.bind(provider_id).execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
@@ -103,36 +121,36 @@ pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<(
|
||||
|
||||
// ============ Models ============
|
||||
|
||||
pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
|
||||
pub async fn list_models(db: &PgPool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
|
||||
let sql = if provider_id.is_some() {
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models WHERE provider_id = ?1 ORDER BY alias"
|
||||
FROM models WHERE provider_id = $1 ORDER BY alias"
|
||||
} else {
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models ORDER BY provider_id, alias"
|
||||
};
|
||||
|
||||
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql);
|
||||
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(sql);
|
||||
if let Some(pid) = provider_id {
|
||||
query = query.bind(pid);
|
||||
}
|
||||
|
||||
let rows = query.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
|
||||
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
|
||||
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
||||
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
||||
// 验证 provider 存在
|
||||
let provider = get_provider(db, &req.provider_id).await?;
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
// 检查 model 唯一性
|
||||
let existing: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2"
|
||||
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
|
||||
)
|
||||
.bind(&req.provider_id).bind(&req.model_id)
|
||||
.fetch_optional(db).await?;
|
||||
@@ -152,7 +170,7 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
|
||||
)
|
||||
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
|
||||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
||||
@@ -161,11 +179,11 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
|
||||
get_model(db, &id).await
|
||||
}
|
||||
|
||||
pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
|
||||
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||
FROM models WHERE id = ?1"
|
||||
FROM models WHERE id = $1"
|
||||
)
|
||||
.bind(model_id)
|
||||
.fetch_optional(db)
|
||||
@@ -174,45 +192,47 @@ pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo>
|
||||
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
|
||||
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
||||
|
||||
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at })
|
||||
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
|
||||
}
|
||||
|
||||
pub async fn update_model(
|
||||
db: &SqlitePool, model_id: &str, req: &UpdateModelRequest,
|
||||
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
|
||||
) -> SaasResult<ModelInfo> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let mut updates = Vec::new();
|
||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||
let mut param_idx: i32 = 1;
|
||||
|
||||
if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); }
|
||||
if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); }
|
||||
if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); }
|
||||
if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||
if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||
|
||||
if updates.is_empty() {
|
||||
return get_model(db, model_id).await;
|
||||
}
|
||||
|
||||
updates.push("updated_at = ?");
|
||||
params.push(Box::new(now.clone()));
|
||||
updates.push(format!("updated_at = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(Box::new(model_id.to_string()));
|
||||
|
||||
let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", "));
|
||||
let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||
let mut query = sqlx::query(&sql);
|
||||
for p in ¶ms {
|
||||
query = query.bind(format!("{}", p));
|
||||
}
|
||||
query = query.bind(now);
|
||||
query.execute(db).await?;
|
||||
|
||||
get_model(db, model_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM models WHERE id = ?1")
|
||||
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM models WHERE id = $1")
|
||||
.bind(model_id).execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
@@ -224,17 +244,17 @@ pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
|
||||
// ============ Account API Keys ============
|
||||
|
||||
pub async fn list_account_api_keys(
|
||||
db: &SqlitePool, account_id: &str, provider_id: Option<&str>,
|
||||
db: &PgPool, encryption: &Arc<FieldEncryption>, account_id: &str, provider_id: Option<&str>,
|
||||
) -> SaasResult<Vec<AccountApiKeyInfo>> {
|
||||
let sql = if provider_id.is_some() {
|
||||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
||||
FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
} else {
|
||||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
||||
FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||
};
|
||||
|
||||
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(sql)
|
||||
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>, String)>(sql)
|
||||
.bind(account_id);
|
||||
if let Some(pid) = provider_id {
|
||||
query = query.bind(pid);
|
||||
@@ -243,26 +263,32 @@ pub async fn list_account_api_keys(
|
||||
let rows = query.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
|
||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||
let masked = mask_api_key(&key_value);
|
||||
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
|
||||
// 解密 key_value 后再做掩码处理(兼容迁移期间的明文数据)
|
||||
let decrypted = encryption.decrypt_or_plaintext(&key_value);
|
||||
let masked = mask_api_key(&decrypted);
|
||||
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339(), masked_key: masked }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn create_account_api_key(
|
||||
db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest,
|
||||
db: &PgPool, encryption: &Arc<FieldEncryption>, account_id: &str, req: &CreateAccountApiKeyRequest,
|
||||
) -> SaasResult<AccountApiKeyInfo> {
|
||||
// 验证 provider 存在
|
||||
get_provider(db, &req.provider_id).await?;
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let now_str = now.to_rfc3339();
|
||||
let permissions = serde_json::to_string(&req.permissions)?;
|
||||
|
||||
// 加密 key_value 后存储
|
||||
let encrypted_key_value = encryption.encrypt(&req.key_value)?;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value)
|
||||
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
|
||||
.bind(&req.key_label).bind(&permissions).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
@@ -270,18 +296,20 @@ pub async fn create_account_api_key(
|
||||
Ok(AccountApiKeyInfo {
|
||||
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
|
||||
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
|
||||
created_at: now, masked_key: masked,
|
||||
created_at: now_str, masked_key: masked,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn rotate_account_api_key(
|
||||
db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str,
|
||||
db: &PgPool, encryption: &Arc<FieldEncryption>, key_id: &str, account_id: &str, new_key_value: &str,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
// 加密新 key_value 后存储
|
||||
let encrypted_key = encryption.encrypt(new_key_value)?;
|
||||
let result = sqlx::query(
|
||||
"UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL"
|
||||
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(new_key_value).bind(&now).bind(key_id).bind(account_id)
|
||||
.bind(&encrypted_key).bind(&now).bind(key_id).bind(account_id)
|
||||
.execute(db).await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
@@ -291,11 +319,11 @@ pub async fn rotate_account_api_key(
|
||||
}
|
||||
|
||||
pub async fn revoke_account_api_key(
|
||||
db: &SqlitePool, key_id: &str, account_id: &str,
|
||||
db: &PgPool, key_id: &str, account_id: &str,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
let result = sqlx::query(
|
||||
"UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
|
||||
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
||||
)
|
||||
.bind(&now).bind(key_id).bind(account_id)
|
||||
.execute(db).await?;
|
||||
@@ -309,25 +337,30 @@ pub async fn revoke_account_api_key(
|
||||
// ============ Usage Statistics ============
|
||||
|
||||
pub async fn get_usage_stats(
|
||||
db: &SqlitePool, account_id: &str, query: &UsageQuery,
|
||||
db: &PgPool, account_id: &str, query: &UsageQuery,
|
||||
) -> SaasResult<UsageStats> {
|
||||
let mut where_clauses = vec!["account_id = ?".to_string()];
|
||||
let mut param_idx: i32 = 1;
|
||||
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
|
||||
param_idx += 1;
|
||||
let mut params: Vec<String> = vec![account_id.to_string()];
|
||||
|
||||
if let Some(ref from) = query.from {
|
||||
where_clauses.push("created_at >= ?".to_string());
|
||||
where_clauses.push(format!("created_at >= ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(from.clone());
|
||||
}
|
||||
if let Some(ref to) = query.to {
|
||||
where_clauses.push("created_at <= ?".to_string());
|
||||
where_clauses.push(format!("created_at <= ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(to.clone());
|
||||
}
|
||||
if let Some(ref pid) = query.provider_id {
|
||||
where_clauses.push("provider_id = ?".to_string());
|
||||
where_clauses.push(format!("provider_id = ${}", param_idx));
|
||||
param_idx += 1;
|
||||
params.push(pid.clone());
|
||||
}
|
||||
if let Some(ref mid) = query.model_id {
|
||||
where_clauses.push("model_id = ?".to_string());
|
||||
where_clauses.push(format!("model_id = ${}", param_idx));
|
||||
params.push(mid.clone());
|
||||
}
|
||||
|
||||
@@ -361,10 +394,10 @@ pub async fn get_usage_stats(
|
||||
}).collect();
|
||||
|
||||
// 按天统计 (最近 30 天)
|
||||
let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339();
|
||||
let from_30d = chrono::Utc::now() - chrono::Duration::days(30);
|
||||
let daily_sql = format!(
|
||||
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||
FROM usage_records WHERE account_id = ?1 AND created_at >= ?2
|
||||
FROM usage_records WHERE account_id = $1 AND created_at >= $2
|
||||
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
|
||||
);
|
||||
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
|
||||
@@ -385,14 +418,14 @@ pub async fn get_usage_stats(
|
||||
}
|
||||
|
||||
pub async fn record_usage(
|
||||
db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str,
|
||||
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
|
||||
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
|
||||
status: &str, error_message: Option<&str>,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query(
|
||||
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
)
|
||||
.bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
|
||||
@@ -409,3 +442,73 @@ fn mask_api_key(key: &str) -> String {
|
||||
}
|
||||
format!("{}...{}", &key[..4], &key[key.len()-4..])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ---- mask_api_key ----
|
||||
|
||||
#[test]
|
||||
fn mask_key_long_key() {
|
||||
let key = "sk-abcdefghijklmnopqrstuvwxyz123456";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "sk-a...3456");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_exactly_8_chars() {
|
||||
// keys <= 8 chars are fully masked
|
||||
let key = "12345678";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "********");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_7_chars() {
|
||||
let key = "abcdefg";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "*******");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_1_char() {
|
||||
let key = "a";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_empty() {
|
||||
let key = "";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_9_chars_boundary() {
|
||||
// 9 chars is the first that uses prefix...suffix format
|
||||
let key = "abcdefghi";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "abcd...fghi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_standard_openai_format() {
|
||||
let key = "sk-proj-abcdefghijklmnopqrstuvwx";
|
||||
let masked = mask_api_key(key);
|
||||
assert_eq!(masked, "sk-p...uvwx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_no_ellipsis_for_short() {
|
||||
let masked = mask_api_key("short");
|
||||
assert!(!masked.contains("..."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_key_has_ellipsis_for_long() {
|
||||
let masked = mask_api_key("this_is_a_very_long_key_value");
|
||||
assert!(masked.contains("..."));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
// --- Provider ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ProviderInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
@@ -18,7 +18,7 @@ pub struct ProviderInfo {
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateProviderRequest {
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
@@ -32,7 +32,7 @@ pub struct CreateProviderRequest {
|
||||
|
||||
fn default_protocol() -> String { "openai".into() }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateProviderRequest {
|
||||
pub display_name: Option<String>,
|
||||
pub base_url: Option<String>,
|
||||
@@ -45,7 +45,7 @@ pub struct UpdateProviderRequest {
|
||||
|
||||
// --- Model ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
@@ -62,7 +62,7 @@ pub struct ModelInfo {
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateModelRequest {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
@@ -75,7 +75,7 @@ pub struct CreateModelRequest {
|
||||
pub pricing_output: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateModelRequest {
|
||||
pub alias: Option<String>,
|
||||
pub context_window: Option<i64>,
|
||||
@@ -89,7 +89,7 @@ pub struct UpdateModelRequest {
|
||||
|
||||
// --- Account API Key ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
|
||||
pub struct AccountApiKeyInfo {
|
||||
pub id: String,
|
||||
pub provider_id: String,
|
||||
@@ -101,7 +101,7 @@ pub struct AccountApiKeyInfo {
|
||||
pub masked_key: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateAccountApiKeyRequest {
|
||||
pub provider_id: String,
|
||||
pub key_value: String,
|
||||
@@ -110,14 +110,14 @@ pub struct CreateAccountApiKeyRequest {
|
||||
pub permissions: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct RotateApiKeyRequest {
|
||||
pub new_key_value: String,
|
||||
}
|
||||
|
||||
// --- Usage ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct UsageStats {
|
||||
pub total_requests: i64,
|
||||
pub total_input_tokens: i64,
|
||||
@@ -126,7 +126,7 @@ pub struct UsageStats {
|
||||
pub by_day: Vec<DailyUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct ModelUsage {
|
||||
pub provider_id: String,
|
||||
pub model_id: String,
|
||||
@@ -135,7 +135,7 @@ pub struct ModelUsage {
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct DailyUsage {
|
||||
pub date: String,
|
||||
pub request_count: i64,
|
||||
@@ -143,7 +143,7 @@ pub struct DailyUsage {
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
|
||||
pub struct UsageQuery {
|
||||
pub from: Option<String>,
|
||||
pub to: Option<String>,
|
||||
@@ -151,22 +151,3 @@ pub struct UsageQuery {
|
||||
pub model_id: Option<String>,
|
||||
}
|
||||
|
||||
// --- Seed Data ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SeedProvider {
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub base_url: String,
|
||||
pub models: Vec<SeedModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SeedModel {
|
||||
pub id: String,
|
||||
pub alias: String,
|
||||
pub context_window: Option<i64>,
|
||||
pub max_output_tokens: Option<i64>,
|
||||
pub supports_streaming: Option<bool>,
|
||||
pub supports_vision: Option<bool>,
|
||||
}
|
||||
|
||||
790
crates/zclaw-saas/src/openapi.rs
Normal file
790
crates/zclaw-saas/src/openapi.rs
Normal file
@@ -0,0 +1,790 @@
|
||||
//! OpenAPI / Swagger 文档定义
|
||||
//!
|
||||
//! 聚合所有模块的 schema,并在 build_router 中通过 utoipa-swagger-ui 暴露文档。
|
||||
|
||||
use utoipa::OpenApi;
|
||||
|
||||
/// ZCLAW SaaS API 根 OpenApi 定义
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
info(
|
||||
title = "ZCLAW SaaS API",
|
||||
version = "0.1.0",
|
||||
description = "ZCLAW SaaS 后端服务 API -- 账号权限管理、模型配置、请求中转和配置迁移",
|
||||
license(name = "Apache-2.0 OR MIT")
|
||||
),
|
||||
tags(
|
||||
(name = "auth", description = "认证 (登录 / 注册 / TOTP)"),
|
||||
(name = "accounts", description = "账号管理"),
|
||||
(name = "providers", description = "模型供应商"),
|
||||
(name = "models", description = "模型配置"),
|
||||
(name = "keys", description = "API Key 管理"),
|
||||
(name = "usage", description = "用量统计"),
|
||||
(name = "relay", description = "请求中转"),
|
||||
(name = "config", description = "配置迁移"),
|
||||
),
|
||||
paths(
|
||||
crate::openapi::paths::auth::register,
|
||||
crate::openapi::paths::auth::login,
|
||||
crate::openapi::paths::auth::refresh,
|
||||
crate::openapi::paths::auth::me,
|
||||
crate::openapi::paths::auth::change_password,
|
||||
crate::openapi::paths::auth::totp_setup,
|
||||
crate::openapi::paths::auth::totp_verify,
|
||||
crate::openapi::paths::auth::totp_disable,
|
||||
crate::openapi::paths::accounts::list_accounts,
|
||||
crate::openapi::paths::accounts::get_account,
|
||||
crate::openapi::paths::accounts::update_account,
|
||||
crate::openapi::paths::accounts::update_status,
|
||||
crate::openapi::paths::accounts::list_tokens,
|
||||
crate::openapi::paths::accounts::create_token,
|
||||
crate::openapi::paths::accounts::revoke_token,
|
||||
crate::openapi::paths::accounts::list_devices,
|
||||
crate::openapi::paths::accounts::register_device,
|
||||
crate::openapi::paths::accounts::device_heartbeat,
|
||||
crate::openapi::paths::accounts::list_operation_logs,
|
||||
crate::openapi::paths::accounts::dashboard_stats,
|
||||
crate::openapi::paths::providers::list_providers,
|
||||
crate::openapi::paths::providers::get_provider,
|
||||
crate::openapi::paths::providers::create_provider,
|
||||
crate::openapi::paths::providers::update_provider,
|
||||
crate::openapi::paths::providers::delete_provider,
|
||||
crate::openapi::paths::providers::list_provider_models,
|
||||
crate::openapi::paths::models::list_models,
|
||||
crate::openapi::paths::models::get_model,
|
||||
crate::openapi::paths::models::create_model,
|
||||
crate::openapi::paths::models::update_model,
|
||||
crate::openapi::paths::models::delete_model,
|
||||
crate::openapi::paths::keys::list_api_keys,
|
||||
crate::openapi::paths::keys::create_api_key,
|
||||
crate::openapi::paths::keys::revoke_api_key,
|
||||
crate::openapi::paths::keys::rotate_api_key,
|
||||
crate::openapi::paths::usage::get_usage,
|
||||
crate::openapi::paths::relay::chat_completions,
|
||||
crate::openapi::paths::relay::list_tasks,
|
||||
crate::openapi::paths::relay::get_task,
|
||||
crate::openapi::paths::relay::retry_task,
|
||||
crate::openapi::paths::relay::list_available_models,
|
||||
crate::openapi::paths::config::list_config_items,
|
||||
crate::openapi::paths::config::get_config_item,
|
||||
crate::openapi::paths::config::create_config_item,
|
||||
crate::openapi::paths::config::update_config_item,
|
||||
crate::openapi::paths::config::delete_config_item,
|
||||
crate::openapi::paths::config::analyze_config,
|
||||
crate::openapi::paths::config::seed_config,
|
||||
crate::openapi::paths::config::sync_config,
|
||||
crate::openapi::paths::config::config_diff,
|
||||
crate::openapi::paths::config::list_sync_logs,
|
||||
),
|
||||
components(schemas(
|
||||
crate::auth::types::LoginRequest,
|
||||
crate::auth::types::LoginResponse,
|
||||
crate::auth::types::RegisterRequest,
|
||||
crate::auth::types::ChangePasswordRequest,
|
||||
crate::auth::types::AccountPublic,
|
||||
crate::account::types::UpdateAccountRequest,
|
||||
crate::account::types::UpdateStatusRequest,
|
||||
crate::account::types::ListAccountsQuery,
|
||||
crate::account::types::AccountPublicPaginatedResponse,
|
||||
crate::account::types::CreateTokenRequest,
|
||||
crate::account::types::TokenInfo,
|
||||
crate::account::types::RegisterDeviceRequest,
|
||||
crate::account::types::DeviceHeartbeatRequest,
|
||||
crate::account::types::DeviceInfo,
|
||||
crate::model_config::types::ProviderInfo,
|
||||
crate::model_config::types::CreateProviderRequest,
|
||||
crate::model_config::types::UpdateProviderRequest,
|
||||
crate::model_config::types::ModelInfo,
|
||||
crate::model_config::types::CreateModelRequest,
|
||||
crate::model_config::types::UpdateModelRequest,
|
||||
crate::model_config::types::AccountApiKeyInfo,
|
||||
crate::model_config::types::CreateAccountApiKeyRequest,
|
||||
crate::model_config::types::RotateApiKeyRequest,
|
||||
crate::model_config::types::UsageStats,
|
||||
crate::model_config::types::ModelUsage,
|
||||
crate::model_config::types::DailyUsage,
|
||||
crate::model_config::types::UsageQuery,
|
||||
crate::relay::types::RelayTaskInfo,
|
||||
crate::relay::types::RelayTaskQuery,
|
||||
crate::migration::types::ConfigItemInfo,
|
||||
crate::migration::types::CreateConfigItemRequest,
|
||||
crate::migration::types::UpdateConfigItemRequest,
|
||||
crate::migration::types::ConfigSyncLogInfo,
|
||||
crate::migration::types::ConfigAnalysis,
|
||||
crate::migration::types::CategorySummary,
|
||||
crate::migration::types::SyncConfigRequest,
|
||||
crate::migration::types::ConfigDiffItem,
|
||||
crate::migration::types::ConfigDiffResponse,
|
||||
crate::migration::types::ConfigQuery,
|
||||
)),
|
||||
modifiers(&SecurityAddon)
|
||||
)]
|
||||
pub struct ApiDoc;
|
||||
|
||||
struct SecurityAddon;
|
||||
|
||||
impl utoipa::Modify for SecurityAddon {
|
||||
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
|
||||
if let Some(components) = openapi.components.as_mut() {
|
||||
components.add_security_scheme(
|
||||
"bearer_auth",
|
||||
utoipa::openapi::security::SecurityScheme::Http(
|
||||
utoipa::openapi::security::Http::new(
|
||||
utoipa::openapi::security::HttpAuthScheme::Bearer,
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Path stubs for OpenAPI documentation generation.
|
||||
/// These functions are never called at runtime -- they exist solely so that
|
||||
/// `utoipa::path` can produce the correct OpenAPI spec entries.
|
||||
pub mod paths {
|
||||
pub mod auth {
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/register",
|
||||
tag = "auth",
|
||||
request_body = crate::auth::types::RegisterRequest,
|
||||
responses(
|
||||
(status = 201, description = "注册成功", body = crate::auth::types::LoginResponse),
|
||||
(status = 409, description = "用户已存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn register() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/login",
|
||||
tag = "auth",
|
||||
request_body = crate::auth::types::LoginRequest,
|
||||
responses(
|
||||
(status = 200, description = "登录成功", body = crate::auth::types::LoginResponse),
|
||||
(status = 401, description = "认证失败"),
|
||||
)
|
||||
)]
|
||||
pub async fn login() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/refresh",
|
||||
tag = "auth",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "刷新 token 成功", body = crate::auth::types::LoginResponse),
|
||||
(status = 401, description = "认证失败"),
|
||||
)
|
||||
)]
|
||||
pub async fn refresh() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/auth/me",
|
||||
tag = "auth",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "当前用户信息", body = crate::auth::types::AccountPublic),
|
||||
(status = 401, description = "未认证"),
|
||||
)
|
||||
)]
|
||||
pub async fn me() {}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/auth/password",
|
||||
tag = "auth",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::auth::types::ChangePasswordRequest,
|
||||
responses(
|
||||
(status = 200, description = "密码修改成功"),
|
||||
(status = 400, description = "旧密码不正确"),
|
||||
(status = 401, description = "未认证"),
|
||||
)
|
||||
)]
|
||||
pub async fn change_password() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/totp/setup",
|
||||
tag = "auth",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "TOTP 设置信息(含 secret 和 QR URI)"),
|
||||
(status = 401, description = "未认证"),
|
||||
(status = 409, description = "TOTP 已启用"),
|
||||
)
|
||||
)]
|
||||
pub async fn totp_setup() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/totp/verify",
|
||||
tag = "auth",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "验证成功,TOTP 已启用"),
|
||||
(status = 401, description = "验证码错误"),
|
||||
)
|
||||
)]
|
||||
pub async fn totp_verify() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/totp/disable",
|
||||
tag = "auth",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "TOTP 已禁用"),
|
||||
(status = 401, description = "密码错误"),
|
||||
)
|
||||
)]
|
||||
pub async fn totp_disable() {}
|
||||
}
|
||||
|
||||
pub mod accounts {
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/accounts",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
params(crate::account::types::ListAccountsQuery),
|
||||
responses(
|
||||
(status = 200, description = "账号列表", body = crate::account::types::AccountPublicPaginatedResponse),
|
||||
)
|
||||
)]
|
||||
pub async fn list_accounts() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/accounts/{id}",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "账号 ID")),
|
||||
responses(
|
||||
(status = 200, description = "账号详情", body = crate::auth::types::AccountPublic),
|
||||
(status = 404, description = "账号不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn get_account() {}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/accounts/{id}",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "账号 ID")),
|
||||
request_body = crate::account::types::UpdateAccountRequest,
|
||||
responses(
|
||||
(status = 200, description = "更新成功", body = crate::auth::types::AccountPublic),
|
||||
(status = 404, description = "账号不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn update_account() {}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/accounts/{id}/status",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "账号 ID")),
|
||||
request_body = crate::account::types::UpdateStatusRequest,
|
||||
responses(
|
||||
(status = 200, description = "状态更新成功"),
|
||||
)
|
||||
)]
|
||||
pub async fn update_status() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/tokens",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "Token 列表", body = Vec<crate::account::types::TokenInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_tokens() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/tokens",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::account::types::CreateTokenRequest,
|
||||
responses(
|
||||
(status = 201, description = "创建成功", body = crate::account::types::TokenInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn create_token() {}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/tokens/{id}",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "Token ID")),
|
||||
responses(
|
||||
(status = 204, description = "撤销成功"),
|
||||
)
|
||||
)]
|
||||
pub async fn revoke_token() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/devices",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "设备列表", body = Vec<crate::account::types::DeviceInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_devices() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/devices/register",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::account::types::RegisterDeviceRequest,
|
||||
responses(
|
||||
(status = 201, description = "注册成功", body = crate::account::types::DeviceInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn register_device() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/devices/heartbeat",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::account::types::DeviceHeartbeatRequest,
|
||||
responses(
|
||||
(status = 200, description = "心跳更新成功"),
|
||||
)
|
||||
)]
|
||||
pub async fn device_heartbeat() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/logs/operations",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
params(
|
||||
("page" = Option<i32>, Query, description = "页码"),
|
||||
("page_size" = Option<i32>, Query, description = "每页数量"),
|
||||
("action" = Option<String>, Query, description = "操作类型过滤"),
|
||||
("account_id" = Option<String>, Query, description = "账号 ID 过滤"),
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "操作日志列表"),
|
||||
)
|
||||
)]
|
||||
pub async fn list_operation_logs() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/stats/dashboard",
|
||||
tag = "accounts",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "仪表盘统计数据"),
|
||||
)
|
||||
)]
|
||||
pub async fn dashboard_stats() {}
|
||||
}
|
||||
|
||||
pub mod providers {
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/providers",
|
||||
tag = "providers",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "供应商列表", body = Vec<crate::model_config::types::ProviderInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_providers() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/providers/{id}",
|
||||
tag = "providers",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "供应商 ID")),
|
||||
responses(
|
||||
(status = 200, description = "供应商详情", body = crate::model_config::types::ProviderInfo),
|
||||
(status = 404, description = "供应商不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn get_provider() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/providers",
|
||||
tag = "providers",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::model_config::types::CreateProviderRequest,
|
||||
responses(
|
||||
(status = 201, description = "创建成功", body = crate::model_config::types::ProviderInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn create_provider() {}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/providers/{id}",
|
||||
tag = "providers",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "供应商 ID")),
|
||||
request_body = crate::model_config::types::UpdateProviderRequest,
|
||||
responses(
|
||||
(status = 200, description = "更新成功", body = crate::model_config::types::ProviderInfo),
|
||||
(status = 404, description = "供应商不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn update_provider() {}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/providers/{id}",
|
||||
tag = "providers",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "供应商 ID")),
|
||||
responses(
|
||||
(status = 204, description = "删除成功"),
|
||||
(status = 404, description = "供应商不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn delete_provider() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/providers/{id}/models",
|
||||
tag = "providers",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "供应商 ID")),
|
||||
responses(
|
||||
(status = 200, description = "供应商下的模型列表", body = Vec<crate::model_config::types::ModelInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_provider_models() {}
|
||||
}
|
||||
|
||||
pub mod models {
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/models",
|
||||
tag = "models",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "模型列表", body = Vec<crate::model_config::types::ModelInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_models() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/models/{id}",
|
||||
tag = "models",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "模型 ID")),
|
||||
responses(
|
||||
(status = 200, description = "模型详情", body = crate::model_config::types::ModelInfo),
|
||||
(status = 404, description = "模型不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn get_model() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/models",
|
||||
tag = "models",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::model_config::types::CreateModelRequest,
|
||||
responses(
|
||||
(status = 201, description = "创建成功", body = crate::model_config::types::ModelInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn create_model() {}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/models/{id}",
|
||||
tag = "models",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "模型 ID")),
|
||||
request_body = crate::model_config::types::UpdateModelRequest,
|
||||
responses(
|
||||
(status = 200, description = "更新成功", body = crate::model_config::types::ModelInfo),
|
||||
(status = 404, description = "模型不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn update_model() {}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/models/{id}",
|
||||
tag = "models",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "模型 ID")),
|
||||
responses(
|
||||
(status = 204, description = "删除成功"),
|
||||
(status = 404, description = "模型不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn delete_model() {}
|
||||
}
|
||||
|
||||
pub mod keys {
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/keys",
|
||||
tag = "keys",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "API Key 列表", body = Vec<crate::model_config::types::AccountApiKeyInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_api_keys() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/keys",
|
||||
tag = "keys",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::model_config::types::CreateAccountApiKeyRequest,
|
||||
responses(
|
||||
(status = 201, description = "创建成功", body = crate::model_config::types::AccountApiKeyInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn create_api_key() {}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/keys/{id}",
|
||||
tag = "keys",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "Key ID")),
|
||||
responses(
|
||||
(status = 204, description = "撤销成功"),
|
||||
)
|
||||
)]
|
||||
pub async fn revoke_api_key() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/keys/{id}/rotate",
|
||||
tag = "keys",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "Key ID")),
|
||||
request_body = crate::model_config::types::RotateApiKeyRequest,
|
||||
responses(
|
||||
(status = 200, description = "轮换成功", body = crate::model_config::types::AccountApiKeyInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn rotate_api_key() {}
|
||||
}
|
||||
|
||||
pub mod usage {
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/usage",
|
||||
tag = "usage",
|
||||
security(("bearer_auth" = [])),
|
||||
params(crate::model_config::types::UsageQuery),
|
||||
responses(
|
||||
(status = 200, description = "用量统计", body = crate::model_config::types::UsageStats),
|
||||
)
|
||||
)]
|
||||
pub async fn get_usage() {}
|
||||
}
|
||||
|
||||
pub mod relay {
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/relay/chat/completions",
|
||||
tag = "relay",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "聊天补全响应(JSON 或 SSE 流)"),
|
||||
(status = 402, description = "上游服务错误"),
|
||||
(status = 404, description = "模型不存在或未启用"),
|
||||
)
|
||||
)]
|
||||
pub async fn chat_completions() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/relay/tasks",
|
||||
tag = "relay",
|
||||
security(("bearer_auth" = [])),
|
||||
params(crate::relay::types::RelayTaskQuery),
|
||||
responses(
|
||||
(status = 200, description = "中转任务列表", body = Vec<crate::relay::types::RelayTaskInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_tasks() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/relay/tasks/{id}",
|
||||
tag = "relay",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "任务 ID")),
|
||||
responses(
|
||||
(status = 200, description = "任务详情", body = crate::relay::types::RelayTaskInfo),
|
||||
(status = 404, description = "任务不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn get_task() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/relay/tasks/{id}/retry",
|
||||
tag = "relay",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "任务 ID")),
|
||||
responses(
|
||||
(status = 200, description = "重试成功", body = crate::relay::types::RelayTaskInfo),
|
||||
(status = 404, description = "任务不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn retry_task() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/relay/models",
|
||||
tag = "relay",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "可用模型列表", body = Vec<crate::model_config::types::ModelInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_available_models() {}
|
||||
}
|
||||
|
||||
pub mod config {
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/config/items",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
params(crate::migration::types::ConfigQuery),
|
||||
responses(
|
||||
(status = 200, description = "配置项列表", body = Vec<crate::migration::types::ConfigItemInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_config_items() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/config/items/{id}",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "配置项 ID")),
|
||||
responses(
|
||||
(status = 200, description = "配置项详情", body = crate::migration::types::ConfigItemInfo),
|
||||
(status = 404, description = "配置项不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn get_config_item() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/config/items",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::migration::types::CreateConfigItemRequest,
|
||||
responses(
|
||||
(status = 201, description = "创建成功", body = crate::migration::types::ConfigItemInfo),
|
||||
)
|
||||
)]
|
||||
pub async fn create_config_item() {}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/config/items/{id}",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "配置项 ID")),
|
||||
request_body = crate::migration::types::UpdateConfigItemRequest,
|
||||
responses(
|
||||
(status = 200, description = "更新成功", body = crate::migration::types::ConfigItemInfo),
|
||||
(status = 404, description = "配置项不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn update_config_item() {}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/config/items/{id}",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
params(("id" = String, Path, description = "配置项 ID")),
|
||||
responses(
|
||||
(status = 204, description = "删除成功"),
|
||||
(status = 404, description = "配置项不存在"),
|
||||
)
|
||||
)]
|
||||
pub async fn delete_config_item() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/config/analysis",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "配置分析结果", body = crate::migration::types::ConfigAnalysis),
|
||||
)
|
||||
)]
|
||||
pub async fn analyze_config() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/config/seed",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "种子数据初始化成功"),
|
||||
)
|
||||
)]
|
||||
pub async fn seed_config() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/config/sync",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::migration::types::SyncConfigRequest,
|
||||
responses(
|
||||
(status = 200, description = "同步成功"),
|
||||
)
|
||||
)]
|
||||
pub async fn sync_config() {}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/config/diff",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
request_body = crate::migration::types::SyncConfigRequest,
|
||||
responses(
|
||||
(status = 200, description = "配置差异", body = crate::migration::types::ConfigDiffResponse),
|
||||
)
|
||||
)]
|
||||
pub async fn config_diff() {}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/config/sync-logs",
|
||||
tag = "config",
|
||||
security(("bearer_auth" = [])),
|
||||
responses(
|
||||
(status = 200, description = "同步日志列表", body = Vec<crate::migration::types::ConfigSyncLogInfo>),
|
||||
)
|
||||
)]
|
||||
pub async fn list_sync_logs() {}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
//! 中转服务 HTTP 处理器
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use axum::body::Bytes;
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Path, Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
@@ -31,33 +35,70 @@ pub async fn chat_completions(
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider
|
||||
let models = model_service::list_models(&state.db, None).await?;
|
||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
// 查找 model 对应的 provider (直接 SQL 查询,避免全量加载)
|
||||
let target_model = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool)>(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled
|
||||
FROM models WHERE model_id = $1 AND enabled = true"
|
||||
)
|
||||
.bind(model_name)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
let (_model_id, provider_id, model_name_db, _, _, _, _, _, _) = target_model;
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
|
||||
let provider = model_service::get_provider(&state.db, &provider_id).await?;
|
||||
if !provider.enabled {
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
|
||||
// 获取 provider 的 API key (从数据库直接查询)
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
// 优先使用用户级 account_api_key,回退到 provider 级 key
|
||||
let account_key_encrypted: Option<String> = sqlx::query_scalar(
|
||||
"SELECT key_value FROM account_api_keys
|
||||
WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true
|
||||
ORDER BY created_at DESC LIMIT 1"
|
||||
)
|
||||
.bind(&target_model.provider_id)
|
||||
.bind(&ctx.account_id)
|
||||
.bind(&provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
let api_key: Option<String> = if let Some(encrypted) = account_key_encrypted {
|
||||
// 更新 last_used_at
|
||||
let _ = sqlx::query(
|
||||
"UPDATE account_api_keys SET last_used_at = NOW() WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.bind(&provider_id)
|
||||
.execute(&state.db)
|
||||
.await;
|
||||
Some(state.field_encryption.decrypt_or_plaintext(&encrypted))
|
||||
} else {
|
||||
// 回退到 provider 级 key
|
||||
let provider_key_encrypted: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = $1"
|
||||
)
|
||||
.bind(&provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
provider_key_encrypted.map(|k| state.field_encryption.decrypt_or_plaintext(&k))
|
||||
};
|
||||
|
||||
if api_key.is_none() {
|
||||
return Err(SaasError::Internal(format!(
|
||||
"Provider {} 没有可用的 API Key", provider.name
|
||||
)));
|
||||
}
|
||||
|
||||
let request_body = serde_json::to_string(&req)?;
|
||||
|
||||
// 创建中转任务
|
||||
let config = state.config.read().await;
|
||||
let task = service::create_relay_task(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, &request_body, 0,
|
||||
&state.db, &ctx.account_id, &provider_id,
|
||||
&model_name_db, &request_body, 0,
|
||||
config.relay.max_attempts,
|
||||
).await?;
|
||||
|
||||
@@ -66,8 +107,9 @@ pub async fn chat_completions(
|
||||
|
||||
// 执行中转 (带重试)
|
||||
let response = service::execute_relay(
|
||||
&state.db, &task.id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &request_body, stream,
|
||||
&state.db, &task.id, &ctx.account_id, &provider_id, &model_name_db,
|
||||
&provider.base_url,
|
||||
api_key.as_deref(), &request_body, stream,
|
||||
config.relay.max_attempts,
|
||||
config.relay.retry_delay_ms,
|
||||
).await;
|
||||
@@ -86,34 +128,35 @@ pub async fn chat_completions(
|
||||
.unwrap_or(0);
|
||||
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, input_tokens, output_tokens,
|
||||
&state.db, &ctx.account_id, &provider_id,
|
||||
&model_name_db, input_tokens, output_tokens,
|
||||
None, "success", None,
|
||||
).await?;
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, 0, 0,
|
||||
None, "success", None,
|
||||
).await?;
|
||||
Ok(service::RelayResponse::SseWithUsage { body, task_id: relay_task_id, account_id: relay_account_id, provider_id: relay_provider_id, model_id: relay_model_id }) => {
|
||||
// 流式响应: 使用 async_stream 包装器提取 SSE 末尾的 usage
|
||||
let wrapped = sse_usage_wrapper(
|
||||
state.db.clone(),
|
||||
relay_task_id, relay_account_id, relay_provider_id, relay_model_id,
|
||||
body,
|
||||
);
|
||||
let wrapped_body = axum::body::Body::from_stream(wrapped);
|
||||
|
||||
// 流式响应: 直接转发 axum::body::Body
|
||||
let response = axum::response::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||||
.header("Cache-Control", "no-cache")
|
||||
.header("Connection", "keep-alive")
|
||||
.body(body)
|
||||
.unwrap();
|
||||
.body(wrapped_body)
|
||||
.map_err(|e| SaasError::Internal(format!("SSE 响应构建失败: {}", e)))?;
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
model_service::record_usage(
|
||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||
&target_model.model_id, 0, 0,
|
||||
&state.db, &ctx.account_id, &provider_id,
|
||||
&model_name_db, 0, 0,
|
||||
None, "failed", Some(&e.to_string()),
|
||||
).await?;
|
||||
Err(e)
|
||||
@@ -179,7 +222,7 @@ pub async fn retry_task(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
) -> SaasResult<(StatusCode, Json<serde_json::Value>)> {
|
||||
check_permission(&ctx, "relay:admin")?;
|
||||
|
||||
let task = service::get_relay_task(&state.db, &id).await?;
|
||||
@@ -191,17 +234,35 @@ pub async fn retry_task(
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = ?1"
|
||||
|
||||
// 重试时使用原始任务所属用户的 account key,回退到 provider key
|
||||
let account_key_encrypted: Option<String> = sqlx::query_scalar(
|
||||
"SELECT key_value FROM account_api_keys
|
||||
WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true
|
||||
ORDER BY created_at DESC LIMIT 1"
|
||||
)
|
||||
.bind(&task.account_id)
|
||||
.bind(&task.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
let api_key: Option<String> = if let Some(encrypted) = account_key_encrypted {
|
||||
Some(state.field_encryption.decrypt_or_plaintext(&encrypted))
|
||||
} else {
|
||||
let provider_key_encrypted: Option<String> = sqlx::query_scalar(
|
||||
"SELECT api_key FROM providers WHERE id = $1"
|
||||
)
|
||||
.bind(&task.provider_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.flatten();
|
||||
provider_key_encrypted.map(|k| state.field_encryption.decrypt_or_plaintext(&k))
|
||||
};
|
||||
|
||||
// 读取原始请求体
|
||||
let request_body: Option<String> = sqlx::query_scalar(
|
||||
"SELECT request_body FROM relay_tasks WHERE id = ?1"
|
||||
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db)
|
||||
@@ -222,7 +283,7 @@ pub async fn retry_task(
|
||||
|
||||
// 重置任务状态为 queued 以允许新的 processing
|
||||
sqlx::query(
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?1"
|
||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
|
||||
)
|
||||
.bind(&id)
|
||||
.execute(&state.db)
|
||||
@@ -231,10 +292,14 @@ pub async fn retry_task(
|
||||
// 异步执行重试
|
||||
let db = state.db.clone();
|
||||
let task_id = id.clone();
|
||||
let retry_account_id = ctx.account_id.clone();
|
||||
let retry_provider_id = task.provider_id.clone();
|
||||
let retry_model_id = task.model_id.clone();
|
||||
tokio::spawn(async move {
|
||||
match service::execute_relay(
|
||||
&db, &task_id, &provider.base_url,
|
||||
provider_api_key.as_deref(), &body, stream,
|
||||
&db, &task_id, &retry_account_id, &retry_provider_id, &retry_model_id,
|
||||
&provider.base_url,
|
||||
api_key.as_deref(), &body, stream,
|
||||
max_attempts, base_delay_ms,
|
||||
).await {
|
||||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||||
@@ -245,5 +310,101 @@ pub async fn retry_task(
|
||||
log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id,
|
||||
None, ctx.client_ip.as_deref()).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
||||
Ok((StatusCode::ACCEPTED, Json(serde_json::json!({"ok": true, "task_id": id}))))
|
||||
}
|
||||
|
||||
/// 包装 SSE 流,提取末尾的 usage 数据并异步记录
|
||||
///
|
||||
/// 支持客户端断连检测:当 body stream 返回错误(通常表示客户端提前断开连接),
|
||||
/// 记录日志并将任务标记为 "cancelled" 而非 "completed"。
|
||||
fn sse_usage_wrapper(
|
||||
db: sqlx::PgPool,
|
||||
task_id: String,
|
||||
account_id: String,
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
body: axum::body::Body,
|
||||
) -> impl futures::Stream<Item = Result<Bytes, std::io::Error>> + Send {
|
||||
use futures::StreamExt;
|
||||
|
||||
let last_usage: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
|
||||
let mut saw_done = false;
|
||||
|
||||
async_stream::stream! {
|
||||
let mut data_stream = std::pin::pin!(body.into_data_stream().map(|r| r.map_err(std::io::Error::other)));
|
||||
loop {
|
||||
match StreamExt::next(&mut data_stream).await {
|
||||
Some(Ok(chunk)) => {
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in text.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
let trimmed = data.trim();
|
||||
if trimmed == "[DONE]" {
|
||||
saw_done = true;
|
||||
let usage_str = last_usage.lock().await.take();
|
||||
if let Some(s) = usage_str {
|
||||
let (input, output) = service::extract_token_usage(&s);
|
||||
if input > 0 || output > 0 {
|
||||
let db2 = db.clone();
|
||||
let tid = task_id.clone();
|
||||
let aid = account_id.clone();
|
||||
let pid = provider_id.clone();
|
||||
let mid = model_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = service::update_task_status(
|
||||
&db2, &tid, "completed",
|
||||
Some(input), Some(output), None
|
||||
).await;
|
||||
let _ = model_service::record_usage(
|
||||
&db2, &aid, &pid, &mid,
|
||||
input, output, None, "success", None,
|
||||
).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if serde_json::from_str::<serde_json::Value>(trimmed)
|
||||
.ok()
|
||||
.and_then(|v| if v.get("usage").is_some() { Some(trimmed.to_string()) } else { None })
|
||||
.is_some()
|
||||
{
|
||||
*last_usage.lock().await = Some(trimmed.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
yield Ok(chunk);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
// 客户端断连或上游连接中断
|
||||
if !saw_done {
|
||||
tracing::warn!(
|
||||
"SSE stream error for task {} (client disconnected): {}",
|
||||
task_id, e
|
||||
);
|
||||
// 将任务标记为 cancelled(区别于 completed 和 failed)
|
||||
let db2 = db.clone();
|
||||
let tid = task_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = service::update_task_status(
|
||||
&db2, &tid, "cancelled",
|
||||
None, None, Some("客户端断开连接"),
|
||||
).await;
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Stream 正常结束(上游发送完毕)
|
||||
if !saw_done {
|
||||
// 上游关闭但未发送 [DONE],仍记录完成
|
||||
tracing::info!(
|
||||
"SSE stream ended without [DONE] for task {}",
|
||||
task_id,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! 中转服务核心逻辑
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
use futures::StreamExt;
|
||||
@@ -18,35 +18,34 @@ fn is_retryable_error(e: &reqwest::Error) -> bool {
|
||||
// ============ Relay Task Management ============
|
||||
|
||||
pub async fn create_relay_task(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request_body: &str,
|
||||
priority: i64,
|
||||
_priority: i64,
|
||||
max_attempts: u32,
|
||||
) -> SaasResult<RelayTaskInfo> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let request_hash = hash_request(request_body);
|
||||
let now = chrono::Utc::now();
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)"
|
||||
VALUES ($1, $2, $3, $4, '', $5, 'queued', 0, 0, $6, $7, $7)"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||
.bind(request_body).bind(max_attempts as i64).bind(&now)
|
||||
.execute(db).await?;
|
||||
|
||||
get_relay_task(db, &id).await
|
||||
}
|
||||
|
||||
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, chrono::DateTime<chrono::Utc>, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE id = ?1"
|
||||
FROM relay_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(db)
|
||||
@@ -58,12 +57,12 @@ pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayT
|
||||
Ok(RelayTaskInfo {
|
||||
id, account_id, provider_id, model_id, status, priority,
|
||||
attempt_count, max_attempts, input_tokens, output_tokens,
|
||||
error_message, queued_at, started_at, completed_at, created_at,
|
||||
error_message, queued_at: queued_at.to_rfc3339(), started_at: started_at.map(|t| t.to_rfc3339()), completed_at: completed_at.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_relay_tasks(
|
||||
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
|
||||
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
|
||||
) -> SaasResult<Vec<RelayTaskInfo>> {
|
||||
let page = query.page.unwrap_or(1).max(1);
|
||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
||||
@@ -71,13 +70,13 @@ pub async fn list_relay_tasks(
|
||||
|
||||
let sql = if query.status.is_some() {
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
|
||||
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||||
} else {
|
||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
|
||||
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||
};
|
||||
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
|
||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, chrono::DateTime<chrono::Utc>, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(sql)
|
||||
.bind(account_id);
|
||||
|
||||
if let Some(ref status) = query.status {
|
||||
@@ -88,31 +87,32 @@ pub async fn list_relay_tasks(
|
||||
|
||||
let rows = query_builder.fetch_all(db).await?;
|
||||
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
||||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at: queued_at.to_rfc3339(), started_at: started_at.map(|t| t.to_rfc3339()), completed_at: completed_at.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339() }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
pub async fn update_task_status(
|
||||
db: &SqlitePool, task_id: &str, status: &str,
|
||||
db: &PgPool, task_id: &str, status: &str,
|
||||
input_tokens: Option<i64>, output_tokens: Option<i64>,
|
||||
error_message: Option<&str>,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
let update_sql = match status {
|
||||
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
|
||||
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
|
||||
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
|
||||
"processing" => "started_at = $1, status = 'processing', attempt_count = attempt_count + 1",
|
||||
"completed" => "completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens)",
|
||||
"failed" => "completed_at = $1, status = 'failed', error_message = $2",
|
||||
"cancelled" => "completed_at = $1, status = 'cancelled', error_message = $2",
|
||||
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
|
||||
};
|
||||
|
||||
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
|
||||
let sql = format!("UPDATE relay_tasks SET {} WHERE id = $4", update_sql);
|
||||
|
||||
let mut query = sqlx::query(&sql).bind(&now);
|
||||
if status == "completed" {
|
||||
query = query.bind(input_tokens).bind(output_tokens);
|
||||
}
|
||||
if status == "failed" {
|
||||
if status == "failed" || status == "cancelled" {
|
||||
query = query.bind(error_message);
|
||||
}
|
||||
query = query.bind(task_id);
|
||||
@@ -124,8 +124,11 @@ pub async fn update_task_status(
|
||||
// ============ Relay Execution ============
|
||||
|
||||
pub async fn execute_relay(
|
||||
db: &SqlitePool,
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
account_id: &str,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
provider_base_url: &str,
|
||||
provider_api_key: Option<&str>,
|
||||
request_body: &str,
|
||||
@@ -135,6 +138,31 @@ pub async fn execute_relay(
|
||||
) -> SaasResult<RelayResponse> {
|
||||
validate_provider_url(provider_base_url)?;
|
||||
|
||||
// DNS Rebinding 防护: 解析 host 并验证所有 resolved IP 非私有
|
||||
let parsed_url: url::Url = provider_base_url.trim_end_matches('/').parse()
|
||||
.map_err(|_| SaasError::InvalidInput(format!("无效的 provider URL: {}", provider_base_url)))?;
|
||||
let host = parsed_url.host_str()
|
||||
.ok_or_else(|| SaasError::InvalidInput("provider URL 缺少 host".into()))?;
|
||||
|
||||
// 仅对非 IP 的 host 做 DNS 解析(纯 IP 已在 validate_provider_url 中检查)
|
||||
if host.parse::<std::net::IpAddr>().is_err() {
|
||||
let port = parsed_url.port_or_known_default().unwrap_or(443);
|
||||
let addr_str = format!("{}:{}", host, port);
|
||||
let addrs: Vec<std::net::SocketAddr> = std::net::ToSocketAddrs::to_socket_addrs(&addr_str)
|
||||
.map_err(|e| SaasError::InvalidInput(format!("DNS 解析失败: {}", e)))?
|
||||
.collect();
|
||||
if addrs.is_empty() {
|
||||
return Err(SaasError::InvalidInput(format!("DNS 解析无结果: {}", host)));
|
||||
}
|
||||
for addr in &addrs {
|
||||
if is_private_ip(&addr.ip()) {
|
||||
return Err(SaasError::InvalidInput(format!(
|
||||
"provider URL {} 解析到私有 IP: {}", host, addr.ip()
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
@@ -167,8 +195,14 @@ pub async fn execute_relay(
|
||||
let byte_stream = resp.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
let body = axum::body::Body::from_stream(byte_stream);
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
||||
return Ok(RelayResponse::SseWithUsage {
|
||||
body,
|
||||
task_id: task_id.to_string(),
|
||||
account_id: account_id.to_string(),
|
||||
provider_id: provider_id.to_string(),
|
||||
model_id: model_id.to_string(),
|
||||
});
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||
@@ -182,7 +216,12 @@ pub async fn execute_relay(
|
||||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||||
// 4xx 客户端错误或已达最大重试次数 → 立即失败
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||
// 仅记录日志,不将上游错误体暴露给客户端(可能含敏感信息如 API key)
|
||||
tracing::warn!(
|
||||
"Relay task {} 上游返回 HTTP {} (body truncated): {}",
|
||||
task_id, status, &body[..body.len().min(200)]
|
||||
);
|
||||
let err_msg = format!("上游服务返回错误 (HTTP {})", status);
|
||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||
return Err(SaasError::Relay(err_msg));
|
||||
}
|
||||
@@ -218,17 +257,19 @@ pub async fn execute_relay(
|
||||
#[derive(Debug)]
|
||||
pub enum RelayResponse {
|
||||
Json(String),
|
||||
Sse(axum::body::Body),
|
||||
/// SSE 流式响应 + 上下文信息
|
||||
SseWithUsage {
|
||||
body: axum::body::Body,
|
||||
task_id: String,
|
||||
account_id: String,
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
// ============ Helpers ============
|
||||
|
||||
fn hash_request(body: &str) -> String {
|
||||
use sha2::{Sha256, Digest};
|
||||
hex::encode(Sha256::digest(body.as_bytes()))
|
||||
}
|
||||
|
||||
fn extract_token_usage(body: &str) -> (i64, i64) {
|
||||
pub fn extract_token_usage(body: &str) -> (i64, i64) {
|
||||
let parsed: serde_json::Value = match serde_json::from_str(body) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return (0, 0),
|
||||
@@ -273,6 +314,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||
Some(h) => h,
|
||||
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
|
||||
};
|
||||
// url crate 的 host_str() 对 IPv6 地址保留方括号 (如 "[::1]"),
|
||||
// 需要去掉方括号才能与阻止列表匹配和解析为 IpAddr
|
||||
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||||
|
||||
// 精确匹配的阻止列表
|
||||
let blocked_exact = [
|
||||
@@ -335,3 +379,302 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ---- is_retryable_status ----
|
||||
|
||||
#[test]
|
||||
fn retryable_status_429() {
|
||||
assert!(is_retryable_status(429));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retryable_status_5xx_range() {
|
||||
for code in 500u16..600 {
|
||||
assert!(is_retryable_status(code), "expected {code} to be retryable");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_retryable_status_200() {
|
||||
assert!(!is_retryable_status(200));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_retryable_status_400() {
|
||||
assert!(!is_retryable_status(400));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_retryable_status_404() {
|
||||
assert!(!is_retryable_status(404));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_retryable_status_422() {
|
||||
assert!(!is_retryable_status(422));
|
||||
}
|
||||
|
||||
// ---- extract_token_usage ----
|
||||
|
||||
#[test]
|
||||
fn extract_usage_normal() {
|
||||
let body = r#"{"usage":{"prompt_tokens":100,"completion_tokens":50}}"#;
|
||||
assert_eq!(extract_token_usage(body), (100, 50));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_no_usage_field() {
|
||||
let body = r#"{"id":"chatcmpl-abc","object":"chat.completion"}"#;
|
||||
assert_eq!(extract_token_usage(body), (0, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_invalid_json() {
|
||||
assert_eq!(extract_token_usage("not json at all"), (0, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_empty_body() {
|
||||
assert_eq!(extract_token_usage(""), (0, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_partial_tokens() {
|
||||
// only prompt_tokens present, completion_tokens missing
|
||||
let body = r#"{"usage":{"prompt_tokens":200}}"#;
|
||||
assert_eq!(extract_token_usage(body), (200, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_completion_only() {
|
||||
let body = r#"{"usage":{"completion_tokens":75}}"#;
|
||||
assert_eq!(extract_token_usage(body), (0, 75));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_zero_tokens() {
|
||||
let body = r#"{"usage":{"prompt_tokens":0,"completion_tokens":0}}"#;
|
||||
assert_eq!(extract_token_usage(body), (0, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_usage_string_instead_of_int() {
|
||||
// non-integer token values should fall back to 0
|
||||
let body = r#"{"usage":{"prompt_tokens":"abc","completion_tokens":null}}"#;
|
||||
assert_eq!(extract_token_usage(body), (0, 0));
|
||||
}
|
||||
|
||||
// ---- is_private_ip ----
|
||||
|
||||
#[test]
|
||||
fn private_ip_10_range() {
|
||||
let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_172_16_range() {
|
||||
let ip: std::net::IpAddr = "172.16.0.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_172_31_range() {
|
||||
let ip: std::net::IpAddr = "172.31.255.255".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_172_15_not_private() {
|
||||
// 172.15.x.x is NOT in the private range (starts at 172.16)
|
||||
let ip: std::net::IpAddr = "172.15.255.255".parse().unwrap();
|
||||
assert!(!is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_172_32_not_private() {
|
||||
// 172.32.x.x is NOT in the private range (ends at 172.31)
|
||||
let ip: std::net::IpAddr = "172.32.0.0".parse().unwrap();
|
||||
assert!(!is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_192_168_range() {
|
||||
let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_127_loopback() {
|
||||
let ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_127_any() {
|
||||
let ip: std::net::IpAddr = "127.255.255.255".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_169_254_link_local() {
|
||||
let ip: std::net::IpAddr = "169.254.1.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_0_0_0_0() {
|
||||
let ip: std::net::IpAddr = "0.0.0.0".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_v6_loopback() {
|
||||
let ip: std::net::IpAddr = "::1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_v6_link_local() {
|
||||
let ip: std::net::IpAddr = "fe80::1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_v6_mapped_ipv4_loopback() {
|
||||
let ip: std::net::IpAddr = "::ffff:127.0.0.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn private_ip_v6_mapped_ipv4_private() {
|
||||
let ip: std::net::IpAddr = "::ffff:192.168.1.1".parse().unwrap();
|
||||
assert!(is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_ip_8_8_8_8() {
|
||||
let ip: std::net::IpAddr = "8.8.8.8".parse().unwrap();
|
||||
assert!(!is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_ip_1_1_1_1() {
|
||||
let ip: std::net::IpAddr = "1.1.1.1".parse().unwrap();
|
||||
assert!(!is_private_ip(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_ip_v6_google() {
|
||||
let ip: std::net::IpAddr = "2001:4860:4860::8888".parse().unwrap();
|
||||
assert!(!is_private_ip(&ip));
|
||||
}
|
||||
|
||||
// ---- validate_provider_url ----
|
||||
|
||||
#[test]
|
||||
fn validate_url_https_valid() {
|
||||
assert!(validate_provider_url("https://api.openai.com").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_https_with_path() {
|
||||
assert!(validate_provider_url("https://api.openai.com/v1").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_https_with_port() {
|
||||
assert!(validate_provider_url("https://api.openai.com:443").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_localhost() {
|
||||
assert!(validate_provider_url("https://localhost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_127_0_0_1() {
|
||||
assert!(validate_provider_url("https://127.0.0.1").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_0_0_0_0() {
|
||||
assert!(validate_provider_url("https://0.0.0.0").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_169_254_169_254() {
|
||||
assert!(validate_provider_url("https://169.254.169.254").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_metadata_google_internal() {
|
||||
assert!(validate_provider_url("https://metadata.google.internal").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_private_ip_10() {
|
||||
assert!(validate_provider_url("https://10.0.0.1").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_private_ip_172_16() {
|
||||
assert!(validate_provider_url("https://172.16.0.1").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_private_ip_192_168() {
|
||||
assert!(validate_provider_url("https://192.168.0.1").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_numeric_host() {
|
||||
// decimal IP representation (e.g. 2130706433 = 127.0.0.1)
|
||||
assert!(validate_provider_url("https://2130706433").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_subdomain_localhost() {
|
||||
assert!(validate_provider_url("https://evil.localhost").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_subdomain_internal() {
|
||||
assert!(validate_provider_url("https://app.internal").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_subdomain_local() {
|
||||
assert!(validate_provider_url("https://myapp.local").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_ipv6_loopback() {
|
||||
assert!(validate_provider_url("https://[::1]").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_invalid_format() {
|
||||
assert!(validate_provider_url("not a url").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_missing_host() {
|
||||
assert!(validate_provider_url("https://").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_ftp_scheme() {
|
||||
assert!(validate_provider_url("ftp://api.openai.com").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_url_blocks_http_in_production() {
|
||||
// In CI / default env, ZCLAW_SAAS_DEV is not set, so http is blocked
|
||||
assert!(validate_provider_url("http://api.openai.com").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,27 +2,8 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 中转请求 (OpenAI 兼容格式)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RelayChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: serde_json::Value,
|
||||
}
|
||||
|
||||
/// 中转任务信息
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
|
||||
pub struct RelayTaskInfo {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
@@ -42,18 +23,10 @@ pub struct RelayTaskInfo {
|
||||
}
|
||||
|
||||
/// 中转任务查询
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
|
||||
pub struct RelayTaskQuery {
|
||||
pub status: Option<String>,
|
||||
pub page: Option<i64>,
|
||||
pub page_size: Option<i64>,
|
||||
}
|
||||
|
||||
/// Provider 速率限制状态
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitState {
|
||||
pub rpm: i64,
|
||||
pub tpm: i64,
|
||||
pub concurrent: usize,
|
||||
pub max_concurrent: usize,
|
||||
}
|
||||
|
||||
@@ -1,31 +1,36 @@
|
||||
//! 应用状态
|
||||
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use crate::config::SaaSConfig;
|
||||
use crate::crypto::FieldEncryption;
|
||||
|
||||
/// 全局应用状态,通过 Axum State 共享
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// 数据库连接池
|
||||
pub db: SqlitePool,
|
||||
pub db: PgPool,
|
||||
/// 服务器配置 (可热更新)
|
||||
pub config: Arc<RwLock<SaaSConfig>>,
|
||||
/// JWT 密钥
|
||||
pub jwt_secret: secrecy::SecretString,
|
||||
/// 字段级加密器 (AES-256-GCM)
|
||||
pub field_encryption: Arc<FieldEncryption>,
|
||||
/// 速率限制: account_id → 请求时间戳列表
|
||||
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(db: SqlitePool, config: SaaSConfig) -> anyhow::Result<Self> {
|
||||
pub fn new(db: PgPool, config: SaaSConfig) -> anyhow::Result<Self> {
|
||||
let jwt_secret = config.jwt_secret()?;
|
||||
let field_encryption = Arc::new(FieldEncryption::new()?);
|
||||
Ok(Self {
|
||||
db,
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
jwt_secret,
|
||||
field_encryption,
|
||||
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
//! 集成测试 (Phase 1 + Phase 2)
|
||||
//!
|
||||
//! 所有测试通过全局 Mutex 串行执行,避免共享数据库导致的 UNIQUE 约束冲突和数据竞争。
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -9,8 +11,16 @@ use tower::ServiceExt;
|
||||
|
||||
const MAX_BODY_SIZE: usize = 1024 * 1024; // 1MB
|
||||
|
||||
/// 全局 Mutex 用于序列化所有集成测试
|
||||
/// tokio::test 默认并行执行,但共享数据库要求串行访问
|
||||
static INTEGRATION_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
|
||||
async fn build_test_app() -> axum::Router {
|
||||
use zclaw_saas::{config::SaaSConfig, db::init_memory_db, state::AppState};
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter("error")
|
||||
.with_test_writer()
|
||||
.try_init();
|
||||
use zclaw_saas::{config::SaaSConfig, db::init_test_db, state::AppState};
|
||||
use axum::extract::ConnectInfo;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
@@ -18,7 +28,7 @@ async fn build_test_app() -> axum::Router {
|
||||
std::env::set_var("ZCLAW_SAAS_DEV", "true");
|
||||
std::env::set_var("ZCLAW_SAAS_JWT_SECRET", "test-secret-for-integration-tests-only");
|
||||
|
||||
let db = init_memory_db().await.unwrap();
|
||||
let db = init_test_db().await.unwrap();
|
||||
let mut config = SaaSConfig::default();
|
||||
config.auth.jwt_expiration_hours = 24;
|
||||
let state = AppState::new(db, config).expect("测试环境 AppState 初始化失败");
|
||||
@@ -85,6 +95,7 @@ fn auth_header(token: &str) -> String {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_and_login() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "testuser", "test@example.com").await;
|
||||
assert!(!token.is_empty());
|
||||
@@ -92,6 +103,7 @@ async fn test_register_and_login() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_duplicate_fails() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
|
||||
let body = json!({
|
||||
@@ -123,6 +135,7 @@ async fn test_register_duplicate_fails() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unauthorized_access() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
|
||||
let req = Request::builder()
|
||||
@@ -137,6 +150,7 @@ async fn test_unauthorized_access() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_login_wrong_password() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
register_and_login(&app, "wrongpwd", "wrongpwd@example.com").await;
|
||||
|
||||
@@ -156,6 +170,7 @@ async fn test_login_wrong_password() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_full_authenticated_flow() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "fulltest", "full@example.com").await;
|
||||
|
||||
@@ -204,6 +219,7 @@ async fn test_full_authenticated_flow() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_providers_crud() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
// 注册 super_admin 角色用户 (通过直接插入角色权限)
|
||||
let token = register_and_login(&app, "adminprov", "adminprov@example.com").await;
|
||||
@@ -239,6 +255,7 @@ async fn test_providers_crud() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_models_list_and_usage() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "modeluser", "modeluser@example.com").await;
|
||||
|
||||
@@ -274,6 +291,7 @@ async fn test_models_list_and_usage() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_keys_lifecycle() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "keyuser", "keyuser@example.com").await;
|
||||
|
||||
@@ -309,6 +327,7 @@ async fn test_api_keys_lifecycle() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relay_models_list() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "relayuser", "relayuser@example.com").await;
|
||||
|
||||
@@ -329,6 +348,7 @@ async fn test_relay_models_list() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relay_chat_no_model() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "relayfail", "relayfail@example.com").await;
|
||||
|
||||
@@ -351,6 +371,7 @@ async fn test_relay_chat_no_model() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relay_tasks_list() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "relaytasks", "relaytasks@example.com").await;
|
||||
|
||||
@@ -369,6 +390,7 @@ async fn test_relay_tasks_list() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_analysis_empty() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "cfguser", "cfguser@example.com").await;
|
||||
|
||||
@@ -389,6 +411,7 @@ async fn test_config_analysis_empty() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_seed_and_list() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "cfgseed", "cfgseed@example.com").await;
|
||||
|
||||
@@ -423,6 +446,7 @@ async fn test_config_seed_and_list() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_device_register_and_list() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "devuser", "devuser@example.com").await;
|
||||
|
||||
@@ -463,6 +487,7 @@ async fn test_device_register_and_list() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_device_upsert_on_reregister() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "upsertdev", "upsertdev@example.com").await;
|
||||
|
||||
@@ -516,6 +541,7 @@ async fn test_device_upsert_on_reregister() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_device_heartbeat() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "hbuser", "hbuser@example.com").await;
|
||||
|
||||
@@ -563,6 +589,7 @@ async fn test_device_heartbeat() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_device_register_missing_id() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "baddev", "baddev@example.com").await;
|
||||
|
||||
@@ -578,11 +605,12 @@ async fn test_device_register_missing_id() {
|
||||
.unwrap();
|
||||
|
||||
let resp = app.oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
assert!(resp.status() == StatusCode::BAD_REQUEST || resp.status() == StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_change_password() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "pwduser", "pwduser@example.com").await;
|
||||
|
||||
@@ -632,6 +660,7 @@ async fn test_change_password() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_change_password_wrong_old() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "wrongold", "wrongold@example.com").await;
|
||||
|
||||
@@ -655,6 +684,7 @@ async fn test_change_password_wrong_old() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_e2e_full_lifecycle() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
|
||||
// 1. 注册
|
||||
@@ -771,6 +801,7 @@ async fn test_e2e_full_lifecycle() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_sync() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "cfgsync", "cfgsync@example.com").await;
|
||||
|
||||
@@ -808,6 +839,7 @@ async fn test_config_sync() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_totp_setup_and_verify() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "totpuser", "totp@example.com").await;
|
||||
|
||||
@@ -825,7 +857,7 @@ async fn test_totp_setup_and_verify() {
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert!(body["otpauth_uri"].is_string());
|
||||
assert!(body["secret"].is_string());
|
||||
let secret = body["secret"].as_str().unwrap();
|
||||
let _secret = body["secret"].as_str().unwrap();
|
||||
|
||||
// 2. Verify with wrong code → 400
|
||||
let bad_verify = Request::builder()
|
||||
@@ -868,6 +900,7 @@ async fn test_totp_setup_and_verify() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_totp_disabled_login_without_code() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "nototp", "nototp@example.com").await;
|
||||
|
||||
@@ -913,6 +946,7 @@ async fn test_totp_disabled_login_without_code() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_totp_disable_wrong_password() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "totpwrong", "totpwrong@example.com").await;
|
||||
|
||||
@@ -932,6 +966,7 @@ async fn test_totp_disable_wrong_password() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_diff() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "diffuser", "diffuser@example.com").await;
|
||||
|
||||
@@ -959,6 +994,7 @@ async fn test_config_diff() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_sync_push() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "syncpush", "syncpush@example.com").await;
|
||||
|
||||
@@ -987,6 +1023,7 @@ async fn test_config_sync_push() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_relay_retry_unauthorized() {
|
||||
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
|
||||
let app = build_test_app().await;
|
||||
let token = register_and_login(&app, "retryuser", "retryuser@example.com").await;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user